Cullen Rhodes be1c72d2ba
[mlir][linalg] Move transpose_matmul to targeted transform op (#89717)
More targeted than a blanket "apply everywhere" pattern. Follow up to
#89075 to address @ftynse's feedback.
2024-04-23 10:52:50 +01:00

162 lines
5.9 KiB
C++

//===- TransposeMatmul.cpp - Convert Linalg matmul to transposed variants -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This is intended to be a simple high-level (target-agnostic) matmul
// transposition transformation.
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "linalg-transpose-matmul"
using namespace mlir;
using namespace mlir::linalg;
/// Pattern to replace
///
/// linalg.matmul(a, b)
///
/// with
///
/// linalg.matmul_transpose_a(linalg.transpose(a), b)
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");
Location loc = matmulOp.getLoc();
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
}
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}
/// Pattern to replace
///
/// linalg.batch_matmul(a, b)
///
/// with
///
/// linalg.batch_matmul_transpose_a(linalg.transpose(a), b)
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
bool transposeLHS) {
if (!bufferization::hasTensorSemantics(batchMatmulOp))
return rewriter.notifyMatchFailure(
batchMatmulOp, "only matmul ops with tensors are supported");
Location loc = batchMatmulOp.getLoc();
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (type.isDynamicDim(2))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
}
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}
namespace {
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
LogicalResult matchAndRewrite(linalg::MatmulOp op,
PatternRewriter &rewriter) const override {
if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
return failure();
}
return success();
}
private:
bool transposeLHS;
};
struct TransposeBatchMatmul final
: public OpRewritePattern<linalg::BatchMatmulOp> {
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
PatternRewriter &rewriter) const override {
if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
return failure();
}
return success();
}
private:
bool transposeLHS;
};
} // namespace
void mlir::linalg::populateTransposeMatmulPatterns(RewritePatternSet &patterns,
bool transposeLHS) {
patterns.add<TransposeMatmul, TransposeBatchMatmul>(patterns.getContext(),
transposeLHS);
}