[mlir][tensor] Fold pack-unpack with unbalanced outer_dims_perm attr (#92234)
Extends pack/unpack perm attribute checker to account for cases when the optional outer_dims_perm attribute might be missing in one operation and the other one has explicit identity permutation. This enables canonicalizer to fold more unpack(pack(x)) variants.
This commit is contained in:
parent
8f711aa324
commit
dcd32bd65f
@ -4112,7 +4112,13 @@ Speculation::Speculatability PackOp::getSpeculatability() {
|
||||
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
|
||||
if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
|
||||
return false;
|
||||
return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
|
||||
if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
|
||||
return true;
|
||||
// Outer dims permutation is optional.
|
||||
// To compare unbalanced pack-unpack pair, treat no permutation as equal to
|
||||
// identity permutation.
|
||||
return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
|
||||
isIdentityPermutation(unPackOp.getOuterDimsPerm());
|
||||
}
|
||||
|
||||
// Return true if pack and unpack have the same tiles.
|
||||
|
@ -2252,6 +2252,32 @@ func.func @pack_unpack_dynamic_with_padding(%t: tensor<?x?x?x?xf32>, %dim1: inde
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @pack_outer_dims_unpack_no_outer_dims(
|
||||
// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
|
||||
// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
|
||||
func.func @pack_outer_dims_unpack_no_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
|
||||
%tensor_empty = tensor.empty() : tensor<128x128xf32>
|
||||
%unpacked = tensor.unpack %t inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
|
||||
%tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
|
||||
%packed = tensor.pack %unpacked outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
|
||||
return %packed : tensor<16x16x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @pack_no_outer_dims_unpack_outer_dims(
|
||||
// CHECK-SAME: %[[T:.+]]: tensor<16x16x?x?xf32>,
|
||||
// CHECK: return %[[T]] : tensor<16x16x?x?xf32>
|
||||
func.func @pack_no_outer_dims_unpack_outer_dims(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index) -> tensor<16x16x?x?xf32> {
|
||||
%tensor_empty = tensor.empty() : tensor<128x128xf32>
|
||||
%unpacked = tensor.unpack %t outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty : tensor<16x16x?x?xf32> -> tensor<128x128xf32>
|
||||
%tensor_empty1 = tensor.empty(%tile1, %tile2) : tensor<16x16x?x?xf32>
|
||||
%packed = tensor.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [%tile1, %tile2] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x?x?xf32>
|
||||
return %packed : tensor<16x16x?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @invalid_empty_negative_size
|
||||
// CHECK: %[[IDX:.*]] = index.constant
|
||||
// CHECK: %[[T:.*]] = tensor.empty(%[[IDX]]) : tensor<4x5x?xf32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user