
More targeted than a blanket "apply everywhere" pattern. Follow up to #89075 to address @ftynse's feedback.
162 lines
5.9 KiB
C++
162 lines
5.9 KiB
C++
//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
// This is intended to be a simple high-level (target-agnostic) matmul
|
|
// transposition transformation.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
|
|
#define DEBUG_TYPE "linalg-transpose-matmul"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
/// Pattern to replace
|
|
///
|
|
/// linalg.matmul(a, b)
|
|
///
|
|
/// with
|
|
///
|
|
/// linalg.matmul_transpose_a(linalg.transpose(a), b)
|
|
///
|
|
/// By default the LHS is transposed. Set `transposeLHS=false` to
|
|
/// transpose RHS instead.
|
|
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
|
|
linalg::MatmulOp matmulOp,
|
|
bool transposeLHS) {
|
|
if (!bufferization::hasTensorSemantics(matmulOp))
|
|
return rewriter.notifyMatchFailure(
|
|
matmulOp, "only matmul ops with tensors are supported");
|
|
|
|
Location loc = matmulOp.getLoc();
|
|
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
|
|
auto type = cast<ShapedType>(input.getType());
|
|
|
|
SmallVector<Value> dynamicDims;
|
|
if (type.isDynamicDim(1))
|
|
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
|
|
if (type.isDynamicDim(0))
|
|
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
|
|
|
ArrayRef<int64_t> shape = type.getShape();
|
|
Value empty = rewriter.create<tensor::EmptyOp>(
|
|
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
|
|
dynamicDims);
|
|
auto transposeOp = rewriter.create<linalg::TransposeOp>(
|
|
loc, input, empty, ArrayRef<int64_t>{1, 0});
|
|
Operation *newMatmulOp;
|
|
if (transposeLHS) {
|
|
newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
|
|
loc, matmulOp.getResultTypes(),
|
|
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
|
|
matmulOp.getOutputs());
|
|
} else {
|
|
newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
|
|
loc, matmulOp.getResultTypes(),
|
|
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
|
|
matmulOp.getOutputs());
|
|
}
|
|
rewriter.replaceOp(matmulOp, newMatmulOp);
|
|
return newMatmulOp;
|
|
}
|
|
|
|
/// Pattern to replace
|
|
///
|
|
/// linalg.batch_matmul(a, b)
|
|
///
|
|
/// with
|
|
///
|
|
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
|
|
///
|
|
/// Only the non-batch dimensions are transposed. By default the LHS is
|
|
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
|
|
FailureOr<Operation *>
|
|
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
|
|
linalg::BatchMatmulOp batchMatmulOp,
|
|
bool transposeLHS) {
|
|
if (!bufferization::hasTensorSemantics(batchMatmulOp))
|
|
return rewriter.notifyMatchFailure(
|
|
batchMatmulOp, "only matmul ops with tensors are supported");
|
|
|
|
Location loc = batchMatmulOp.getLoc();
|
|
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
|
|
auto type = cast<ShapedType>(input.getType());
|
|
|
|
SmallVector<Value> dynamicDims;
|
|
if (type.isDynamicDim(0))
|
|
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
|
if (type.isDynamicDim(2))
|
|
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
|
|
if (type.isDynamicDim(1))
|
|
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
|
|
|
|
ArrayRef<int64_t> shape = type.getShape();
|
|
Value empty = rewriter.create<tensor::EmptyOp>(
|
|
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
|
|
type.getElementType(), dynamicDims);
|
|
auto transposeOp = rewriter.create<linalg::TransposeOp>(
|
|
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
|
|
Operation *newMatmulOp;
|
|
if (transposeLHS) {
|
|
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
|
|
loc, batchMatmulOp.getResultTypes(),
|
|
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
|
|
batchMatmulOp.getOutputs());
|
|
} else {
|
|
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
|
|
loc, batchMatmulOp.getResultTypes(),
|
|
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
|
|
batchMatmulOp.getOutputs());
|
|
}
|
|
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
|
|
return newMatmulOp;
|
|
}
|
|
|
|
namespace {
|
|
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
|
|
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
|
|
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
|
|
|
|
LogicalResult matchAndRewrite(linalg::MatmulOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
bool transposeLHS;
|
|
};
|
|
|
|
struct TransposeBatchMatmul final
|
|
: public OpRewritePattern<linalg::BatchMatmulOp> {
|
|
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
|
|
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
|
|
|
|
LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
bool transposeLHS;
|
|
};
|
|
} // namespace
|
|
|
|
void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
|
|
bool transposeLHS) {
|
|
patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
|
|
transposeLHS);
|
|
}
|