llvm-project/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

822 lines
31 KiB
C++

//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <type_traits>
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::vector;
namespace {
struct TestVectorToVectorLowering
: public PassWrapper<TestVectorToVectorLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
TestVectorToVectorLowering() = default;
TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
: PassWrapper(pass) {}
StringRef getArgument() const final {
return "test-vector-to-vector-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns between ops in the vector dialect";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
llvm::cl::init(false)};
void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
if (unroll) {
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
}
populateVectorToVectorCanonicalizationPatterns(patterns);
populateBubbleVectorBitCastOpPatterns(patterns);
populateCastAwayVectorLeadingOneDimPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
private:
// Return the target shape based on op type.
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
return SmallVector<int64_t, 4>(2, 2);
if (isa<vector::ContractionOp>(op))
return SmallVector<int64_t, 4>(3, 2);
// For transfer ops, just propagate the shape coming from
// InsertStridedSlices/ExtractStridedSlices.
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
VectorType dstVec;
for (Operation *users : readOp->getUsers()) {
auto extract = dyn_cast<ExtractStridedSliceOp>(users);
if (!extract)
return llvm::None;
auto vecType = extract.getResult().getType().cast<VectorType>();
if (dstVec && dstVec != vecType)
return llvm::None;
dstVec = vecType;
}
return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
dstVec.getShape().end());
}
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
if (!insert)
return llvm::None;
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
return SmallVector<int64_t, 4>(shape.begin(), shape.end());
}
return llvm::None;
}
static LogicalResult filter(Operation *op) {
return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
ContractionOp, TransferReadOp, TransferWriteOp>(op));
}
};
struct TestVectorContractionLowering
: public PassWrapper<TestVectorContractionLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
StringRef getArgument() const final {
return "test-vector-contraction-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorContractionLowering() = default;
TestVectorContractionLowering(const TestVectorContractionLowering &pass)
: PassWrapper(pass) {}
Option<bool> lowerToFlatMatrix{
*this, "vector-lower-matrix-intrinsics",
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
llvm::cl::init(false)};
Option<bool> lowerToOuterProduct{
*this, "vector-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
llvm::cl::init(false)};
Option<bool> lowerToFilterOuterProduct{
*this, "vector-filter-outerproduct",
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
"vectors of size 4."),
llvm::cl::init(false)};
Option<bool> lowerToParallelArith{
*this, "vector-parallel-arith",
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
// Test on one pattern in isolation.
if (lowerToOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(options,
&getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
// Test on one pattern in isolation.
if (lowerToFilterOuterProduct) {
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
VectorTransformsOptions options{lowering};
patterns.add<ContractionOpToOuterProductOpLowering>(
options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
// where M is not 4.
if (op.getRhsType().getShape()[0] == 4)
return failure();
return success();
});
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
if (lowerToParallelArith) {
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::ParallelArith));
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
// Test on all contract lowering patterns.
VectorContractLowering contractLowering = VectorContractLowering::Dot;
if (lowerToFlatMatrix)
contractLowering = VectorContractLowering::Matmul;
VectorMultiReductionLowering vectorMultiReductionLowering =
VectorMultiReductionLowering::InnerParallel;
VectorTransformsOptions options{contractLowering,
vectorMultiReductionLowering,
VectorTransposeLowering()};
populateVectorBroadcastLoweringPatterns(patterns);
populateVectorContractLoweringPatterns(patterns, options);
populateVectorMaskOpLoweringPatterns(patterns);
populateVectorShapeCastLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransposeLowering
: public PassWrapper<TestVectorTransposeLowering,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
StringRef getArgument() const final {
return "test-vector-transpose-lowering";
}
StringRef getDescription() const final {
return "Test lowering patterns that lower contract ops in the vector "
"dialect";
}
TestVectorTransposeLowering() = default;
TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
: PassWrapper(pass) {}
Option<bool> lowerToEltwise{
*this, "eltwise",
llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
llvm::cl::init(false)};
Option<bool> lowerToFlatTranspose{
*this, "flat",
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
llvm::cl::init(false)};
Option<bool> lowerToShuffleTranspose{
*this, "shuffle",
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
llvm::cl::init(false)};
Option<bool> lowerToAvx2{
*this, "avx2",
llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
llvm::cl::init(false)};
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
vector::VectorTransformsOptions vectorTransformOptions;
if (lowerToEltwise) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::EltWise);
}
if (lowerToFlatTranspose) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::Flat);
}
if (lowerToShuffleTranspose) {
vectorTransformOptions =
vectorTransformOptions.setVectorTransposeLowering(
VectorTransposeLowering::Shuffle);
}
vector::populateVectorTransposeLoweringPatterns(patterns,
vectorTransformOptions);
if (lowerToAvx2) {
auto avx2LoweringOptions =
x86vector::avx2::LoweringOptions().setTransposeOptions(
x86vector::avx2::TransposeLoweringOptions()
.lower4x8xf32()
.lower8x8xf32());
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
patterns, avx2LoweringOptions, /*benefit=*/10);
}
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
return signalPassFailure();
}
};
struct TestVectorUnrollingPatterns
: public PassWrapper<TestVectorUnrollingPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
StringRef getArgument() const final {
return "test-vector-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to unroll contract ops in the vector "
"dialect";
}
TestVectorUnrollingPatterns() = default;
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
: PassWrapper(pass) {}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<arith::AddFOp, vector::FMAOp,
vector::MultiDimReductionOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::ReductionOp>(op));
}));
populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
.setFilterConstraint([](Operation *op) {
return success(isa<vector::TransposeOp>(op));
}));
if (unrollBasedOnType) {
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
SmallVector<int64_t, 4> nativeShape(
contractOp.getIteratorTypes().size(), 4);
Type lhsType = contractOp.getLhsType().getElementType();
nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
return nativeShape;
};
UnrollVectorOptions opts;
opts.setNativeShapeFn(nativeShapeFn)
.setFilterConstraint(
[](Operation *op) { return success(isa<ContractionOp>(op)); });
if (!unrollOrder.empty()) {
opts.setUnrollTraversalOrderFn([this](Operation *op)
-> Optional<SmallVector<int64_t>> {
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
if (contractOp.getIteratorTypes().size() == unrollOrder.size())
return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
return None;
});
}
populateVectorUnrollPatterns(patterns, opts);
} else {
auto nativeShapeFn =
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
auto contractOp = dyn_cast<ContractionOp>(op);
if (!contractOp)
return None;
return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
};
populateVectorUnrollPatterns(patterns,
UnrollVectorOptions()
.setNativeShapeFn(nativeShapeFn)
.setFilterConstraint([](Operation *op) {
return success(isa<ContractionOp>(op));
}));
}
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
ListOption<int64_t> unrollOrder{*this, "unroll-order",
llvm::cl::desc("set the unroll order")};
Option<bool> unrollBasedOnType{
*this, "unroll-based-on-type",
llvm::cl::desc("Set the unroll factor based on type of the operation"),
llvm::cl::init(false)};
};
struct TestVectorTransferUnrollingPatterns
: public PassWrapper<TestVectorTransferUnrollingPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferUnrollingPatterns)
TestVectorTransferUnrollingPatterns() = default;
TestVectorTransferUnrollingPatterns(
const TestVectorTransferUnrollingPatterns &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-unrolling-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to unroll transfer ops in the vector "
"dialect";
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
UnrollVectorOptions opts;
opts.setNativeShape(ArrayRef<int64_t>{2, 2})
.setFilterConstraint([](Operation *op) {
return success(
isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
});
if (reverseUnrollOrder.getValue()) {
opts.setUnrollTraversalOrderFn(
[](Operation *op) -> Optional<SmallVector<int64_t>> {
int64_t numLoops = 0;
if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
numLoops = readOp.getVectorType().getRank();
else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
numLoops = writeOp.getVectorType().getRank();
else
return None;
auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
return llvm::to_vector(order);
});
}
populateVectorUnrollPatterns(patterns, opts);
populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
Option<bool> reverseUnrollOrder{
*this, "reverse-unroll-order",
llvm::cl::desc(
"reverse the order of unrolling of vector transfer operations"),
llvm::cl::init(false)};
};
struct TestVectorTransferFullPartialSplitPatterns
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferFullPartialSplitPatterns)
StringRef getArgument() const final {
return "test-vector-transfer-full-partial-split";
}
StringRef getDescription() const final {
return "Test lowering patterns to split "
"transfer ops via scf.if + linalg ops";
}
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect>();
}
Option<bool> useLinalgOps{
*this, "use-memref-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
"memref.copy operations."),
llvm::cl::init(false)};
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
else
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransferOpt
: public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
StringRef getArgument() const final { return "test-vector-transferop-opt"; }
StringRef getDescription() const final {
return "Test optimization transformations for transfer ops";
}
void runOnOperation() override { transferOpflowOpt(getOperation()); }
};
struct TestVectorTransferLoweringPatterns
: public PassWrapper<TestVectorTransferLoweringPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferLoweringPatterns)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-lowering-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to lower transfer ops to other vector ops";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferLoweringPatterns(patterns);
populateVectorTransferPermutationMapLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorMultiReductionLoweringPatterns
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorMultiReductionLoweringPatterns)
TestVectorMultiReductionLoweringPatterns() = default;
TestVectorMultiReductionLoweringPatterns(
const TestVectorMultiReductionLoweringPatterns &pass)
: PassWrapper(pass) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
StringRef getArgument() const final {
return "test-vector-multi-reduction-lowering-patterns";
}
StringRef getDescription() const final {
return "Test lowering patterns to lower vector.multi_reduction to other "
"vector ops";
}
Option<bool> useOuterReductions{
*this, "use-outer-reductions",
llvm::cl::desc("Move reductions to outer most dimensions"),
llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorMultiReductionLoweringPatterns(
patterns, useOuterReductions
? vector::VectorMultiReductionLowering::InnerParallel
: vector::VectorMultiReductionLowering::InnerReduction);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransferCollapseInnerMostContiguousDims
: public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferCollapseInnerMostContiguousDims)
TestVectorTransferCollapseInnerMostContiguousDims() = default;
TestVectorTransferCollapseInnerMostContiguousDims(
const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect, AffineDialect>();
}
StringRef getArgument() const final {
return "test-vector-transfer-collapse-inner-most-dims";
}
StringRef getDescription() const final {
return "Test lowering patterns that reducedes the rank of the vector "
"transfer memory and vector operands.";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorReduceToContractPatternsPatterns
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorReduceToContractPatternsPatterns)
StringRef getArgument() const final {
return "test-vector-reduction-to-contract-patterns";
}
StringRef getDescription() const final {
return "Test patterns to convert multireduce op to contract and combine "
"broadcast/transpose to contract";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorReductionToContractPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorTransferDropUnitDimsPatterns
: public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestVectorTransferDropUnitDimsPatterns)
StringRef getArgument() const final {
return "test-vector-transfer-drop-unit-dims-patterns";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorTransferDropUnitDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestFlattenVectorTransferPatterns
: public PassWrapper<TestFlattenVectorTransferPatterns,
OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestFlattenVectorTransferPatterns)
StringRef getArgument() const final {
return "test-vector-transfer-flatten-patterns";
}
StringRef getDescription() const final {
return "Test patterns to rewrite contiguous row-major N-dimensional "
"vector.transfer_{read,write} ops into 1D transfers";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateFlattenVectorTransferPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
struct TestVectorScanLowering
: public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
StringRef getArgument() const final { return "test-vector-scan-lowering"; }
StringRef getDescription() const final {
return "Test lowering patterns that lower the scan op in the vector "
"dialect";
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorScanLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
/// Allocate shared memory for a single warp to test lowering of
/// WarpExecuteOnLane0Op.
static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
WarpExecuteOnLane0Op warpOp,
Type type) {
static constexpr int64_t kSharedMemorySpace = 3;
// Compute type of shared memory buffer.
MemRefType memrefType;
if (auto vectorType = type.dyn_cast<VectorType>()) {
memrefType =
MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
kSharedMemorySpace);
} else {
memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
}
// Get symbol table holding all shared memory globals.
ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
SymbolTable symbolTable(moduleOp);
// Create a pretty name.
SmallString<64> buf;
llvm::raw_svector_ostream os(buf);
interleave(memrefType.getShape(), os, "x");
os << "x" << memrefType.getElementType();
std::string symbolName = (Twine("__shared_") + os.str()).str();
auto ip = builder.saveInsertionPoint();
builder.setInsertionPoint(moduleOp);
auto global = builder.create<memref::GlobalOp>(
loc,
/*sym_name=*/symbolName,
/*sym_visibility=*/builder.getStringAttr("private"),
/*type=*/memrefType,
/*initial_value=*/Attribute(),
/*constant=*/false,
/*alignment=*/IntegerAttr());
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
global->moveBefore(&moduleOp.front());
builder.restoreInsertionPoint(ip);
return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
}
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
CombiningKind kind, uint32_t size) {
Value laneVal = input;
// Parallel reduction using butterfly shuffles.
for (uint64_t i = 1; i < size; i <<= 1) {
Value shuffled = builder
.create<gpu::ShuffleOp>(loc, laneVal, i,
/*width=*/size,
/*mode=*/gpu::ShuffleMode::XOR)
.getShuffleResult();
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
}
return laneVal;
}
struct TestVectorDistribution
: public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
AffineDialect>();
}
StringRef getArgument() const final { return "test-vector-warp-distribute"; }
StringRef getDescription() const final {
return "Test vector warp distribute transformation and lowering patterns";
}
TestVectorDistribution() = default;
TestVectorDistribution(const TestVectorDistribution &pass)
: PassWrapper(pass) {}
Option<bool> warpOpToSCF{
*this, "rewrite-warp-ops-to-scf-if",
llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
llvm::cl::init(false)};
Option<bool> distributeTransferWriteOps{
*this, "distribute-transfer-write",
llvm::cl::desc("Test distribution of transfer write"),
llvm::cl::init(false)};
Option<bool> hoistUniform{*this, "hoist-uniform",
llvm::cl::desc("Test hoist uniform"),
llvm::cl::init(false)};
Option<bool> propagateDistribution{
*this, "propagate-distribution",
llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
getOperation().walk([&](Operation *op) {
if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
if (hoistUniform) {
moveScalarUniformCode(warpOp);
}
WalkResult::interrupt();
}
});
MLIRContext *ctx = &getContext();
if (distributeTransferWriteOps) {
auto distributionFn = [](vector::TransferWriteOp writeOp) {
// Create a map (d0, d1) -> (d1) to distribute along the inner
// dimension. Once we support n-d distribution we can add more
// complex cases.
int64_t vecRank = writeOp.getVectorType().getRank();
OpBuilder builder(writeOp.getContext());
auto map =
AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
return map;
};
RewritePatternSet patterns(ctx);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
if (propagateDistribution) {
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(patterns);
vector::populateDistributeReduction(patterns, warpReduction);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
WarpExecuteOnLane0LoweringOptions options;
options.warpAllocationFn = allocateGlobalSharedMemory;
options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
WarpExecuteOnLane0Op warpOp) {
builder.create<gpu::BarrierOp>(loc);
};
// Test on one pattern in isolation.
if (warpOpToSCF) {
populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
return;
}
}
};
} // namespace
namespace mlir {
namespace test {
void registerTestVectorLowerings() {
PassRegistration<TestVectorToVectorLowering>();
PassRegistration<TestVectorContractionLowering>();
PassRegistration<TestVectorTransposeLowering>();
PassRegistration<TestVectorUnrollingPatterns>();
PassRegistration<TestVectorTransferUnrollingPatterns>();
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
PassRegistration<TestVectorTransferOpt>();
PassRegistration<TestVectorTransferLoweringPatterns>();
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
PassRegistration<TestFlattenVectorTransferPatterns>();
PassRegistration<TestVectorScanLowering>();
PassRegistration<TestVectorDistribution>();
}
} // namespace test
} // namespace mlir