From 9cc11b98a76c9b2f39b84f709566aac6f962f07a Mon Sep 17 00:00:00 2001 From: donald chen Date: Tue, 23 Jul 2024 12:52:25 +0800 Subject: [PATCH] [mlir] [linalg] Add pattern to swap transpose with broadcast (#97063) Add a pattern that implement: transpose(broadcast(input)) -> broadcast(transpose(input)) --- .../mlir/Dialect/Utils/IndexingUtils.h | 8 ++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 61 ++++++++++++++- mlir/lib/Dialect/Utils/IndexingUtils.cpp | 26 +++++++ mlir/test/Dialect/Linalg/canonicalize.mlir | 75 ++++++++++++++++++- 4 files changed, 168 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index b774359552aa..7849782e5442 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -243,6 +243,14 @@ SmallVector computePermutationVector(int64_t permSize, ArrayRef positions, ArrayRef desiredPositions); +/// Returns a permutation vector that drop the input dims in +/// dropPositions from inputPerm. +/// +/// For example, inputPerm = {2, 4, 0, 1, 3} and dropPositions= {1, 2} would +/// result in a {2, 0, 1} permutation vector. +SmallVector dropDims(ArrayRef inputPerm, + ArrayRef dropPositions); + /// Helper to return a subset of `arrayAttr` as a vector of int64_t. // TODO: Port everything relevant to DenseArrayAttr and drop this util. SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index cefaad9b2265..d1db90bbe2d2 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1895,9 +1895,68 @@ struct FoldTransposeWithTranspose : OpRewritePattern { } }; +/// This pattern canonicalize transpose by swapping the order of +/// broadcast and transpose: +/// transpose(broadcast(input)) -> broadcast(transpose(input)) +struct SwapTransposeWithBroadcast : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + Value input = transposeOp.getInput(); + BroadcastOp broadcastOp = input.getDefiningOp(); + if (!input.hasOneUse() || !broadcastOp) + return failure(); + + ArrayRef dimensions = broadcastOp.getDimensions(); + ArrayRef perms = transposeOp.getPermutation(); + + // Get new perms and new dimensions. + SmallVector resultPerms = dropDims(perms, dimensions); + SmallVector invertPerm = invertPermutationVector(perms); + SmallVector resultDimensions; + unsigned dimensionSize = dimensions.size(); + for (unsigned i = 0; i < dimensionSize; ++i) + resultDimensions.push_back(invertPerm[dimensions[i]]); + + // Create transpose result. + Value broadcastInput = broadcastOp.getInput(); + Location loc = transposeOp.getLoc(); + MLIRContext *ctx = transposeOp.getContext(); + SmallVector dims; + auto broadcastInputTy = + mlir::cast(broadcastInput.getType()); + unsigned inputRank = broadcastInputTy.getRank(); + for (unsigned i = 0; i < inputRank; ++i) { + if (broadcastInputTy.isDynamicDim(i)) { + dims.push_back(rewriter.create(loc, broadcastInput, i) + ->getResult(0)); + } else { + dims.push_back(IntegerAttr::get(IndexType::get(ctx), + broadcastInputTy.getDimSize(i))); + } + } + SmallVector transposeResultShapes = + applyPermutation(dims, resultPerms); + Value transposeInit = rewriter.create( + transposeOp.getLoc(), transposeResultShapes, + broadcastInputTy.getElementType()); + + // Create broadcast(transpose(input)). + Value transposeResult = + rewriter + .create(loc, broadcastOp.getInput(), transposeInit, + resultPerms) + ->getResult(0); + rewriter.replaceOpWithNewOp( + transposeOp, transposeResult, transposeOp.getInit(), resultDimensions); + return success(); + } +}; + void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index aba225be720c..108839a4d90e 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -252,6 +252,32 @@ mlir::computePermutationVector(int64_t permSize, ArrayRef positions, return res; } +SmallVector mlir::dropDims(ArrayRef inputPerm, + ArrayRef dropPositions) { + assert(inputPerm.size() >= dropPositions.size() && + "expect inputPerm size large than position to drop"); + SmallVector res; + unsigned permSize = inputPerm.size(); + for (unsigned inputIndex = 0; inputIndex < permSize; ++inputIndex) { + int64_t targetIndex = inputPerm[inputIndex]; + bool shouldDrop = false; + unsigned dropSize = dropPositions.size(); + for (unsigned dropIndex = 0; dropIndex < dropSize; dropIndex++) { + if (dropPositions[dropIndex] == inputPerm[inputIndex]) { + shouldDrop = true; + break; + } + if (dropPositions[dropIndex] < inputPerm[inputIndex]) { + targetIndex--; + } + } + if (!shouldDrop) { + res.push_back(targetIndex); + } + } + return res; +} + SmallVector mlir::getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront, unsigned dropBack) { diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 928030a81dc0..d34bc8c1c54f 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>) return %0 : tensor<2x3xf32> } -// ---- +// ----- func.func @transpose_1d(%input: tensor<16xf32>, %init: tensor<16xf32>) -> tensor<16xf32> { @@ -1096,3 +1096,76 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>, func.return %transpose2 : tensor<3x4x5xf32> } +// ----- + +func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>, + %init1: tensor<1x2x3x4x5x6xf32>, + %init2: tensor<1x6x2x3x5x4xf32>) -> tensor<1x6x2x3x5x4xf32> { + // CHECK-LABEL: @broadcast_transpose_fold + // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32> + // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32> + // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32> + // CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32> + // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1] + // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1] + // CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32> + %broadcast = linalg.broadcast + ins(%input : tensor<2x4x5xf32>) + outs(%init1 : tensor<1x2x3x4x5x6xf32>) + dimensions = [0, 2, 5] + %transpose = linalg.transpose + ins(%broadcast : tensor<1x2x3x4x5x6xf32>) + outs(%init2 : tensor<1x6x2x3x5x4xf32>) + permutation = [0, 5, 1, 2, 4, 3] + func.return %transpose : tensor<1x6x2x3x5x4xf32> +} + +// ----- + +func.func @broadcast_transpose_fold_dynamic(%input: tensor, + %init1: tensor<1x?x3x?x5x6xf32>, + %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> { + // CHECK-LABEL: @broadcast_transpose_fold_dynamic + // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor + // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32> + // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32> + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + // CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor + // CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor + // CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor + // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor) outs(%[[TMP_INIT]] : tensor) permutation = [1, 2, 0] + // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3] + // CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32> + %broadcast = linalg.broadcast + ins(%input : tensor) + outs(%init1 : tensor<1x?x3x?x5x6xf32>) + dimensions = [0, 2, 5] + %transpose = linalg.transpose + ins(%broadcast : tensor<1x?x3x?x5x6xf32>) + outs(%init2 : tensor<1x3x?x6x5x?xf32>) + permutation = [0, 2, 3, 5, 4, 1] + func.return %transpose : tensor<1x3x?x6x5x?xf32> +} + +// ----- + +func.func @broadcast_transpose_fold_2dim(%input: tensor<2xf32>, + %init1: tensor<2x4xf32>, + %init2: tensor<4x2xf32>) -> tensor<4x2xf32> { + // CHECK-LABEL: @broadcast_transpose_fold_2dim + // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32> + // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32> + // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32> + // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0] + // CHECK: return %[[BROADCAST]] : tensor<4x2xf32> + %broadcast = linalg.broadcast + ins(%input : tensor<2xf32>) + outs(%init1 : tensor<2x4xf32>) + dimensions = [1] + %transpose = linalg.transpose + ins(%broadcast : tensor<2x4xf32>) + outs(%init2 : tensor<4x2xf32>) + permutation = [1, 0] + func.return %transpose : tensor<4x2xf32> +}