335 lines
14 KiB
C++
335 lines
14 KiB
C++
//===- TestLinalgTransforms.cpp - Test Linalg transformation patterns -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements logic for testing Linalg transformations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/Vector/VectorOps.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
|
|
#include "llvm/ADT/SetVector.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
namespace {
|
|
struct TestLinalgTransforms
|
|
: public PassWrapper<TestLinalgTransforms, FunctionPass> {
|
|
TestLinalgTransforms() = default;
|
|
TestLinalgTransforms(const TestLinalgTransforms &pass) {}
|
|
|
|
void runOnFunction() override;
|
|
|
|
Option<bool> testPatterns{*this, "test-patterns",
|
|
llvm::cl::desc("Test a mixed set of patterns"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testMatmulToVectorPatterns1dTiling{
|
|
*this, "test-matmul-to-vector-patterns-tile-1d",
|
|
llvm::cl::desc(
|
|
"Test a fused pass that applies patterns from matmul to vectors via "
|
|
"1-d tiling"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testMatmulToVectorPatterns2dTiling{
|
|
*this, "test-matmul-to-vector-patterns-tile-2d",
|
|
llvm::cl::desc(
|
|
"Test a fused pass that applies patterns from matmul to vectors via "
|
|
"2-d tiling"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
|
|
llvm::cl::desc("Test promotion options"),
|
|
llvm::cl::init(false)};
|
|
Option<bool> testVectorTransferForwardingPatterns{
|
|
*this, "test-vector-transfer-forwarding-patterns",
|
|
llvm::cl::desc(
|
|
"Test a fused pass that forwards linalg.copy to vector.transfer"),
|
|
llvm::cl::init(false)};
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
static void applyPatterns(FuncOp funcOp) {
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
OwningRewritePatternList patterns;
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg tiling patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}),
|
|
LinalgMarker(Identifier::get("MEM", ctx), Identifier::get("L3", ctx)));
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}),
|
|
LinalgMarker(Identifier::get("L3", ctx), Identifier::get("L2", ctx)));
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
|
LinalgMarker(Identifier::get("L2", ctx), Identifier::get("L1", ctx)));
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}),
|
|
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("REG", ctx)));
|
|
|
|
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
|
ctx,
|
|
LinalgTilingOptions().setTileSizes({5, 6}).setLoopType(
|
|
LinalgTilingLoopType::ParallelLoops),
|
|
LinalgMarker({}, Identifier::get("L1", ctx)));
|
|
|
|
patterns.insert<LinalgTilingPattern<DotOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes(8000),
|
|
LinalgMarker(ArrayRef<Identifier>{Identifier::get("MEM", ctx),
|
|
Identifier::get("L3", ctx),
|
|
Identifier::get("L2", ctx)},
|
|
Identifier::get("REG", ctx)));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg tiling and permutation patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({2000, 3000, 4000})
|
|
.setInterchange({1, 2, 0}),
|
|
LinalgMarker(Identifier::get("__with_perm__", ctx),
|
|
Identifier::get("L2__with_perm__", ctx)));
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({200, 300, 400})
|
|
.setInterchange({1, 0, 2}),
|
|
LinalgMarker(Identifier::get("L2__with_perm__", ctx),
|
|
Identifier::get("L1__with_perm__", ctx)));
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}),
|
|
LinalgMarker(Identifier::get("L1__with_perm__", ctx),
|
|
Identifier::get("REG__with_perm__", ctx)));
|
|
|
|
patterns.insert<LinalgTilingPattern<MatvecOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}),
|
|
LinalgMarker(Identifier::get("__with_perm__", ctx),
|
|
Identifier::get("L1__with_perm__", ctx)));
|
|
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({16, 8, 4})
|
|
.setInterchange({1, 2, 0})
|
|
.setLoopType(LinalgTilingLoopType::ParallelLoops),
|
|
LinalgMarker(Identifier::get("par__with_perm__", ctx),
|
|
Identifier::get("after_par__with_perm__", ctx)));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg to loops patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.insert<LinalgLoweringPattern<DotOp>>(
|
|
ctx,
|
|
/*loweringType=*/LinalgLoweringType::Loops,
|
|
LinalgMarker(Identifier::get("REG", ctx)));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg to vector contraction patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.insert<LinalgVectorizationPattern<MatmulOp>,
|
|
LinalgVectorizationPattern<FillOp>,
|
|
LinalgVectorizationPattern<GenericOp>>(
|
|
ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg generic permutation patterns.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.insert<LinalgInterchangePattern<GenericOp>>(
|
|
ctx,
|
|
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
|
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
|
|
patterns.insert<LinalgInterchangePattern<IndexedGenericOp>>(
|
|
ctx,
|
|
/*interchangeVector=*/ArrayRef<unsigned>{1, 2, 0},
|
|
LinalgMarker({}, Identifier::get("PERMUTED", ctx)));
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Linalg subview operands promotion.
|
|
//===--------------------------------------------------------------------===//
|
|
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
|
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
|
|
LinalgMarker(Identifier::get("_promote_views_", ctx),
|
|
Identifier::get("_views_promoted_", ctx)));
|
|
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgPromotionOptions()
|
|
.setOperandsToPromote({0})
|
|
.setUseFullTileBuffersByDefault(true),
|
|
LinalgMarker(Identifier::get("_promote_first_view_", ctx),
|
|
Identifier::get("_first_view_promoted_", ctx)));
|
|
patterns.insert<LinalgPromotionPattern<FillOp>>(
|
|
ctx,
|
|
LinalgPromotionOptions()
|
|
.setOperandsToPromote({0})
|
|
.setUseFullTileBuffers({true})
|
|
.setAlignment(32),
|
|
LinalgMarker(Identifier::get("_promote_views_aligned_", ctx),
|
|
Identifier::get("_views_aligned_promoted_", ctx)));
|
|
|
|
applyPatternsAndFoldGreedily(funcOp, patterns);
|
|
|
|
// Drop the marker.
|
|
funcOp.walk([](LinalgOp op) {
|
|
op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
|
|
});
|
|
}
|
|
|
|
static void fillL1TilingAndMatmulToVectorPatterns(
|
|
FuncOp funcOp, StringRef startMarker,
|
|
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
patternsVector.emplace_back(LinalgTilingPattern<MatmulOp>(
|
|
ctx,
|
|
LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
|
|
LinalgMarker(Identifier::get(startMarker, ctx),
|
|
Identifier::get("L1", ctx))));
|
|
|
|
patternsVector.emplace_back(LinalgPromotionPattern<MatmulOp>(
|
|
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
|
|
LinalgMarker(Identifier::get("L1", ctx), Identifier::get("VEC", ctx))));
|
|
|
|
patternsVector.emplace_back(LinalgVectorizationPattern<MatmulOp>(
|
|
ctx, LinalgMarker(Identifier::get("VEC", ctx))));
|
|
patternsVector.back()
|
|
.insert<LinalgVectorizationPattern<FillOp>,
|
|
LinalgVectorizationPattern<CopyOp>>(ctx);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Test promotion callbacks
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
// Allocation call back
|
|
static Optional<Value> allocCallBackFn(OpBuilder &b, SubViewOp subView,
|
|
ArrayRef<Value> boundingSubViewSize,
|
|
OperationFolder *folder) {
|
|
SmallVector<int64_t, 4> shape(boundingSubViewSize.size(), -1);
|
|
return b
|
|
.create<AllocOp>(subView.getLoc(),
|
|
MemRefType::get(shape,
|
|
subView.getType().getElementType(),
|
|
/*affineMapComposition =*/{}, 3),
|
|
boundingSubViewSize)
|
|
.getResult();
|
|
}
|
|
|
|
// Deallocation callback
|
|
static LogicalResult deallocCallBackFn(OpBuilder &b, Value buffer) {
|
|
b.create<DeallocOp>(buffer.getLoc(), buffer);
|
|
return success();
|
|
}
|
|
|
|
// Copy in call back
|
|
static LogicalResult copyCallBackFn(OpBuilder &b, Value src, Value dst,
|
|
bool isOutput) {
|
|
auto floatType = src.getType().cast<MemRefType>().getElementType();
|
|
if (!floatType.isa<FloatType>())
|
|
return failure();
|
|
if (!isOutput)
|
|
b.create<FillOp>(
|
|
src.getLoc(), dst,
|
|
b.create<ConstantOp>(src.getLoc(), FloatAttr::get(floatType, 42.0)));
|
|
b.create<CopyOp>(src.getLoc(), src, dst);
|
|
return success();
|
|
}
|
|
|
|
void fillPromotionCallBackPatterns(MLIRContext *ctx,
|
|
OwningRewritePatternList &patterns) {
|
|
patterns.insert<LinalgTilingPattern<MatmulOp>>(
|
|
ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}),
|
|
LinalgMarker(Identifier::get("START", ctx),
|
|
Identifier::get("PROMOTE", ctx)));
|
|
patterns.insert<LinalgPromotionPattern<MatmulOp>>(
|
|
ctx,
|
|
LinalgPromotionOptions()
|
|
.setOperandsToPromote({0, 2})
|
|
.setUseFullTileBuffers({false, false})
|
|
.setAllocationDeallocationFns(allocCallBackFn, deallocCallBackFn)
|
|
.setCopyInOutFns(
|
|
[](OpBuilder &b, Value src, Value dst) -> LogicalResult {
|
|
copyCallBackFn(b, src, dst, false);
|
|
return success();
|
|
},
|
|
[](OpBuilder &b, Value src, Value dst) -> LogicalResult {
|
|
copyCallBackFn(b, src, dst, true);
|
|
return success();
|
|
}),
|
|
LinalgMarker(Identifier::get("PROMOTE", ctx)));
|
|
}
|
|
|
|
static void
|
|
applyMatmulToVectorPatterns(FuncOp funcOp,
|
|
bool testMatmulToVectorPatterns1dTiling,
|
|
bool testMatmulToVectorPatterns2dTiling) {
|
|
MLIRContext *ctx = funcOp.getContext();
|
|
SmallVector<OwningRewritePatternList, 4> stage1Patterns;
|
|
if (testMatmulToVectorPatterns1dTiling) {
|
|
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
|
|
stage1Patterns);
|
|
} else if (testMatmulToVectorPatterns2dTiling) {
|
|
stage1Patterns.emplace_back(LinalgTilingPattern<MatmulOp>(
|
|
ctx,
|
|
LinalgTilingOptions()
|
|
.setTileSizes({768, 264, 768})
|
|
.setInterchange({1, 2, 0}),
|
|
LinalgMarker(Identifier::get("START", ctx),
|
|
Identifier::get("L2", ctx))));
|
|
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
|
|
stage1Patterns);
|
|
}
|
|
OwningRewritePatternList stage2Patterns =
|
|
getLinalgTilingCanonicalizationPatterns(ctx);
|
|
applyStagedPatterns(funcOp, stage1Patterns, stage2Patterns);
|
|
}
|
|
|
|
static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
|
|
OwningRewritePatternList forwardPattern;
|
|
forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
|
|
forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
|
|
applyPatternsAndFoldGreedily(funcOp, forwardPattern);
|
|
}
|
|
|
|
/// Apply transformations specified as patterns.
|
|
void TestLinalgTransforms::runOnFunction() {
|
|
auto lambda = [&](void *) {
|
|
getFunction().walk([](LinalgOp op) {
|
|
op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
|
|
});
|
|
};
|
|
std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
|
|
|
|
if (testPromotionOptions) {
|
|
OwningRewritePatternList patterns;
|
|
fillPromotionCallBackPatterns(&getContext(), patterns);
|
|
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
|
return;
|
|
}
|
|
if (testPatterns)
|
|
return applyPatterns(getFunction());
|
|
if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
|
|
return applyMatmulToVectorPatterns(getFunction(),
|
|
testMatmulToVectorPatterns1dTiling,
|
|
testMatmulToVectorPatterns2dTiling);
|
|
if (testVectorTransferForwardingPatterns)
|
|
return applyVectorTransferForwardingPatterns(getFunction());
|
|
}
|
|
|
|
namespace mlir {
|
|
void registerTestLinalgTransforms() {
|
|
PassRegistration<TestLinalgTransforms> testTransformPatternsPass(
|
|
"test-linalg-transform-patterns",
|
|
"Test Linalg transformation patterns by applying them greedily.");
|
|
}
|
|
} // namespace mlir
|