[mlir][linalg] Move transpose_matmul to targeted transform op (#89717)
More targeted than a blanket "apply everywhere" pattern. Follow up to #89075 to address @ftynse's feedback.
This commit is contained in:
parent
719112c2f6
commit
be1c72d2ba
@ -73,23 +73,6 @@ def ApplyTilingCanonicalizationPatternsOp : Op<Transform_Dialect,
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def ApplyTransposeMatmulPatternsOp : Op<Transform_Dialect,
|
||||
"apply_patterns.linalg.transpose_matmul",
|
||||
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
|
||||
let description = [{
|
||||
Collects patterns to convert Linalg matmul ops to transposed variants.
|
||||
|
||||
By default the LHS matrix is transposed. Set `inputToTranspose=<rhs>` to
|
||||
instead transpose RHS matrix.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
DefaultValuedAttr<TransposeMatmulInput,
|
||||
"TransposeMatmulInput::lhs">:$inputToTranspose);
|
||||
|
||||
let assemblyFormat = "(`<` $inputToTranspose^ `>`)? attr-dict";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferizeToAllocationOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2429,6 +2412,52 @@ def TransposeConv2DOp : Op<Transform_Dialect,
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeMatmulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TransposeMatmulOp : Op<Transform_Dialect,
|
||||
"structured.transpose_matmul",
|
||||
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
||||
TransformOpInterface, TransformEachOpTrait,
|
||||
ReportTrackingListenerFailuresOpTrait]> {
|
||||
let description = [{
|
||||
Convert Linalg matmul ops to transposed variants.
|
||||
|
||||
By default the LHS matrix is transposed. Specify `<rhs>` to instead
|
||||
transpose RHS matrix.
|
||||
|
||||
#### Return modes:
|
||||
|
||||
This operation fails if `target` is unsupported, i.e., not a
|
||||
`linalg.matmul` or `linalg.batch_matmul`. Otherwise, the operation succeeds
|
||||
and returns a handle to the transposed matmul op.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TransformHandleTypeInterface:$target,
|
||||
DefaultValuedAttr<TransposeMatmulInput,
|
||||
"TransposeMatmulInput::lhs">:$inputToTranspose);
|
||||
let results = (outs TransformHandleTypeInterface:$transformed);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$target (`<` $inputToTranspose^ `>`)?
|
||||
attr-dict `:` functional-type($target, results)
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$target)>
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::DiagnosedSilenceableFailure applyToOne(
|
||||
::mlir::transform::TransformRewriter &rewriter,
|
||||
::mlir::linalg::LinalgOp target,
|
||||
::mlir::transform::ApplyToEachResultList &results,
|
||||
::mlir::transform::TransformState &state);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSliceToCopyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1244,6 +1244,14 @@ FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
|
||||
FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter,
|
||||
linalg::Conv2DNhwcFhwcQOp op);
|
||||
|
||||
/// Convert Linalg matmul ops to transposed variants.
|
||||
FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter,
|
||||
linalg::MatmulOp op,
|
||||
bool transposeLHS = true);
|
||||
FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
|
||||
linalg::BatchMatmulOp op,
|
||||
bool transposeLHS = true);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Rewrite patterns wrapping transformations.
|
||||
// TODO: every single such pattern should be a close to noop wrapper around a
|
||||
|
@ -199,12 +199,6 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns(
|
||||
linalg::populateLinalgTilingCanonicalizationPatterns(patterns);
|
||||
}
|
||||
|
||||
void transform::ApplyTransposeMatmulPatternsOp::populatePatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
|
||||
linalg::populateTransposeMatmulPatterns(patterns, transposeLHS);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferizeToAllocationOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -3422,6 +3416,32 @@ DiagnosedSilenceableFailure transform::TransposeConv2DOp::applyToOne(
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TransposeMatmulOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure transform::TransposeMatmulOp::applyToOne(
|
||||
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
|
||||
transform::ApplyToEachResultList &results,
|
||||
transform::TransformState &state) {
|
||||
rewriter.setInsertionPoint(target);
|
||||
bool transposeLHS = getInputToTranspose() == TransposeMatmulInput::lhs;
|
||||
auto maybeTransformed =
|
||||
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
|
||||
.Case([&](linalg::MatmulOp op) {
|
||||
return transposeMatmul(rewriter, op, transposeLHS);
|
||||
})
|
||||
.Case([&](linalg::BatchMatmulOp op) {
|
||||
return transposeBatchMatmul(rewriter, op, transposeLHS);
|
||||
})
|
||||
.Default([&](Operation *op) { return failure(); });
|
||||
if (failed(maybeTransformed))
|
||||
return emitSilenceableFailure(target->getLoc()) << "not supported";
|
||||
// Handle to the new Matmul operation with transposed filters
|
||||
results.push_back(*maybeTransformed);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSliceToCopyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -18,7 +18,6 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
namespace {
|
||||
/// Pattern to replace
|
||||
///
|
||||
/// linalg.matmul(a, b)
|
||||
@ -29,50 +28,44 @@ namespace {
|
||||
///
|
||||
/// By default the LHS is transposed. Set `transposeLHS=false` to
|
||||
/// transpose RHS instead.
|
||||
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
|
||||
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
|
||||
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
|
||||
FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
|
||||
linalg::MatmulOp matmulOp,
|
||||
bool transposeLHS) {
|
||||
if (!bufferization::hasTensorSemantics(matmulOp))
|
||||
return rewriter.notifyMatchFailure(
|
||||
matmulOp, "only matmul ops with tensors are supported");
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!bufferization::hasTensorSemantics(matmulOp))
|
||||
return rewriter.notifyMatchFailure(
|
||||
matmulOp, "only matmul ops with tensors are supported");
|
||||
Location loc = matmulOp.getLoc();
|
||||
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
|
||||
auto type = cast<ShapedType>(input.getType());
|
||||
|
||||
Location loc = matmulOp.getLoc();
|
||||
Value input = matmulOp.getInputs()[transposeLHS ? 0 : 1];
|
||||
auto type = cast<ShapedType>(input.getType());
|
||||
SmallVector<Value> dynamicDims;
|
||||
if (type.isDynamicDim(1))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
|
||||
if (type.isDynamicDim(0))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
||||
|
||||
SmallVector<Value> dynamicDims;
|
||||
if (type.isDynamicDim(1))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
|
||||
if (type.isDynamicDim(0))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
||||
|
||||
ArrayRef<int64_t> shape = type.getShape();
|
||||
Value empty = rewriter.create<tensor::EmptyOp>(
|
||||
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
|
||||
dynamicDims);
|
||||
auto transposeOp = rewriter.create<linalg::TransposeOp>(
|
||||
loc, input, empty, ArrayRef<int64_t>{1, 0});
|
||||
if (transposeLHS) {
|
||||
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeAOp>(
|
||||
matmulOp, matmulOp.getResultTypes(),
|
||||
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
|
||||
matmulOp.getOutputs());
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<linalg::MatmulTransposeBOp>(
|
||||
matmulOp, matmulOp.getResultTypes(),
|
||||
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
|
||||
matmulOp.getOutputs());
|
||||
}
|
||||
|
||||
return success();
|
||||
ArrayRef<int64_t> shape = type.getShape();
|
||||
Value empty = rewriter.create<tensor::EmptyOp>(
|
||||
loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(),
|
||||
dynamicDims);
|
||||
auto transposeOp = rewriter.create<linalg::TransposeOp>(
|
||||
loc, input, empty, ArrayRef<int64_t>{1, 0});
|
||||
Operation *newMatmulOp;
|
||||
if (transposeLHS) {
|
||||
newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>(
|
||||
loc, matmulOp.getResultTypes(),
|
||||
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
|
||||
matmulOp.getOutputs());
|
||||
} else {
|
||||
newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>(
|
||||
loc, matmulOp.getResultTypes(),
|
||||
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
|
||||
matmulOp.getOutputs());
|
||||
}
|
||||
|
||||
private:
|
||||
bool transposeLHS;
|
||||
};
|
||||
rewriter.replaceOp(matmulOp, newMatmulOp);
|
||||
return newMatmulOp;
|
||||
}
|
||||
|
||||
/// Pattern to replace
|
||||
///
|
||||
@ -84,47 +77,75 @@ private:
|
||||
///
|
||||
/// Only the non-batch dimensions are transposed. By default the LHS is
|
||||
/// transposed. Set `transposeLHS=false` to transpose RHS instead.
|
||||
FailureOr<Operation *>
|
||||
mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
|
||||
linalg::BatchMatmulOp batchMatmulOp,
|
||||
bool transposeLHS) {
|
||||
if (!bufferization::hasTensorSemantics(batchMatmulOp))
|
||||
return rewriter.notifyMatchFailure(
|
||||
batchMatmulOp, "only matmul ops with tensors are supported");
|
||||
|
||||
Location loc = batchMatmulOp.getLoc();
|
||||
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
|
||||
auto type = cast<ShapedType>(input.getType());
|
||||
|
||||
SmallVector<Value> dynamicDims;
|
||||
if (type.isDynamicDim(0))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
||||
if (type.isDynamicDim(2))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
|
||||
if (type.isDynamicDim(1))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
|
||||
|
||||
ArrayRef<int64_t> shape = type.getShape();
|
||||
Value empty = rewriter.create<tensor::EmptyOp>(
|
||||
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
|
||||
type.getElementType(), dynamicDims);
|
||||
auto transposeOp = rewriter.create<linalg::TransposeOp>(
|
||||
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
|
||||
Operation *newMatmulOp;
|
||||
if (transposeLHS) {
|
||||
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>(
|
||||
loc, batchMatmulOp.getResultTypes(),
|
||||
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
|
||||
batchMatmulOp.getOutputs());
|
||||
} else {
|
||||
newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>(
|
||||
loc, batchMatmulOp.getResultTypes(),
|
||||
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
|
||||
batchMatmulOp.getOutputs());
|
||||
}
|
||||
rewriter.replaceOp(batchMatmulOp, newMatmulOp);
|
||||
return newMatmulOp;
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct TransposeMatmul final : public OpRewritePattern<linalg::MatmulOp> {
|
||||
TransposeMatmul(MLIRContext *ctx, bool transposeLHS)
|
||||
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::MatmulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (failed(transposeMatmul(rewriter, op, transposeLHS))) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
bool transposeLHS;
|
||||
};
|
||||
|
||||
struct TransposeBatchMatmul final
|
||||
: public OpRewritePattern<linalg::BatchMatmulOp> {
|
||||
TransposeBatchMatmul(MLIRContext *ctx, bool transposeLHS)
|
||||
: OpRewritePattern(ctx), transposeLHS(transposeLHS) {}
|
||||
|
||||
LogicalResult matchAndRewrite(linalg::BatchMatmulOp batchMatmulOp,
|
||||
LogicalResult matchAndRewrite(linalg::BatchMatmulOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!bufferization::hasTensorSemantics(batchMatmulOp))
|
||||
return rewriter.notifyMatchFailure(
|
||||
batchMatmulOp, "only matmul ops with tensors are supported");
|
||||
|
||||
Location loc = batchMatmulOp.getLoc();
|
||||
Value input = batchMatmulOp.getInputs()[transposeLHS ? 0 : 1];
|
||||
auto type = cast<ShapedType>(input.getType());
|
||||
|
||||
SmallVector<Value> dynamicDims;
|
||||
if (type.isDynamicDim(0))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
|
||||
if (type.isDynamicDim(2))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2));
|
||||
if (type.isDynamicDim(1))
|
||||
dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1));
|
||||
|
||||
ArrayRef<int64_t> shape = type.getShape();
|
||||
Value empty = rewriter.create<tensor::EmptyOp>(
|
||||
loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]},
|
||||
type.getElementType(), dynamicDims);
|
||||
auto transposeOp = rewriter.create<linalg::TransposeOp>(
|
||||
loc, input, empty, ArrayRef<int64_t>{0, 2, 1});
|
||||
if (transposeLHS) {
|
||||
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeAOp>(
|
||||
batchMatmulOp, batchMatmulOp.getResultTypes(),
|
||||
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
|
||||
batchMatmulOp.getOutputs());
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<linalg::BatchMatmulTransposeBOp>(
|
||||
batchMatmulOp, batchMatmulOp.getResultTypes(),
|
||||
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
|
||||
batchMatmulOp.getOutputs());
|
||||
if (failed(transposeBatchMatmul(rewriter, op, transposeLHS))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -2,10 +2,9 @@
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.structured.transpose_matmul %matmul : (!transform.any_op) -> (!transform.any_op)
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.linalg.transpose_matmul
|
||||
} : !transform.any_op
|
||||
transform.apply_cse to %0 : !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.canonicalization
|
||||
|
@ -2,10 +2,9 @@
|
||||
|
||||
module attributes {transform.with_named_sequence} {
|
||||
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
|
||||
%matmul = transform.structured.match ops{["linalg.matmul", "linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.structured.transpose_matmul %matmul <rhs> : (!transform.any_op) -> (!transform.any_op)
|
||||
%0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.linalg.transpose_matmul <rhs>
|
||||
} : !transform.any_op
|
||||
transform.apply_cse to %0 : !transform.any_op
|
||||
transform.apply_patterns to %0 {
|
||||
transform.apply_patterns.canonicalization
|
||||
|
Loading…
x
Reference in New Issue
Block a user