[mlir][linalg] Move transpose_matmul to targeted transform op (#89717)

More targeted than a blanket "apply everywhere" pattern. Follow up to
#89075 to address @ftynse's feedback.
This commit is contained in:
Cullen Rhodes 2024-04-23 10:52:50 +01:00 committed by GitHub
parent 719112c2f6
commit be1c72d2ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 180 additions and 104 deletions

View File

@ -73,23 +73,6 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}
def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
"apply_patterns.linalg.transpose_matmul",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collects patterns to convert Linalg matmul ops to transposed variants.
By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
instead transpose RHS matrix.
}];
let arguments = (ins
DefaultValuedAttr<TransposeMatmulInput,
"TransposeMatmulInput::lhs">:$inputToTranspose);
let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
}
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
@ -2429,6 +2412,52 @@ def TransposeConv2DOp : Op<Transform_Dialect,
}];
}
//===----------------------------------------------------------------------===//
// TransposeMatmulOp
//===----------------------------------------------------------------------===//
def TransposeMatmulOp : Op<Transform_Dialect,
"structured.transpose_matmul",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Convert Linalg matmul ops to transposed variants.
By default the LHS matrix is transposed. Specify `<rhs>` to instead
transpose RHS matrix.
#### Return modes:
This operation fails if `target` is unsupported, i.e., not a
`linalg.matmul` or `linalg.batch_matmul`. Otherwise, the operation succeeds
and returns a handle to the transposed matmul op.
}];
let arguments = (ins
TransformHandleTypeInterface:$target,
DefaultValuedAttr<TransposeMatmulInput,
"TransposeMatmulInput::lhs">:$inputToTranspose);
let results = (outs TransformHandleTypeInterface:$transformed);
let assemblyFormat = [{
$target (`<` $inputToTranspose^ `>`)?
attr-dict `:` functional-type($target, results)
}];
let builders = [
OpBuilder<(ins "Value":$target)>
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//

View File

@ -1244,6 +1244,14 @@ FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcQOp op);
/// Convert Linalg matmul ops to transposed variants.
FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp op,
bool transposeLHS = true);
FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp op,
bool transposeLHS = true);
//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a

View File

@ -199,12 +199,6 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
}
void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
}
//===----------------------------------------------------------------------===//
// BufferizeToAllocationOp
//===----------------------------------------------------------------------===//
@ -3422,6 +3416,32 @@ DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// TransposeMatmulOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
auto maybeTransformed =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
.Case([&](linalg::MatmulOp op) {
return transposeMatmul(rewriter, op, transposeLHS);
})
.Case([&](linalg::BatchMatmulOp op) {
return transposeBatchMatmul(rewriter, op, transposeLHS);
})
.Default([&](Operation *op) { return failure(); });
if (failed(maybeTransformed))
return emitSilenceableFailure(target->getLoc()) << "not supported";
// Handle to the new Matmul operation with transposed filters
results.push_back(*maybeTransformed);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// InsertSliceToCopyOp
//===----------------------------------------------------------------------===//

View File

@ -18,7 +18,6 @@
using namespace mlir;
using namespace mlir::linalg;
namespace {
/// Pattern to replace
///
/// linalg.matmul(a, b)
@ -29,50 +28,44 @@ namespace {
///
/// By default the LHS is transposed. Set `transposeLHS=false` to
/// transpose RHS instead.
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
linalg::MatmulOp matmulOp,
bool transposeLHS) {
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");
LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
PatternRewriter &rewriter) const override {
if (!bufferization::hasTensorSemantics(matmulOp))
return rewriter.notifyMatchFailure(
matmulOp, "only matmul ops with tensors are supported");
Location loc = matmulOp.getLoc();
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());
Location loc = matmulOp.getLoc();
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
if (transposeLHS) {
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
matmulOp, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
matmulOp, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
}
return success();
ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
}
private:
bool transposeLHS;
};
rewriter.replaceOp(matmulOp, newMatmulOp);
return newMatmulOp;
}
/// Pattern to replace
///
@ -84,47 +77,75 @@ private:
///
/// Only the non-batch dimensions are transposed. By default the LHS is
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
FailureOr<Operation *>
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp batchMatmulOp,
bool transposeLHS) {
if (!bufferization::hasTensorSemantics(batchMatmulOp))
return rewriter.notifyMatchFailure(
batchMatmulOp, "only matmul ops with tensors are supported");
Location loc = batchMatmulOp.getLoc();
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (type.isDynamicDim(2))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
}
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
return newMatmulOp;
}
namespace {
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
LogicalResult matchAndRewrite(linalg::MatmulOp op,
PatternRewriter &rewriter) const override {
if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
return failure();
}
return success();
}
private:
bool transposeLHS;
};
struct TransposeBatchMatmul final
: public OpRewritePattern<linalg::BatchMatmulOp> {
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
PatternRewriter &rewriter) const override {
if (!bufferization::hasTensorSemantics(batchMatmulOp))
return rewriter.notifyMatchFailure(
batchMatmulOp, "only matmul ops with tensors are supported");
Location loc = batchMatmulOp.getLoc();
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
auto type = cast<ShapedType>(input.getType());
SmallVector<Value> dynamicDims;
if (type.isDynamicDim(0))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
if (type.isDynamicDim(2))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
if (type.isDynamicDim(1))
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
ArrayRef<int64_t> shape = type.getShape();
Value empty = rewriter.create<tensor::EmptyOp>(
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
type.getElementType(), dynamicDims);
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
if (transposeLHS) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
batchMatmulOp, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
batchMatmulOp, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
return failure();
}
return success();
}

View File

@ -2,10 +2,9 @@
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.transpose_matmul %matmul : (!transform.any_op) -> (!transform.any_op)
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.linalg.transpose_matmul
} : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.canonicalization

View File

@ -2,10 +2,9 @@
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.transpose_matmul %matmul <rhs> : (!transform.any_op) -> (!transform.any_op)
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.linalg.transpose_matmul <rhs>
} : !transform.any_op
transform.apply_cse to %0 : !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.canonicalization