[MLIR][Tensor] Add Destination style RewritePattern for DimOp. (#65780)
This enables canonicalization to fold away unnecessary tensor.dim ops which in turn enables folding away of other operations, as can be seen in conv_tensors_dynamic where affine.min operations were folded away.
This commit is contained in:
parent
2744d9b649
commit
0aa459fc8a
@ -579,11 +579,32 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Fold dim of a destination passing style op into the dim of the corresponding
|
||||
/// init.
|
||||
struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
|
||||
using OpRewritePattern<DimOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DimOp dimOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto source = dimOp.getSource();
|
||||
auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
|
||||
if (!destOp)
|
||||
return failure();
|
||||
|
||||
auto resultIndex = source.cast<OpResult>().getResultNumber();
|
||||
auto initOperand = destOp.getDpsInitOperand(resultIndex);
|
||||
|
||||
rewriter.updateRootInPlace(
|
||||
dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<DimOfCastOp>(context);
|
||||
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -397,9 +397,8 @@ func.func @fold_static_pad_fill() -> tensor<412x276xf32> {
|
||||
|
||||
// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[OF:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[SRC]] : tensor<8x?x16x32xf32>)
|
||||
// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
|
||||
// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
|
||||
// CHECK: %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32>
|
||||
// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
|
||||
// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
|
||||
// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
|
||||
@ -908,3 +907,24 @@ func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
|
||||
ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
|
||||
return %arg0 : tensor<16x64x256xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @canonicalize_dim_of_dest_style_op
|
||||
// CHECK: tensor.dim
|
||||
// CHECK: tensor.dim
|
||||
// CHECK-NOT: tensor.dim
|
||||
// CHECK: return
|
||||
func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%dim0_0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%dim1_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||
%0 = tensor.empty(%dim0_0, %dim1_0) : tensor<?x?xf32>
|
||||
%1 = linalg.copy ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%dim0_1 = tensor.dim %1, %c0 : tensor<?x?xf32>
|
||||
%dim1_1 = tensor.dim %1, %c1 : tensor<?x?xf32>
|
||||
%2 = tensor.empty(%dim0_1, %dim1_1) : tensor<?x?xf32>
|
||||
%3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %3: tensor<?x?xf32>
|
||||
}
|
||||
|
@ -197,10 +197,8 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
|
||||
// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
|
||||
// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
|
||||
// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * -2 + s0 * 2 + s1 - 2, d1 * 2 + s1 - 2)>
|
||||
// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)>
|
||||
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
|
||||
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
|
||||
// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 4)>
|
||||
// CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, -d1 + s1, 2)>
|
||||
|
||||
// CHECK: func @conv_tensors_dynamic
|
||||
@ -225,8 +223,6 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
|
||||
// CHECK-DAG: %[[FILTER_OC:.+]] = tensor.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_N:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_C:.+]] = tensor.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[FILL_H:.+]] = tensor.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[FILL_W:.+]] = tensor.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
|
||||
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
|
||||
@ -234,14 +230,12 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
|
||||
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
|
||||
// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
|
||||
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[FILL_H]], %[[FILTER_H]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
|
||||
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]], %[[FILTER_H]]]
|
||||
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
|
||||
// CHECK-NEXT: %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
|
||||
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[FILL_W]], %[[FILTER_W]]]
|
||||
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
|
||||
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]], %[[FILTER_W]]]
|
||||
// CHECK-NEXT: %[[ST_INPUT:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
|
||||
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
|
||||
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
|
||||
@ -253,7 +247,7 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
|
||||
// CHECK-NEXT: %[[ST_FILTER:.+]] = tensor.extract_slice %[[FILTER]][0, 0, 0, %[[IV3]]]
|
||||
// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
|
||||
// CHECK-NEXT: %[[ST_FILL:.+]] = tensor.extract_slice %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
|
||||
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_2]]]
|
||||
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC_2]]]
|
||||
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_nhwc_hwcf
|
||||
// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[ST_FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
|
@ -43,9 +43,7 @@ transform.sequence failures(propagate) {
|
||||
// CHECK: arith.addf
|
||||
// CHECK: linalg.yield
|
||||
// CHECK: } -> tensor<?x?xf32>
|
||||
// CHECK: %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
|
||||
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
|
||||
// CHECK: scf.yield %[[INS]] : tensor<?x5xf32>
|
||||
// CHECK: }
|
||||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
|
||||
@ -76,14 +74,16 @@ transform.sequence failures(propagate) {
|
||||
by tile_sizes = [5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
|
||||
// CHECK: func @reduction_tile_transpose
|
||||
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
|
||||
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
|
||||
// CHECK: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK: %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
|
||||
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
|
||||
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
|
||||
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
|
||||
// CHECK: }
|
||||
// CHECK: linalg.generic
|
||||
|
Loading…
x
Reference in New Issue
Block a user