[mlir][tensor] Fold unpadding collapse_shape into extract_slice (#93554)
This commit is contained in:
parent
189efb0fbb
commit
8f4d5a32ac
@ -48,6 +48,39 @@ struct FoldExpandOfRankReducingExtract
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Fold collapse_shape which only removes static dimensions of size `1`
|
||||||
|
/// into extract_slice.
|
||||||
|
struct FoldUnPaddingCollapseIntoExtract
|
||||||
|
: public OpRewritePattern<tensor::CollapseShapeOp> {
|
||||||
|
using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto extractSliceOp =
|
||||||
|
collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
|
||||||
|
// Collapse cannot be folded away with multiple users of the extract slice
|
||||||
|
// and it is not necessarily beneficial to only convert the collapse into
|
||||||
|
// another extract slice.
|
||||||
|
if (!extractSliceOp || !extractSliceOp->hasOneUse())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Only fold away simple collapse where all removed dimensions have static
|
||||||
|
// size `1`.
|
||||||
|
SliceVerificationResult res = isRankReducedType(
|
||||||
|
collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
|
||||||
|
if (res != SliceVerificationResult::Success)
|
||||||
|
return rewriter.notifyMatchFailure(collapseShapeOp,
|
||||||
|
"expected unpadding collapse");
|
||||||
|
|
||||||
|
Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
|
||||||
|
extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
|
||||||
|
extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
|
||||||
|
extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
|
||||||
|
rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
/// Fold insert_slice(collapse_shape) ops that cancel itself out.
|
/// Fold insert_slice(collapse_shape) ops that cancel itself out.
|
||||||
template <typename OpTy>
|
template <typename OpTy>
|
||||||
struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
|
struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
|
||||||
@ -111,10 +144,11 @@ struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
|
|||||||
|
|
||||||
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
|
void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
|
||||||
RewritePatternSet &patterns) {
|
RewritePatternSet &patterns) {
|
||||||
patterns.add<FoldExpandOfRankReducingExtract,
|
patterns
|
||||||
FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
|
.add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
|
||||||
FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
|
FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
|
||||||
FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
|
FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
|
||||||
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
|
FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
|
||||||
patterns.getContext());
|
FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
|
||||||
|
patterns.getContext());
|
||||||
}
|
}
|
||||||
|
@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract(
|
// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract(
|
||||||
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
|
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
|
||||||
// CHECK-DAG: %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
|
// CHECK-DAG: %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0]
|
||||||
// CHECK-DAG: %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
|
// CHECK-SAME: [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
|
||||||
|
// CHECK-DAG: %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0]
|
||||||
|
// CHECK-SAME: [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
|
||||||
// CHECK: return %[[extract1]], %[[extract2]]
|
// CHECK: return %[[extract1]], %[[extract2]]
|
||||||
func.func @expand_shape_of_rank_reducing_extract(
|
func.func @expand_shape_of_rank_reducing_extract(
|
||||||
%t: tensor<?x?x?x?xf32>, %idx: index)
|
%t: tensor<?x?x?x?xf32>, %idx: index)
|
||||||
@ -22,9 +24,82 @@ func.func @expand_shape_of_rank_reducing_extract(
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @unpadding_collapse_of_extract_slice(
|
||||||
|
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
|
||||||
|
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
||||||
|
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
||||||
|
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
|
||||||
|
// CHECK-SAME: [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?xf32>
|
||||||
|
// CHECK: return %[[extract]]
|
||||||
|
func.func @unpadding_collapse_of_extract_slice(
|
||||||
|
%t: tensor<?x?x?x?xf32>, %x: index, %y: index)
|
||||||
|
-> tensor<?x?xf32> {
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c3 = arith.constant 3 : index
|
||||||
|
%sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
|
||||||
|
%sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32>
|
||||||
|
%0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1]
|
||||||
|
: tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
|
||||||
|
%1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
|
||||||
|
: tensor<1x?x1x?xf32> into tensor<?x?xf32>
|
||||||
|
return %1 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @non_unpadding_collapse_of_extract_slice(
|
||||||
|
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
|
||||||
|
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
||||||
|
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
||||||
|
// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
|
||||||
|
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
|
||||||
|
// CHECK-SAME: [%{{.*}}, %{{.*}}, %[[sz]], 1] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0], [1, 2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||||
|
// CHECK: return %[[collapse]]
|
||||||
|
func.func @non_unpadding_collapse_of_extract_slice(
|
||||||
|
%t: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
|
||||||
|
-> tensor<?x?xf32> {
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%sz0 = tensor.dim %t, %c0 : tensor<?x?x?x?xf32>
|
||||||
|
%sz1 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
|
||||||
|
%0 = tensor.extract_slice %t[%x, %y, 0, 0] [%sz0, %sz1, %sz, 1] [1, 1, 1, 1]
|
||||||
|
: tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
|
||||||
|
%1 = tensor.collapse_shape %0 [[0], [1, 2]]
|
||||||
|
: tensor<?x?x?xf32> into tensor<?x?xf32>
|
||||||
|
return %1 : tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @unpadding_collapse_of_extract_slice_with_multiple_users(
|
||||||
|
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x?xf32>
|
||||||
|
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
||||||
|
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
||||||
|
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
|
||||||
|
// CHECK-SAME: [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
|
||||||
|
// CHECK: %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0, 1], [2, 3]] : tensor<1x?x1x?xf32> into tensor<?x?xf32>
|
||||||
|
// CHECK: return %[[extract]], %[[collapse]]
|
||||||
|
func.func @unpadding_collapse_of_extract_slice_with_multiple_users(
|
||||||
|
%t: tensor<?x?x?x?xf32>, %x: index, %y: index)
|
||||||
|
-> (tensor<1x?x1x?xf32>, tensor<?x?xf32>) {
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
%c3 = arith.constant 3 : index
|
||||||
|
%sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
|
||||||
|
%sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32>
|
||||||
|
%0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1]
|
||||||
|
: tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
|
||||||
|
%1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
|
||||||
|
: tensor<1x?x1x?xf32> into tensor<?x?xf32>
|
||||||
|
return %0, %1 : tensor<1x?x1x?xf32>, tensor<?x?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
|
// CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
|
||||||
// CHECK-SAME: %[[t:.*]]: tensor<?x1x1x5xf32>
|
// CHECK-SAME: %[[t:.*]]: tensor<?x1x1x5xf32>
|
||||||
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
|
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0]
|
||||||
|
// CHECK-SAME: [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
|
||||||
// CHECK: return %[[insert]]
|
// CHECK: return %[[insert]]
|
||||||
func.func @rank_reducing_insert_of_collapse_shape(
|
func.func @rank_reducing_insert_of_collapse_shape(
|
||||||
%t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index)
|
%t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index)
|
||||||
@ -40,7 +115,8 @@ func.func @rank_reducing_insert_of_collapse_shape(
|
|||||||
|
|
||||||
// CHECK-LABEL: func @rank_reducing_parallel_insert_of_collapse_shape(
|
// CHECK-LABEL: func @rank_reducing_parallel_insert_of_collapse_shape(
|
||||||
// CHECK-SAME: %[[t:.*]]: tensor<?x1x1x5xf32>
|
// CHECK-SAME: %[[t:.*]]: tensor<?x1x1x5xf32>
|
||||||
// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
|
// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0]
|
||||||
|
// CHECK-SAME: [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
|
||||||
func.func @rank_reducing_parallel_insert_of_collapse_shape(
|
func.func @rank_reducing_parallel_insert_of_collapse_shape(
|
||||||
%t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index, %thr: index)
|
%t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index, %thr: index)
|
||||||
-> tensor<?x?x?x?xf32> {
|
-> tensor<?x?x?x?xf32> {
|
||||||
@ -62,7 +138,8 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
|
|||||||
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
|
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
|
||||||
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
|
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0]
|
||||||
|
// CHECK-SAME: [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
|
||||||
// CHECK: return %[[insert]]
|
// CHECK: return %[[insert]]
|
||||||
func.func @insert_of_padding_expand_shape(
|
func.func @insert_of_padding_expand_shape(
|
||||||
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
|
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
|
||||||
@ -86,8 +163,10 @@ func.func @insert_of_padding_expand_shape(
|
|||||||
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
|
// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]]
|
||||||
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
|
// CHECK-SAME: output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
|
||||||
|
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0]
|
||||||
|
// CHECK-SAME: [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
|
||||||
// CHECK: return %[[insert]]
|
// CHECK: return %[[insert]]
|
||||||
func.func @insert_of_non_padding_expand_shape(
|
func.func @insert_of_non_padding_expand_shape(
|
||||||
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
|
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
|
||||||
@ -110,7 +189,8 @@ func.func @insert_of_non_padding_expand_shape(
|
|||||||
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
|
// CHECK-SAME: %[[d:.*]]: tensor<?x?x?x?xf32>
|
||||||
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
|
// CHECK: tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0]
|
||||||
|
// CHECK-SAME: [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
|
||||||
func.func @parallel_insert_of_padding_expand_shape(
|
func.func @parallel_insert_of_padding_expand_shape(
|
||||||
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
|
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
|
||||||
-> tensor<?x?x?x?xf32> {
|
-> tensor<?x?x?x?xf32> {
|
||||||
@ -137,8 +217,10 @@ func.func @parallel_insert_of_padding_expand_shape(
|
|||||||
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[x:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[y:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
|
// CHECK-SAME: %[[sz:[a-zA-Z0-9_]+]]: index
|
||||||
// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
|
// CHECK: %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]]
|
||||||
// CHECK: tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
|
// CHECK-SAME: output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
|
||||||
|
// CHECK: tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0]
|
||||||
|
// CHECK-SAME: [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
|
||||||
func.func @parallel_insert_of_non_padding_expand_shape(
|
func.func @parallel_insert_of_non_padding_expand_shape(
|
||||||
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
|
%t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
|
||||||
-> tensor<?x?x?x?xf32> {
|
-> tensor<?x?x?x?xf32> {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user