
### Description This PR implements a minimal version of function signature conversion to unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4 depending on the original dimension). This PR also includes new unit tests that only check for function signature conversion. ### Future Plans - Check for capabilities that support vectors of size 8 or 16. - Set up `OneToNTypeConversion` and `DialectConversion` to replace the current implementation that uses `GreedyPatternRewriteDriver`. - Introduce other vector unrolling patterns to cancel out the `vector.insert_strided_slice` and `vector.extract_strided_slice` ops and fully legalize the vector types in the function body. - Handle `func::CallOp` and declarations. - Restructure the code in `SPIRVConversion.cpp`. - Create test passes for testing sets of patterns in isolation. - Optimize the way original shape is splitted into target shapes, e.g. `vector<5xi32>` can be splitted into `vector<4xi32>` and `vector<1xi32>`. --------- Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
82 lines
3.2 KiB
C++
82 lines
3.2 KiB
C++
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
|
|
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
|
|
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
|
|
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
|
|
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
|
|
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
|
|
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
|
|
#include "mlir/Dialect/Arith/Transforms/Passes.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
|
|
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include <memory>
|
|
|
|
#define DEBUG_TYPE "convert-to-spirv"
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
|
|
#include "mlir/Conversion/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
|
|
namespace {
|
|
|
|
/// A pass to perform the SPIR-V conversion.
|
|
struct ConvertToSPIRVPass final
|
|
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
|
|
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
Operation *op = getOperation();
|
|
|
|
if (runSignatureConversion) {
|
|
// Unroll vectors in function signatures to native vector size.
|
|
RewritePatternSet patterns(context);
|
|
populateFuncOpVectorRewritePatterns(patterns);
|
|
populateReturnOpVectorRewritePatterns(patterns);
|
|
GreedyRewriteConfig config;
|
|
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
|
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
|
|
return signalPassFailure();
|
|
}
|
|
|
|
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
|
|
std::unique_ptr<ConversionTarget> target =
|
|
SPIRVConversionTarget::get(targetAttr);
|
|
SPIRVTypeConverter typeConverter(targetAttr);
|
|
RewritePatternSet patterns(context);
|
|
ScfToSPIRVContext scfToSPIRVContext;
|
|
|
|
// Populate patterns for each dialect.
|
|
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
|
|
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
|
|
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
|
|
populateFuncToSPIRVPatterns(typeConverter, patterns);
|
|
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
|
|
populateVectorToSPIRVPatterns(typeConverter, patterns);
|
|
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
|
|
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
|
|
|
|
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|