[mlir] [linalg] Add pattern to swap transpose with broadcast (#97063)
Add a pattern that implement: transpose(broadcast(input)) -> broadcast(transpose(input))
This commit is contained in:
parent
d7e8a7487c
commit
9cc11b98a7
@ -243,6 +243,14 @@ SmallVector<int64_t>
|
|||||||
computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
|
computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
|
||||||
ArrayRef<int64_t> desiredPositions);
|
ArrayRef<int64_t> desiredPositions);
|
||||||
|
|
||||||
|
/// Returns a permutation vector that drop the input dims in
|
||||||
|
/// dropPositions from inputPerm.
|
||||||
|
///
|
||||||
|
/// For example, inputPerm = {2, 4, 0, 1, 3} and dropPositions= {1, 2} would
|
||||||
|
/// result in a {2, 0, 1} permutation vector.
|
||||||
|
SmallVector<int64_t> dropDims(ArrayRef<int64_t> inputPerm,
|
||||||
|
ArrayRef<int64_t> dropPositions);
|
||||||
|
|
||||||
/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
|
/// Helper to return a subset of `arrayAttr` as a vector of int64_t.
|
||||||
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
|
// TODO: Port everything relevant to DenseArrayAttr and drop this util.
|
||||||
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
|
SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
|
||||||
|
@ -1895,9 +1895,68 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// This pattern canonicalize transpose by swapping the order of
|
||||||
|
/// broadcast and transpose:
|
||||||
|
/// transpose(broadcast(input)) -> broadcast(transpose(input))
|
||||||
|
struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
|
||||||
|
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
Value input = transposeOp.getInput();
|
||||||
|
BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
|
||||||
|
if (!input.hasOneUse() || !broadcastOp)
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
|
||||||
|
ArrayRef<int64_t> perms = transposeOp.getPermutation();
|
||||||
|
|
||||||
|
// Get new perms and new dimensions.
|
||||||
|
SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
|
||||||
|
SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
|
||||||
|
SmallVector<int64_t> resultDimensions;
|
||||||
|
unsigned dimensionSize = dimensions.size();
|
||||||
|
for (unsigned i = 0; i < dimensionSize; ++i)
|
||||||
|
resultDimensions.push_back(invertPerm[dimensions[i]]);
|
||||||
|
|
||||||
|
// Create transpose result.
|
||||||
|
Value broadcastInput = broadcastOp.getInput();
|
||||||
|
Location loc = transposeOp.getLoc();
|
||||||
|
MLIRContext *ctx = transposeOp.getContext();
|
||||||
|
SmallVector<OpFoldResult> dims;
|
||||||
|
auto broadcastInputTy =
|
||||||
|
mlir::cast<RankedTensorType>(broadcastInput.getType());
|
||||||
|
unsigned inputRank = broadcastInputTy.getRank();
|
||||||
|
for (unsigned i = 0; i < inputRank; ++i) {
|
||||||
|
if (broadcastInputTy.isDynamicDim(i)) {
|
||||||
|
dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
|
||||||
|
->getResult(0));
|
||||||
|
} else {
|
||||||
|
dims.push_back(IntegerAttr::get(IndexType::get(ctx),
|
||||||
|
broadcastInputTy.getDimSize(i)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SmallVector<OpFoldResult> transposeResultShapes =
|
||||||
|
applyPermutation(dims, resultPerms);
|
||||||
|
Value transposeInit = rewriter.create<tensor::EmptyOp>(
|
||||||
|
transposeOp.getLoc(), transposeResultShapes,
|
||||||
|
broadcastInputTy.getElementType());
|
||||||
|
|
||||||
|
// Create broadcast(transpose(input)).
|
||||||
|
Value transposeResult =
|
||||||
|
rewriter
|
||||||
|
.create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
|
||||||
|
resultPerms)
|
||||||
|
->getResult(0);
|
||||||
|
rewriter.replaceOpWithNewOp<BroadcastOp>(
|
||||||
|
transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
results.add<FoldTransposeWithTranspose>(context);
|
results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -252,6 +252,32 @@ mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SmallVector<int64_t> mlir::dropDims(ArrayRef<int64_t> inputPerm,
|
||||||
|
ArrayRef<int64_t> dropPositions) {
|
||||||
|
assert(inputPerm.size() >= dropPositions.size() &&
|
||||||
|
"expect inputPerm size large than position to drop");
|
||||||
|
SmallVector<int64_t> res;
|
||||||
|
unsigned permSize = inputPerm.size();
|
||||||
|
for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) {
|
||||||
|
int64_t targetIndex = inputPerm[inputIndex];
|
||||||
|
bool shouldDrop = false;
|
||||||
|
unsigned dropSize = dropPositions.size();
|
||||||
|
for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) {
|
||||||
|
if (dropPositions[dropIndex] == inputPerm[inputIndex]) {
|
||||||
|
shouldDrop = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (dropPositions[dropIndex] < inputPerm[inputIndex]) {
|
||||||
|
targetIndex--;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!shouldDrop) {
|
||||||
|
res.push_back(targetIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
|
SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
|
||||||
unsigned dropFront,
|
unsigned dropFront,
|
||||||
unsigned dropBack) {
|
unsigned dropBack) {
|
||||||
|
@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
|
|||||||
return %0 : tensor<2x3xf32>
|
return %0 : tensor<2x3xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ----
|
// -----
|
||||||
|
|
||||||
func.func @transpose_1d(%input: tensor<16xf32>,
|
func.func @transpose_1d(%input: tensor<16xf32>,
|
||||||
%init: tensor<16xf32>) -> tensor<16xf32> {
|
%init: tensor<16xf32>) -> tensor<16xf32> {
|
||||||
@ -1096,3 +1096,76 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
|
|||||||
func.return %transpose2 : tensor<3x4x5xf32>
|
func.return %transpose2 : tensor<3x4x5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
|
||||||
|
%init1: tensor<1x2x3x4x5x6xf32>,
|
||||||
|
%init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> {
|
||||||
|
// CHECK-LABEL: @broadcast_transpose_fold
|
||||||
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
|
||||||
|
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
|
||||||
|
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
|
||||||
|
// CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
|
||||||
|
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
|
||||||
|
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
|
||||||
|
// CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
|
||||||
|
%broadcast = linalg.broadcast
|
||||||
|
ins(%input : tensor<2x4x5xf32>)
|
||||||
|
outs(%init1 : tensor<1x2x3x4x5x6xf32>)
|
||||||
|
dimensions = [0, 2, 5]
|
||||||
|
%transpose = linalg.transpose
|
||||||
|
ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
|
||||||
|
outs(%init2 : tensor<1x6x2x3x5x4xf32>)
|
||||||
|
permutation = [0, 5, 1, 2, 4, 3]
|
||||||
|
func.return %transpose : tensor<1x6x2x3x5x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
|
||||||
|
%init1: tensor<1x?x3x?x5x6xf32>,
|
||||||
|
%init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
|
||||||
|
// CHECK-LABEL: @broadcast_transpose_fold_dynamic
|
||||||
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
|
||||||
|
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
|
||||||
|
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
|
||||||
|
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
|
||||||
|
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
|
||||||
|
// CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
|
||||||
|
// CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
|
||||||
|
// CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
|
||||||
|
// CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
|
||||||
|
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
|
||||||
|
// CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
|
||||||
|
%broadcast = linalg.broadcast
|
||||||
|
ins(%input : tensor<?x?x5xf32>)
|
||||||
|
outs(%init1 : tensor<1x?x3x?x5x6xf32>)
|
||||||
|
dimensions = [0, 2, 5]
|
||||||
|
%transpose = linalg.transpose
|
||||||
|
ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
|
||||||
|
outs(%init2 : tensor<1x3x?x6x5x?xf32>)
|
||||||
|
permutation = [0, 2, 3, 5, 4, 1]
|
||||||
|
func.return %transpose : tensor<1x3x?x6x5x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>,
|
||||||
|
%init1: tensor<2x4xf32>,
|
||||||
|
%init2: tensor<4x2xf32>) -> tensor<4x2xf32> {
|
||||||
|
// CHECK-LABEL: @broadcast_transpose_fold_2dim
|
||||||
|
// CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
|
||||||
|
// CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
|
||||||
|
// CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32>
|
||||||
|
// CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0]
|
||||||
|
// CHECK: return %[[BROADCAST]] : tensor<4x2xf32>
|
||||||
|
%broadcast = linalg.broadcast
|
||||||
|
ins(%input : tensor<2xf32>)
|
||||||
|
outs(%init1 : tensor<2x4xf32>)
|
||||||
|
dimensions = [1]
|
||||||
|
%transpose = linalg.transpose
|
||||||
|
ins(%broadcast : tensor<2x4xf32>)
|
||||||
|
outs(%init2 : tensor<4x2xf32>)
|
||||||
|
permutation = [1, 0]
|
||||||
|
func.return %transpose : tensor<4x2xf32>
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user