//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::linalg; using namespace mlir::vector; namespace { struct TestVectorToVectorLowering : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering) TestVectorToVectorLowering() = default; TestVectorToVectorLowering(const TestVectorToVectorLowering &pass) : PassWrapper(pass) {} StringRef getArgument() const final { return "test-vector-to-vector-lowering"; } StringRef getDescription() const final { return "Test lowering patterns between ops in the vector dialect"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } Option unroll{*this, "unroll", llvm::cl::desc("Include unrolling"), llvm::cl::init(false)}; void runOnOperation() override { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); if (unroll) { populateVectorUnrollPatterns( patterns, UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( filter)); } populateVectorToVectorCanonicalizationPatterns(patterns); populateBubbleVectorBitCastOpPatterns(patterns); populateCastAwayVectorLeadingOneDimPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } private: // Return the target shape based on op type. static Optional> getShape(Operation *op) { if (isa(op)) return SmallVector(2, 2); if (isa(op)) return SmallVector(3, 2); // For transfer ops, just propagate the shape coming from // InsertStridedSlices/ExtractStridedSlices. if (auto readOp = dyn_cast(op)) { VectorType dstVec; for (Operation *users : readOp->getUsers()) { auto extract = dyn_cast(users); if (!extract) return llvm::None; auto vecType = extract.getResult().getType().cast(); if (dstVec && dstVec != vecType) return llvm::None; dstVec = vecType; } return SmallVector(dstVec.getShape().begin(), dstVec.getShape().end()); } if (auto writeOp = dyn_cast(op)) { auto insert = writeOp.getVector().getDefiningOp(); if (!insert) return llvm::None; ArrayRef shape = insert.getSourceVectorType().getShape(); return SmallVector(shape.begin(), shape.end()); } return llvm::None; } static LogicalResult filter(Operation *op) { return success(isa(op)); } }; struct TestVectorContractionLowering : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering) StringRef getArgument() const final { return "test-vector-contraction-lowering"; } StringRef getDescription() const final { return "Test lowering patterns that lower contract ops in the vector " "dialect"; } TestVectorContractionLowering() = default; TestVectorContractionLowering(const TestVectorContractionLowering &pass) : PassWrapper(pass) {} Option lowerToFlatMatrix{ *this, "vector-lower-matrix-intrinsics", llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), llvm::cl::init(false)}; Option lowerToOuterProduct{ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), llvm::cl::init(false)}; Option lowerToFilterOuterProduct{ *this, "vector-filter-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " "vectors of size 4."), llvm::cl::init(false)}; Option lowerToParallelArith{ *this, "vector-parallel-arith", llvm::cl::desc("Lower vector.contract to elementwise vector ops."), llvm::cl::init(false)}; void runOnOperation() override { RewritePatternSet patterns(&getContext()); // Test on one pattern in isolation. if (lowerToOuterProduct) { VectorContractLowering lowering = VectorContractLowering::OuterProduct; VectorTransformsOptions options{lowering}; patterns.add(options, &getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); return; } // Test on one pattern in isolation. if (lowerToFilterOuterProduct) { VectorContractLowering lowering = VectorContractLowering::OuterProduct; VectorTransformsOptions options{lowering}; patterns.add( options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) { // Only lowers vector.contract where the lhs as a type vector // where M is not 4. if (op.getRhsType().getShape()[0] == 4) return failure(); return success(); }); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); return; } if (lowerToParallelArith) { vector::populateVectorContractLoweringPatterns( patterns, vector::VectorTransformsOptions().setVectorTransformsOptions( vector::VectorContractLowering::ParallelArith)); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); return; } // Test on all contract lowering patterns. VectorContractLowering contractLowering = VectorContractLowering::Dot; if (lowerToFlatMatrix) contractLowering = VectorContractLowering::Matmul; VectorMultiReductionLowering vectorMultiReductionLowering = VectorMultiReductionLowering::InnerParallel; VectorTransformsOptions options{contractLowering, vectorMultiReductionLowering, VectorTransposeLowering()}; populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns, options); populateVectorMaskOpLoweringPatterns(patterns); populateVectorShapeCastLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; struct TestVectorTransposeLowering : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering) StringRef getArgument() const final { return "test-vector-transpose-lowering"; } StringRef getDescription() const final { return "Test lowering patterns that lower contract ops in the vector " "dialect"; } TestVectorTransposeLowering() = default; TestVectorTransposeLowering(const TestVectorTransposeLowering &pass) : PassWrapper(pass) {} Option lowerToEltwise{ *this, "eltwise", llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"), llvm::cl::init(false)}; Option lowerToFlatTranspose{ *this, "flat", llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"), llvm::cl::init(false)}; Option lowerToShuffleTranspose{ *this, "shuffle", llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"), llvm::cl::init(false)}; Option lowerToAvx2{ *this, "avx2", llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"), llvm::cl::init(false)}; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { func::FuncOp funcOp = getOperation(); MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); vector::VectorTransformsOptions vectorTransformOptions; if (lowerToEltwise) { vectorTransformOptions = vectorTransformOptions.setVectorTransposeLowering( VectorTransposeLowering::EltWise); } if (lowerToFlatTranspose) { vectorTransformOptions = vectorTransformOptions.setVectorTransposeLowering( VectorTransposeLowering::Flat); } if (lowerToShuffleTranspose) { vectorTransformOptions = vectorTransformOptions.setVectorTransposeLowering( VectorTransposeLowering::Shuffle); } vector::populateVectorTransposeLoweringPatterns(patterns, vectorTransformOptions); if (lowerToAvx2) { auto avx2LoweringOptions = x86vector::avx2::LoweringOptions().setTransposeOptions( x86vector::avx2::TransposeLoweringOptions() .lower4x8xf32() .lower8x8xf32()); x86vector::avx2::populateSpecializedTransposeLoweringPatterns( patterns, avx2LoweringOptions, /*benefit=*/10); } if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) return signalPassFailure(); } }; struct TestVectorUnrollingPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns) StringRef getArgument() const final { return "test-vector-unrolling-patterns"; } StringRef getDescription() const final { return "Test lowering patterns to unroll contract ops in the vector " "dialect"; } TestVectorUnrollingPatterns() = default; TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) : PassWrapper(pass) {} void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{2, 2}) .setFilterConstraint([](Operation *op) { return success(isa(op)); })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{2}) .setFilterConstraint([](Operation *op) { return success(isa(op)); })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{1, 3, 4, 2}) .setFilterConstraint([](Operation *op) { return success(isa(op)); })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn = [](Operation *op) -> Optional> { vector::ContractionOp contractOp = cast(op); SmallVector nativeShape( contractOp.getIteratorTypes().size(), 4); Type lhsType = contractOp.getLhsType().getElementType(); nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2; return nativeShape; }; UnrollVectorOptions opts; opts.setNativeShapeFn(nativeShapeFn) .setFilterConstraint( [](Operation *op) { return success(isa(op)); }); if (!unrollOrder.empty()) { opts.setUnrollTraversalOrderFn([this](Operation *op) -> Optional> { vector::ContractionOp contractOp = cast(op); if (contractOp.getIteratorTypes().size() == unrollOrder.size()) return SmallVector(unrollOrder.begin(), unrollOrder.end()); return None; }); } populateVectorUnrollPatterns(patterns, opts); } else { auto nativeShapeFn = [](Operation *op) -> Optional> { auto contractOp = dyn_cast(op); if (!contractOp) return None; return SmallVector(contractOp.getIteratorTypes().size(), 2); }; populateVectorUnrollPatterns(patterns, UnrollVectorOptions() .setNativeShapeFn(nativeShapeFn) .setFilterConstraint([](Operation *op) { return success(isa(op)); })); } populateVectorToVectorCanonicalizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } ListOption unrollOrder{*this, "unroll-order", llvm::cl::desc("set the unroll order")}; Option unrollBasedOnType{ *this, "unroll-based-on-type", llvm::cl::desc("Set the unroll factor based on type of the operation"), llvm::cl::init(false)}; }; struct TestVectorTransferUnrollingPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorTransferUnrollingPatterns) TestVectorTransferUnrollingPatterns() = default; TestVectorTransferUnrollingPatterns( const TestVectorTransferUnrollingPatterns &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } StringRef getArgument() const final { return "test-vector-transfer-unrolling-patterns"; } StringRef getDescription() const final { return "Test lowering patterns to unroll transfer ops in the vector " "dialect"; } void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); UnrollVectorOptions opts; opts.setNativeShape(ArrayRef{2, 2}) .setFilterConstraint([](Operation *op) { return success( isa(op)); }); if (reverseUnrollOrder.getValue()) { opts.setUnrollTraversalOrderFn( [](Operation *op) -> Optional> { int64_t numLoops = 0; if (auto readOp = dyn_cast(op)) numLoops = readOp.getVectorType().getRank(); else if (auto writeOp = dyn_cast(op)) numLoops = writeOp.getVectorType().getRank(); else return None; auto order = llvm::reverse(llvm::seq(0, numLoops)); return llvm::to_vector(order); }); } populateVectorUnrollPatterns(patterns, opts); populateVectorToVectorCanonicalizationPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } Option reverseUnrollOrder{ *this, "reverse-unroll-order", llvm::cl::desc( "reverse the order of unrolling of vector transfer operations"), llvm::cl::init(false)}; }; struct TestVectorTransferFullPartialSplitPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorTransferFullPartialSplitPatterns) StringRef getArgument() const final { return "test-vector-transfer-full-partial-split"; } StringRef getDescription() const final { return "Test lowering patterns to split " "transfer ops via scf.if + linalg ops"; } TestVectorTransferFullPartialSplitPatterns() = default; TestVectorTransferFullPartialSplitPatterns( const TestVectorTransferFullPartialSplitPatterns &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } Option useLinalgOps{ *this, "use-memref-copy", llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + " "memref.copy operations."), llvm::cl::init(false)}; void runOnOperation() override { MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); VectorTransformsOptions options; if (useLinalgOps) options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); else options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); patterns.add(ctx, options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; struct TestVectorTransferOpt : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt) StringRef getArgument() const final { return "test-vector-transferop-opt"; } StringRef getDescription() const final { return "Test optimization transformations for transfer ops"; } void runOnOperation() override { transferOpflowOpt(getOperation()); } }; struct TestVectorTransferLoweringPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorTransferLoweringPatterns) void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } StringRef getArgument() const final { return "test-vector-transfer-lowering-patterns"; } StringRef getDescription() const final { return "Test lowering patterns to lower transfer ops to other vector ops"; } void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorTransferLoweringPatterns(patterns); populateVectorTransferPermutationMapLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; struct TestVectorMultiReductionLoweringPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorMultiReductionLoweringPatterns) TestVectorMultiReductionLoweringPatterns() = default; TestVectorMultiReductionLoweringPatterns( const TestVectorMultiReductionLoweringPatterns &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } StringRef getArgument() const final { return "test-vector-multi-reduction-lowering-patterns"; } StringRef getDescription() const final { return "Test lowering patterns to lower vector.multi_reduction to other " "vector ops"; } Option useOuterReductions{ *this, "use-outer-reductions", llvm::cl::desc("Move reductions to outer most dimensions"), llvm::cl::init(false)}; void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorMultiReductionLoweringPatterns( patterns, useOuterReductions ? vector::VectorMultiReductionLowering::InnerParallel : vector::VectorMultiReductionLowering::InnerReduction); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; struct TestVectorTransferCollapseInnerMostContiguousDims : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorTransferCollapseInnerMostContiguousDims) TestVectorTransferCollapseInnerMostContiguousDims() = default; TestVectorTransferCollapseInnerMostContiguousDims( const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } StringRef getArgument() const final { return "test-vector-transfer-collapse-inner-most-dims"; } StringRef getDescription() const final { return "Test lowering patterns that reducedes the rank of the vector " "transfer memory and vector operands."; } void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; struct TestVectorReduceToContractPatternsPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorReduceToContractPatternsPatterns) StringRef getArgument() const final { return "test-vector-reduction-to-contract-patterns"; } StringRef getDescription() const final { return "Test patterns to convert multireduce op to contract and combine " "broadcast/transpose to contract"; } void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorReductionToContractPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; struct TestVectorTransferDropUnitDimsPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestVectorTransferDropUnitDimsPatterns) StringRef getArgument() const final { return "test-vector-transfer-drop-unit-dims-patterns"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorTransferDropUnitDimsPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; struct TestFlattenVectorTransferPatterns : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestFlattenVectorTransferPatterns) StringRef getArgument() const final { return "test-vector-transfer-flatten-patterns"; } StringRef getDescription() const final { return "Test patterns to rewrite contiguous row-major N-dimensional " "vector.transfer_{read,write} ops into 1D transfers"; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateFlattenVectorTransferPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; struct TestVectorScanLowering : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering) StringRef getArgument() const final { return "test-vector-scan-lowering"; } StringRef getDescription() const final { return "Test lowering patterns that lower the scan op in the vector " "dialect"; } void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorScanLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; /// Allocate shared memory for a single warp to test lowering of /// WarpExecuteOnLane0Op. static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder, WarpExecuteOnLane0Op warpOp, Type type) { static constexpr int64_t kSharedMemorySpace = 3; // Compute type of shared memory buffer. MemRefType memrefType; if (auto vectorType = type.dyn_cast()) { memrefType = MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {}, kSharedMemorySpace); } else { memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace); } // Get symbol table holding all shared memory globals. ModuleOp moduleOp = warpOp->getParentOfType(); SymbolTable symbolTable(moduleOp); // Create a pretty name. SmallString<64> buf; llvm::raw_svector_ostream os(buf); interleave(memrefType.getShape(), os, "x"); os << "x" << memrefType.getElementType(); std::string symbolName = (Twine("__shared_") + os.str()).str(); auto ip = builder.saveInsertionPoint(); builder.setInsertionPoint(moduleOp); auto global = builder.create( loc, /*sym_name=*/symbolName, /*sym_visibility=*/builder.getStringAttr("private"), /*type=*/memrefType, /*initial_value=*/Attribute(), /*constant=*/false, /*alignment=*/IntegerAttr()); symbolTable.insert(global); // The symbol table inserts at the end of the module, but globals are a bit // nicer if they are at the beginning. global->moveBefore(&moduleOp.front()); builder.restoreInsertionPoint(ip); return builder.create(loc, memrefType, symbolName); } static Value warpReduction(Location loc, OpBuilder &builder, Value input, CombiningKind kind, uint32_t size) { Value laneVal = input; // Parallel reduction using butterfly shuffles. for (uint64_t i = 1; i < size; i <<= 1) { Value shuffled = builder .create(loc, laneVal, i, /*width=*/size, /*mode=*/gpu::ShuffleMode::XOR) .getShuffleResult(); laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled); } return laneVal; } struct TestVectorDistribution : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution) void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } StringRef getArgument() const final { return "test-vector-warp-distribute"; } StringRef getDescription() const final { return "Test vector warp distribute transformation and lowering patterns"; } TestVectorDistribution() = default; TestVectorDistribution(const TestVectorDistribution &pass) : PassWrapper(pass) {} Option warpOpToSCF{ *this, "rewrite-warp-ops-to-scf-if", llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"), llvm::cl::init(false)}; Option distributeTransferWriteOps{ *this, "distribute-transfer-write", llvm::cl::desc("Test distribution of transfer write"), llvm::cl::init(false)}; Option hoistUniform{*this, "hoist-uniform", llvm::cl::desc("Test hoist uniform"), llvm::cl::init(false)}; Option propagateDistribution{ *this, "propagate-distribution", llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)}; void runOnOperation() override { RewritePatternSet patterns(&getContext()); getOperation().walk([&](Operation *op) { if (auto warpOp = dyn_cast(op)) { if (hoistUniform) { moveScalarUniformCode(warpOp); } WalkResult::interrupt(); } }); MLIRContext *ctx = &getContext(); if (distributeTransferWriteOps) { auto distributionFn = [](vector::TransferWriteOp writeOp) { // Create a map (d0, d1) -> (d1) to distribute along the inner // dimension. Once we support n-d distribution we can add more // complex cases. int64_t vecRank = writeOp.getVectorType().getRank(); OpBuilder builder(writeOp.getContext()); auto map = AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); return map; }; RewritePatternSet patterns(ctx); populateDistributeTransferWriteOpPatterns(patterns, distributionFn); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } if (propagateDistribution) { RewritePatternSet patterns(ctx); vector::populatePropagateWarpVectorDistributionPatterns(patterns); vector::populateDistributeReduction(patterns, warpReduction); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } WarpExecuteOnLane0LoweringOptions options; options.warpAllocationFn = allocateGlobalSharedMemory; options.warpSyncronizationFn = [](Location loc, OpBuilder &builder, WarpExecuteOnLane0Op warpOp) { builder.create(loc); }; // Test on one pattern in isolation. if (warpOpToSCF) { populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); return; } } }; } // namespace namespace mlir { namespace test { void registerTestVectorLowerings() { PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); PassRegistration(); } } // namespace test } // namespace mlir