llvm-project/mlir/lib/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.cpp
Angel Zhang 6867e49fc8
[mlir][spirv] Implement vector type legalization for function signatures (#98337)
### Description
This PR implements a minimal version of function signature conversion to
unroll vectors into 1D and with a size supported by SPIR-V (2, 3 or 4
depending on the original dimension). This PR also includes new unit
tests that only check for function signature conversion.

### Future Plans
- Check for capabilities that support vectors of size 8 or 16.
- Set up `OneToNTypeConversion` and `DialectConversion` to replace the
current implementation that uses `GreedyPatternRewriteDriver`.
- Introduce other vector unrolling patterns to cancel out the
`vector.insert_strided_slice` and `vector.extract_strided_slice` ops and
fully legalize the vector types in the function body.
- Handle `func::CallOp` and declarations.
- Restructure the code in `SPIRVConversion.cpp`.
- Create test passes for testing sets of patterns in isolation.
- Optimize the way original shape is splitted into target shapes, e.g.
`vector<5xi32>` can be splitted into `vector<4xi32>` and
`vector<1xi32>`.

---------

Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
2024-07-17 13:09:15 -04:00

82 lines
3.2 KiB
C++

//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <memory>
#define DEBUG_TYPE "convert-to-spirv"
namespace mlir {
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
namespace {
/// A pass to perform the SPIR-V conversion.
struct ConvertToSPIRVPass final
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> {
using ConvertToSPIRVPassBase::ConvertToSPIRVPassBase;
void runOnOperation() override {
MLIRContext *context = &getContext();
Operation *op = getOperation();
if (runSignatureConversion) {
// Unroll vectors in function signatures to native vector size.
RewritePatternSet patterns(context);
populateFuncOpVectorRewritePatterns(patterns);
populateReturnOpVectorRewritePatterns(patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
return signalPassFailure();
}
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
ScfToSPIRVContext scfToSPIRVContext;
// Populate patterns for each dialect.
arith::populateCeilFloorDivExpandOpsPatterns(patterns);
arith::populateArithToSPIRVPatterns(typeConverter, patterns);
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
index::populateIndexToSPIRVPatterns(typeConverter, patterns);
populateVectorToSPIRVPatterns(typeConverter, patterns);
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns);
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(op, *target, std::move(patterns))))
return signalPassFailure();
}
};
} // namespace