
Context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785 Depends on D135200 Differential Revision: https://reviews.llvm.org/D135222
822 lines
31 KiB
C++
822 lines
31 KiB
C++
//===- TestVectorTransforms.cpp - Test Vector transforms and lowerings ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include <type_traits>
|
|
|
|
#include "mlir/Analysis/SliceAnalysis.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/Vector/IR/VectorOps.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorDistribution.h"
|
|
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
using namespace mlir::vector;
|
|
|
|
namespace {
|
|
|
|
struct TestVectorToVectorLowering
|
|
: public PassWrapper<TestVectorToVectorLowering,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorToVectorLowering)
|
|
|
|
TestVectorToVectorLowering() = default;
|
|
TestVectorToVectorLowering(const TestVectorToVectorLowering &pass)
|
|
: PassWrapper(pass) {}
|
|
StringRef getArgument() const final {
|
|
return "test-vector-to-vector-lowering";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns between ops in the vector dialect";
|
|
}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<AffineDialect>();
|
|
}
|
|
|
|
Option<bool> unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override {
|
|
auto *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
if (unroll) {
|
|
populateVectorUnrollPatterns(
|
|
patterns,
|
|
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
|
|
filter));
|
|
}
|
|
populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
populateBubbleVectorBitCastOpPatterns(patterns);
|
|
populateCastAwayVectorLeadingOneDimPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
private:
|
|
// Return the target shape based on op type.
|
|
static Optional<SmallVector<int64_t, 4>> getShape(Operation *op) {
|
|
if (isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp>(op))
|
|
return SmallVector<int64_t, 4>(2, 2);
|
|
if (isa<vector::ContractionOp>(op))
|
|
return SmallVector<int64_t, 4>(3, 2);
|
|
// For transfer ops, just propagate the shape coming from
|
|
// InsertStridedSlices/ExtractStridedSlices.
|
|
if (auto readOp = dyn_cast<vector::TransferReadOp>(op)) {
|
|
VectorType dstVec;
|
|
for (Operation *users : readOp->getUsers()) {
|
|
auto extract = dyn_cast<ExtractStridedSliceOp>(users);
|
|
if (!extract)
|
|
return llvm::None;
|
|
auto vecType = extract.getResult().getType().cast<VectorType>();
|
|
if (dstVec && dstVec != vecType)
|
|
return llvm::None;
|
|
dstVec = vecType;
|
|
}
|
|
return SmallVector<int64_t, 4>(dstVec.getShape().begin(),
|
|
dstVec.getShape().end());
|
|
}
|
|
if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op)) {
|
|
auto insert = writeOp.getVector().getDefiningOp<InsertStridedSliceOp>();
|
|
if (!insert)
|
|
return llvm::None;
|
|
ArrayRef<int64_t> shape = insert.getSourceVectorType().getShape();
|
|
return SmallVector<int64_t, 4>(shape.begin(), shape.end());
|
|
}
|
|
return llvm::None;
|
|
}
|
|
|
|
static LogicalResult filter(Operation *op) {
|
|
return success(isa<arith::AddFOp, arith::SelectOp, arith::CmpFOp,
|
|
ContractionOp, TransferReadOp, TransferWriteOp>(op));
|
|
}
|
|
};
|
|
|
|
struct TestVectorContractionLowering
|
|
: public PassWrapper<TestVectorContractionLowering,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorContractionLowering)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-contraction-lowering";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns that lower contract ops in the vector "
|
|
"dialect";
|
|
}
|
|
TestVectorContractionLowering() = default;
|
|
TestVectorContractionLowering(const TestVectorContractionLowering &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
Option<bool> lowerToFlatMatrix{
|
|
*this, "vector-lower-matrix-intrinsics",
|
|
llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToOuterProduct{
|
|
*this, "vector-outerproduct",
|
|
llvm::cl::desc("Lower vector.contract to vector.outerproduct"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToFilterOuterProduct{
|
|
*this, "vector-filter-outerproduct",
|
|
llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for "
|
|
"vectors of size 4."),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToParallelArith{
|
|
*this, "vector-parallel-arith",
|
|
llvm::cl::desc("Lower vector.contract to elementwise vector ops."),
|
|
llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
|
|
// Test on one pattern in isolation.
|
|
if (lowerToOuterProduct) {
|
|
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
|
|
VectorTransformsOptions options{lowering};
|
|
patterns.add<ContractionOpToOuterProductOpLowering>(options,
|
|
&getContext());
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
return;
|
|
}
|
|
|
|
// Test on one pattern in isolation.
|
|
if (lowerToFilterOuterProduct) {
|
|
VectorContractLowering lowering = VectorContractLowering::OuterProduct;
|
|
VectorTransformsOptions options{lowering};
|
|
patterns.add<ContractionOpToOuterProductOpLowering>(
|
|
options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) {
|
|
// Only lowers vector.contract where the lhs as a type vector<MxNx?>
|
|
// where M is not 4.
|
|
if (op.getRhsType().getShape()[0] == 4)
|
|
return failure();
|
|
return success();
|
|
});
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
return;
|
|
}
|
|
|
|
if (lowerToParallelArith) {
|
|
vector::populateVectorContractLoweringPatterns(
|
|
patterns,
|
|
vector::VectorTransformsOptions().setVectorTransformsOptions(
|
|
vector::VectorContractLowering::ParallelArith));
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
return;
|
|
}
|
|
|
|
// Test on all contract lowering patterns.
|
|
VectorContractLowering contractLowering = VectorContractLowering::Dot;
|
|
if (lowerToFlatMatrix)
|
|
contractLowering = VectorContractLowering::Matmul;
|
|
VectorMultiReductionLowering vectorMultiReductionLowering =
|
|
VectorMultiReductionLowering::InnerParallel;
|
|
VectorTransformsOptions options{contractLowering,
|
|
vectorMultiReductionLowering,
|
|
VectorTransposeLowering()};
|
|
populateVectorBroadcastLoweringPatterns(patterns);
|
|
populateVectorContractLoweringPatterns(patterns, options);
|
|
populateVectorMaskOpLoweringPatterns(patterns);
|
|
populateVectorShapeCastLoweringPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorTransposeLowering
|
|
: public PassWrapper<TestVectorTransposeLowering,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransposeLowering)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transpose-lowering";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns that lower contract ops in the vector "
|
|
"dialect";
|
|
}
|
|
TestVectorTransposeLowering() = default;
|
|
TestVectorTransposeLowering(const TestVectorTransposeLowering &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
Option<bool> lowerToEltwise{
|
|
*this, "eltwise",
|
|
llvm::cl::desc("Lower 2-D vector.transpose to eltwise insert/extract"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToFlatTranspose{
|
|
*this, "flat",
|
|
llvm::cl::desc("Lower 2-D vector.transpose to vector.flat_transpose"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToShuffleTranspose{
|
|
*this, "shuffle",
|
|
llvm::cl::desc("Lower 2-D vector.transpose to shape_cast + shuffle"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> lowerToAvx2{
|
|
*this, "avx2",
|
|
llvm::cl::desc("Lower vector.transpose to avx2-specific patterns"),
|
|
llvm::cl::init(false)};
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<LLVM::LLVMDialect>();
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
func::FuncOp funcOp = getOperation();
|
|
MLIRContext *context = funcOp.getContext();
|
|
RewritePatternSet patterns(context);
|
|
|
|
vector::VectorTransformsOptions vectorTransformOptions;
|
|
if (lowerToEltwise) {
|
|
vectorTransformOptions =
|
|
vectorTransformOptions.setVectorTransposeLowering(
|
|
VectorTransposeLowering::EltWise);
|
|
}
|
|
if (lowerToFlatTranspose) {
|
|
vectorTransformOptions =
|
|
vectorTransformOptions.setVectorTransposeLowering(
|
|
VectorTransposeLowering::Flat);
|
|
}
|
|
if (lowerToShuffleTranspose) {
|
|
vectorTransformOptions =
|
|
vectorTransformOptions.setVectorTransposeLowering(
|
|
VectorTransposeLowering::Shuffle);
|
|
}
|
|
vector::populateVectorTransposeLoweringPatterns(patterns,
|
|
vectorTransformOptions);
|
|
|
|
if (lowerToAvx2) {
|
|
auto avx2LoweringOptions =
|
|
x86vector::avx2::LoweringOptions().setTransposeOptions(
|
|
x86vector::avx2::TransposeLoweringOptions()
|
|
.lower4x8xf32()
|
|
.lower8x8xf32());
|
|
x86vector::avx2::populateSpecializedTransposeLoweringPatterns(
|
|
patterns, avx2LoweringOptions, /*benefit=*/10);
|
|
}
|
|
|
|
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
|
|
return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
struct TestVectorUnrollingPatterns
|
|
: public PassWrapper<TestVectorUnrollingPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorUnrollingPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-unrolling-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to unroll contract ops in the vector "
|
|
"dialect";
|
|
}
|
|
TestVectorUnrollingPatterns() = default;
|
|
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass)
|
|
: PassWrapper(pass) {}
|
|
void runOnOperation() override {
|
|
MLIRContext *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
populateVectorUnrollPatterns(
|
|
patterns, UnrollVectorOptions()
|
|
.setNativeShape(ArrayRef<int64_t>{2, 2})
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(isa<arith::AddFOp, vector::FMAOp,
|
|
vector::MultiDimReductionOp>(op));
|
|
}));
|
|
populateVectorUnrollPatterns(
|
|
patterns, UnrollVectorOptions()
|
|
.setNativeShape(ArrayRef<int64_t>{2})
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(isa<vector::ReductionOp>(op));
|
|
}));
|
|
populateVectorUnrollPatterns(
|
|
patterns, UnrollVectorOptions()
|
|
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(isa<vector::TransposeOp>(op));
|
|
}));
|
|
|
|
if (unrollBasedOnType) {
|
|
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
|
|
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
|
|
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
|
|
SmallVector<int64_t, 4> nativeShape(
|
|
contractOp.getIteratorTypes().size(), 4);
|
|
Type lhsType = contractOp.getLhsType().getElementType();
|
|
nativeShape[nativeShape.size() - 1] = lhsType.isF16() ? 4 : 2;
|
|
return nativeShape;
|
|
};
|
|
|
|
UnrollVectorOptions opts;
|
|
opts.setNativeShapeFn(nativeShapeFn)
|
|
.setFilterConstraint(
|
|
[](Operation *op) { return success(isa<ContractionOp>(op)); });
|
|
|
|
if (!unrollOrder.empty()) {
|
|
opts.setUnrollTraversalOrderFn([this](Operation *op)
|
|
-> Optional<SmallVector<int64_t>> {
|
|
vector::ContractionOp contractOp = cast<vector::ContractionOp>(op);
|
|
if (contractOp.getIteratorTypes().size() == unrollOrder.size())
|
|
return SmallVector<int64_t>(unrollOrder.begin(), unrollOrder.end());
|
|
return None;
|
|
});
|
|
}
|
|
populateVectorUnrollPatterns(patterns, opts);
|
|
} else {
|
|
auto nativeShapeFn =
|
|
[](Operation *op) -> Optional<SmallVector<int64_t, 4>> {
|
|
auto contractOp = dyn_cast<ContractionOp>(op);
|
|
if (!contractOp)
|
|
return None;
|
|
return SmallVector<int64_t, 4>(contractOp.getIteratorTypes().size(), 2);
|
|
};
|
|
populateVectorUnrollPatterns(patterns,
|
|
UnrollVectorOptions()
|
|
.setNativeShapeFn(nativeShapeFn)
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(isa<ContractionOp>(op));
|
|
}));
|
|
}
|
|
populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
ListOption<int64_t> unrollOrder{*this, "unroll-order",
|
|
llvm::cl::desc("set the unroll order")};
|
|
|
|
Option<bool> unrollBasedOnType{
|
|
*this, "unroll-based-on-type",
|
|
llvm::cl::desc("Set the unroll factor based on type of the operation"),
|
|
llvm::cl::init(false)};
|
|
};
|
|
|
|
struct TestVectorTransferUnrollingPatterns
|
|
: public PassWrapper<TestVectorTransferUnrollingPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferUnrollingPatterns)
|
|
|
|
TestVectorTransferUnrollingPatterns() = default;
|
|
TestVectorTransferUnrollingPatterns(
|
|
const TestVectorTransferUnrollingPatterns &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<AffineDialect>();
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-unrolling-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to unroll transfer ops in the vector "
|
|
"dialect";
|
|
}
|
|
void runOnOperation() override {
|
|
MLIRContext *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
UnrollVectorOptions opts;
|
|
opts.setNativeShape(ArrayRef<int64_t>{2, 2})
|
|
.setFilterConstraint([](Operation *op) {
|
|
return success(
|
|
isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
|
|
});
|
|
if (reverseUnrollOrder.getValue()) {
|
|
opts.setUnrollTraversalOrderFn(
|
|
[](Operation *op) -> Optional<SmallVector<int64_t>> {
|
|
int64_t numLoops = 0;
|
|
if (auto readOp = dyn_cast<vector::TransferReadOp>(op))
|
|
numLoops = readOp.getVectorType().getRank();
|
|
else if (auto writeOp = dyn_cast<vector::TransferWriteOp>(op))
|
|
numLoops = writeOp.getVectorType().getRank();
|
|
else
|
|
return None;
|
|
auto order = llvm::reverse(llvm::seq<int64_t>(0, numLoops));
|
|
return llvm::to_vector(order);
|
|
});
|
|
}
|
|
populateVectorUnrollPatterns(patterns, opts);
|
|
populateVectorToVectorCanonicalizationPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
|
|
Option<bool> reverseUnrollOrder{
|
|
*this, "reverse-unroll-order",
|
|
llvm::cl::desc(
|
|
"reverse the order of unrolling of vector transfer operations"),
|
|
llvm::cl::init(false)};
|
|
};
|
|
|
|
struct TestVectorTransferFullPartialSplitPatterns
|
|
: public PassWrapper<TestVectorTransferFullPartialSplitPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferFullPartialSplitPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-full-partial-split";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to split "
|
|
"transfer ops via scf.if + linalg ops";
|
|
}
|
|
TestVectorTransferFullPartialSplitPatterns() = default;
|
|
TestVectorTransferFullPartialSplitPatterns(
|
|
const TestVectorTransferFullPartialSplitPatterns &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
|
|
scf::SCFDialect>();
|
|
}
|
|
|
|
Option<bool> useLinalgOps{
|
|
*this, "use-memref-copy",
|
|
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
|
|
"memref.copy operations."),
|
|
llvm::cl::init(false)};
|
|
void runOnOperation() override {
|
|
MLIRContext *ctx = &getContext();
|
|
RewritePatternSet patterns(ctx);
|
|
VectorTransformsOptions options;
|
|
if (useLinalgOps)
|
|
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
|
|
else
|
|
options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer);
|
|
patterns.add<VectorTransferFullPartialRewriter>(ctx, options);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorTransferOpt
|
|
: public PassWrapper<TestVectorTransferOpt, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorTransferOpt)
|
|
|
|
StringRef getArgument() const final { return "test-vector-transferop-opt"; }
|
|
StringRef getDescription() const final {
|
|
return "Test optimization transformations for transfer ops";
|
|
}
|
|
void runOnOperation() override { transferOpflowOpt(getOperation()); }
|
|
};
|
|
|
|
struct TestVectorTransferLoweringPatterns
|
|
: public PassWrapper<TestVectorTransferLoweringPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferLoweringPatterns)
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<tensor::TensorDialect, memref::MemRefDialect>();
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-lowering-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to lower transfer ops to other vector ops";
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorTransferLoweringPatterns(patterns);
|
|
populateVectorTransferPermutationMapLoweringPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorMultiReductionLoweringPatterns
|
|
: public PassWrapper<TestVectorMultiReductionLoweringPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorMultiReductionLoweringPatterns)
|
|
|
|
TestVectorMultiReductionLoweringPatterns() = default;
|
|
TestVectorMultiReductionLoweringPatterns(
|
|
const TestVectorMultiReductionLoweringPatterns &pass)
|
|
: PassWrapper(pass) {}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect>();
|
|
}
|
|
StringRef getArgument() const final {
|
|
return "test-vector-multi-reduction-lowering-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns to lower vector.multi_reduction to other "
|
|
"vector ops";
|
|
}
|
|
Option<bool> useOuterReductions{
|
|
*this, "use-outer-reductions",
|
|
llvm::cl::desc("Move reductions to outer most dimensions"),
|
|
llvm::cl::init(false)};
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorMultiReductionLoweringPatterns(
|
|
patterns, useOuterReductions
|
|
? vector::VectorMultiReductionLowering::InnerParallel
|
|
: vector::VectorMultiReductionLowering::InnerReduction);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorTransferCollapseInnerMostContiguousDims
|
|
: public PassWrapper<TestVectorTransferCollapseInnerMostContiguousDims,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferCollapseInnerMostContiguousDims)
|
|
|
|
TestVectorTransferCollapseInnerMostContiguousDims() = default;
|
|
TestVectorTransferCollapseInnerMostContiguousDims(
|
|
const TestVectorTransferCollapseInnerMostContiguousDims &pass) = default;
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect, AffineDialect>();
|
|
}
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-collapse-inner-most-dims";
|
|
}
|
|
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns that reducedes the rank of the vector "
|
|
"transfer memory and vector operands.";
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorReduceToContractPatternsPatterns
|
|
: public PassWrapper<TestVectorReduceToContractPatternsPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorReduceToContractPatternsPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-reduction-to-contract-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test patterns to convert multireduce op to contract and combine "
|
|
"broadcast/transpose to contract";
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorReductionToContractPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorTransferDropUnitDimsPatterns
|
|
: public PassWrapper<TestVectorTransferDropUnitDimsPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestVectorTransferDropUnitDimsPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-drop-unit-dims-patterns";
|
|
}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect>();
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorTransferDropUnitDimsPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestFlattenVectorTransferPatterns
|
|
: public PassWrapper<TestFlattenVectorTransferPatterns,
|
|
OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
|
|
TestFlattenVectorTransferPatterns)
|
|
|
|
StringRef getArgument() const final {
|
|
return "test-vector-transfer-flatten-patterns";
|
|
}
|
|
StringRef getDescription() const final {
|
|
return "Test patterns to rewrite contiguous row-major N-dimensional "
|
|
"vector.transfer_{read,write} ops into 1D transfers";
|
|
}
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<memref::MemRefDialect>();
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateFlattenVectorTransferPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
struct TestVectorScanLowering
|
|
: public PassWrapper<TestVectorScanLowering, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorScanLowering)
|
|
|
|
StringRef getArgument() const final { return "test-vector-scan-lowering"; }
|
|
StringRef getDescription() const final {
|
|
return "Test lowering patterns that lower the scan op in the vector "
|
|
"dialect";
|
|
}
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
populateVectorScanLoweringPatterns(patterns);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
};
|
|
|
|
/// Allocate shared memory for a single warp to test lowering of
|
|
/// WarpExecuteOnLane0Op.
|
|
static Value allocateGlobalSharedMemory(Location loc, OpBuilder &builder,
|
|
WarpExecuteOnLane0Op warpOp,
|
|
Type type) {
|
|
static constexpr int64_t kSharedMemorySpace = 3;
|
|
// Compute type of shared memory buffer.
|
|
MemRefType memrefType;
|
|
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
|
memrefType =
|
|
MemRefType::get(vectorType.getShape(), vectorType.getElementType(), {},
|
|
kSharedMemorySpace);
|
|
} else {
|
|
memrefType = MemRefType::get({1}, type, {}, kSharedMemorySpace);
|
|
}
|
|
|
|
// Get symbol table holding all shared memory globals.
|
|
ModuleOp moduleOp = warpOp->getParentOfType<ModuleOp>();
|
|
SymbolTable symbolTable(moduleOp);
|
|
|
|
// Create a pretty name.
|
|
SmallString<64> buf;
|
|
llvm::raw_svector_ostream os(buf);
|
|
interleave(memrefType.getShape(), os, "x");
|
|
os << "x" << memrefType.getElementType();
|
|
std::string symbolName = (Twine("__shared_") + os.str()).str();
|
|
|
|
auto ip = builder.saveInsertionPoint();
|
|
builder.setInsertionPoint(moduleOp);
|
|
auto global = builder.create<memref::GlobalOp>(
|
|
loc,
|
|
/*sym_name=*/symbolName,
|
|
/*sym_visibility=*/builder.getStringAttr("private"),
|
|
/*type=*/memrefType,
|
|
/*initial_value=*/Attribute(),
|
|
/*constant=*/false,
|
|
/*alignment=*/IntegerAttr());
|
|
symbolTable.insert(global);
|
|
// The symbol table inserts at the end of the module, but globals are a bit
|
|
// nicer if they are at the beginning.
|
|
global->moveBefore(&moduleOp.front());
|
|
|
|
builder.restoreInsertionPoint(ip);
|
|
return builder.create<memref::GetGlobalOp>(loc, memrefType, symbolName);
|
|
}
|
|
|
|
static Value warpReduction(Location loc, OpBuilder &builder, Value input,
|
|
CombiningKind kind, uint32_t size) {
|
|
Value laneVal = input;
|
|
// Parallel reduction using butterfly shuffles.
|
|
for (uint64_t i = 1; i < size; i <<= 1) {
|
|
Value shuffled = builder
|
|
.create<gpu::ShuffleOp>(loc, laneVal, i,
|
|
/*width=*/size,
|
|
/*mode=*/gpu::ShuffleMode::XOR)
|
|
.getShuffleResult();
|
|
laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
|
|
}
|
|
return laneVal;
|
|
}
|
|
|
|
struct TestVectorDistribution
|
|
: public PassWrapper<TestVectorDistribution, OperationPass<func::FuncOp>> {
|
|
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorDistribution)
|
|
|
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
|
registry.insert<scf::SCFDialect, memref::MemRefDialect, gpu::GPUDialect,
|
|
AffineDialect>();
|
|
}
|
|
|
|
StringRef getArgument() const final { return "test-vector-warp-distribute"; }
|
|
StringRef getDescription() const final {
|
|
return "Test vector warp distribute transformation and lowering patterns";
|
|
}
|
|
TestVectorDistribution() = default;
|
|
TestVectorDistribution(const TestVectorDistribution &pass)
|
|
: PassWrapper(pass) {}
|
|
|
|
Option<bool> warpOpToSCF{
|
|
*this, "rewrite-warp-ops-to-scf-if",
|
|
llvm::cl::desc("Lower vector.warp_execute_on_lane0 to scf.if op"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> distributeTransferWriteOps{
|
|
*this, "distribute-transfer-write",
|
|
llvm::cl::desc("Test distribution of transfer write"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> hoistUniform{*this, "hoist-uniform",
|
|
llvm::cl::desc("Test hoist uniform"),
|
|
llvm::cl::init(false)};
|
|
|
|
Option<bool> propagateDistribution{
|
|
*this, "propagate-distribution",
|
|
llvm::cl::desc("Test distribution propgation"), llvm::cl::init(false)};
|
|
|
|
void runOnOperation() override {
|
|
RewritePatternSet patterns(&getContext());
|
|
|
|
getOperation().walk([&](Operation *op) {
|
|
if (auto warpOp = dyn_cast<WarpExecuteOnLane0Op>(op)) {
|
|
if (hoistUniform) {
|
|
moveScalarUniformCode(warpOp);
|
|
}
|
|
WalkResult::interrupt();
|
|
}
|
|
});
|
|
MLIRContext *ctx = &getContext();
|
|
if (distributeTransferWriteOps) {
|
|
auto distributionFn = [](vector::TransferWriteOp writeOp) {
|
|
// Create a map (d0, d1) -> (d1) to distribute along the inner
|
|
// dimension. Once we support n-d distribution we can add more
|
|
// complex cases.
|
|
int64_t vecRank = writeOp.getVectorType().getRank();
|
|
OpBuilder builder(writeOp.getContext());
|
|
auto map =
|
|
AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
|
|
return map;
|
|
};
|
|
RewritePatternSet patterns(ctx);
|
|
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
if (propagateDistribution) {
|
|
RewritePatternSet patterns(ctx);
|
|
vector::populatePropagateWarpVectorDistributionPatterns(patterns);
|
|
vector::populateDistributeReduction(patterns, warpReduction);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
}
|
|
WarpExecuteOnLane0LoweringOptions options;
|
|
options.warpAllocationFn = allocateGlobalSharedMemory;
|
|
options.warpSyncronizationFn = [](Location loc, OpBuilder &builder,
|
|
WarpExecuteOnLane0Op warpOp) {
|
|
builder.create<gpu::BarrierOp>(loc);
|
|
};
|
|
// Test on one pattern in isolation.
|
|
if (warpOpToSCF) {
|
|
populateWarpExecuteOnLane0OpToScfForPattern(patterns, options);
|
|
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
|
return;
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
namespace mlir {
|
|
namespace test {
|
|
void registerTestVectorLowerings() {
|
|
PassRegistration<TestVectorToVectorLowering>();
|
|
|
|
PassRegistration<TestVectorContractionLowering>();
|
|
|
|
PassRegistration<TestVectorTransposeLowering>();
|
|
|
|
PassRegistration<TestVectorUnrollingPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferUnrollingPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferFullPartialSplitPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferOpt>();
|
|
|
|
PassRegistration<TestVectorTransferLoweringPatterns>();
|
|
|
|
PassRegistration<TestVectorMultiReductionLoweringPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferCollapseInnerMostContiguousDims>();
|
|
|
|
PassRegistration<TestVectorReduceToContractPatternsPatterns>();
|
|
|
|
PassRegistration<TestVectorTransferDropUnitDimsPatterns>();
|
|
|
|
PassRegistration<TestFlattenVectorTransferPatterns>();
|
|
|
|
PassRegistration<TestVectorScanLowering>();
|
|
|
|
PassRegistration<TestVectorDistribution>();
|
|
}
|
|
} // namespace test
|
|
} // namespace mlir
|