
In particular for Graph Regions, the terminator needs is just a historical artifact of the generalization of MLIR from CFG region. Operations like Module don't need a terminator, and before Module migrated to be an operation with region there wasn't any needed. To validate the feature, the ModuleOp is migrated to use this trait and the ModuleTerminator operation is deleted. This patch is likely to break clients, if you're in this case: - you may iterate on a ModuleOp with `getBody()->without_terminator()`, the solution is simple: just remove the ->without_terminator! - you created a builder with `Builder::atBlockTerminator(module_body)`, just use `Builder::atBlockEnd(module_body)` instead. - you were handling ModuleTerminator: it isn't needed anymore. - for generic code, a `Block::mayNotHaveTerminator()` may be used. Differential Revision: https://reviews.llvm.org/D98468
135 lines
5.3 KiB
C++
135 lines
5.3 KiB
C++
//===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/LoopUtils.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
|
|
using namespace mlir;
|
|
using namespace vector;
|
|
|
|
namespace {
|
|
/// A pass converting MLIR Linalg ops into Vector ops.
|
|
class TestConvVectorization
|
|
: public PassWrapper<TestConvVectorization, OperationPass<ModuleOp>> {
|
|
public:
|
|
TestConvVectorization() = default;
|
|
TestConvVectorization(const TestConvVectorization &) {}
|
|
explicit TestConvVectorization(ArrayRef<int64_t> tileSizesParam) {
|
|
tileSizes = tileSizesParam;
|
|
}
|
|
|
|
void runOnOperation() override;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<VectorDialect>();
|
|
registry.insert<linalg::LinalgDialect>();
|
|
registry.insert<memref::MemRefDialect>();
|
|
registry.insert<scf::SCFDialect>();
|
|
registry.insert<AffineDialect>();
|
|
registry.insert<StandardOpsDialect>();
|
|
}
|
|
|
|
ListOption<int64_t> tileSizes{
|
|
*this, "tile-sizes", llvm::cl::desc("Vectorization sizes."),
|
|
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated};
|
|
};
|
|
} // namespace
|
|
|
|
void TestConvVectorization::runOnOperation() {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp module = getOperation();
|
|
|
|
ConversionTarget target(*context);
|
|
target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect,
|
|
VectorDialect>();
|
|
target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
|
|
target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
|
|
|
|
SmallVector<RewritePatternSet, 4> stage1Patterns;
|
|
linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes);
|
|
SmallVector<FrozenRewritePatternSet, 4> frozenStage1Patterns;
|
|
llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns));
|
|
|
|
RewritePatternSet stage2Patterns =
|
|
linalg::getLinalgTilingCanonicalizationPatterns(context);
|
|
stage2Patterns.add<linalg::AffineMinSCFCanonicalizationPattern>(context);
|
|
|
|
auto stage3Transforms = [](Operation *op) {
|
|
PassManager pm(op->getContext());
|
|
pm.addPass(createLoopInvariantCodeMotionPass());
|
|
if (failed(pm.run(cast<ModuleOp>(op))))
|
|
llvm_unreachable("Unexpected failure in cleanup pass pipeline.");
|
|
op->walk([](FuncOp func) {
|
|
promoteSingleIterationLoops(func);
|
|
linalg::hoistRedundantVectorTransfers(func);
|
|
});
|
|
return success();
|
|
};
|
|
|
|
(void)linalg::applyStagedPatterns(module, frozenStage1Patterns,
|
|
std::move(stage2Patterns),
|
|
stage3Transforms);
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Post staged patterns transforms
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
VectorTransformsOptions vectorTransformsOptions{
|
|
VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
|
|
|
|
RewritePatternSet vectorTransferPatterns(context);
|
|
// Pattern is not applied because rank-reducing vector transfer is not yet
|
|
// supported as can be seen in splitFullAndPartialTransferPrecondition,
|
|
// VectorTransforms.cpp
|
|
vectorTransferPatterns.add<VectorTransferFullPartialRewriter>(
|
|
context, vectorTransformsOptions);
|
|
(void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns));
|
|
|
|
// Programmatic controlled lowering of linalg.copy and linalg.fill.
|
|
PassManager pm(context);
|
|
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
|
|
if (failed(pm.run(module)))
|
|
llvm_unreachable("Unexpected failure in linalg to loops pass.");
|
|
|
|
// Programmatic controlled lowering of vector.contract only.
|
|
RewritePatternSet vectorContractLoweringPatterns(context);
|
|
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
|
|
vectorTransformsOptions);
|
|
(void)applyPatternsAndFoldGreedily(module,
|
|
std::move(vectorContractLoweringPatterns));
|
|
|
|
// Programmatic controlled lowering of vector.transfer only.
|
|
RewritePatternSet vectorToLoopsPatterns(context);
|
|
populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
|
|
VectorTransferToSCFOptions());
|
|
(void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
|
|
|
|
// Ensure we drop the marker in the end.
|
|
module.walk([](linalg::LinalgOp op) {
|
|
op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker);
|
|
});
|
|
}
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestConvVectorization() {
|
|
PassRegistration<TestConvVectorization> testTransformPatternsPass(
|
|
"test-conv-vectorization", "Test vectorization of convolutions");
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|