[mlir][memref]: Allow collapse dummy strided unit dim (#103719)
Dimensions of size 1 should be skipped, because their strides are meaningless and could have any arbitrary value.
This commit is contained in:
parent
65281570af
commit
76c0798425
@ -2448,6 +2448,11 @@ computeCollapsedLayoutMap(MemRefType srcType,
|
||||
if (strict && (stride.saturated || srcStride.saturated))
|
||||
return failure();
|
||||
|
||||
// Dimensions of size 1 should be skipped, because their strides are
|
||||
// meaningless and could have any arbitrary value.
|
||||
if (srcShape[idx - 1] == 1)
|
||||
continue;
|
||||
|
||||
if (!stride.saturated && !srcStride.saturated && stride != srcStride)
|
||||
return failure();
|
||||
}
|
||||
|
@ -99,7 +99,9 @@ func.func @expand_collapse_shape_static(
|
||||
%arg4: memref<1x5xf32, strided<[5, 1], offset: ?>>,
|
||||
%arg5: memref<f32>,
|
||||
%arg6: memref<3x4x5xf32, strided<[240, 60, 10], offset: 0>>,
|
||||
%arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>) {
|
||||
%arg7: memref<1x2049xi64, strided<[?, ?], offset: ?>>,
|
||||
%arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
|
||||
%arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
|
||||
// Reshapes that collapse and expand back a contiguous buffer.
|
||||
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1], [2]]
|
||||
// CHECK-SAME: memref<3x4x5xf32> into memref<12x5xf32>
|
||||
@ -163,6 +165,19 @@ func.func @expand_collapse_shape_static(
|
||||
memref<1x2049xi64, strided<[?, ?], offset: ?>> into
|
||||
memref<2049xi64, strided<[?], offset: ?>>
|
||||
|
||||
// %arg8: memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>>,
|
||||
// %arg9: memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>>) {
|
||||
|
||||
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0, 1, 2]]
|
||||
%r8 = memref.collapse_shape %arg8 [[0, 1, 2]] :
|
||||
memref<1x1x1024xi8, strided<[40960, 4096, 1], offset: 0>> into
|
||||
memref<1024xi8, strided<[1], offset: 0>>
|
||||
|
||||
// CHECK: memref.collapse_shape {{.*}} {{\[}}[0], [1, 2, 3]]
|
||||
%r9 = memref.collapse_shape %arg9 [[0], [1, 2, 3]] :
|
||||
memref<24x1x1x1024xi8, strided<[40960, 40960, 4096, 1], offset: 0>> into
|
||||
memref<24x1024xi8, strided<[40960, 1], offset: 0>>
|
||||
|
||||
// Reshapes that expand and collapse back a contiguous buffer with some 1's.
|
||||
// CHECK: memref.expand_shape {{.*}} {{\[}}[0, 1], [2], [3, 4]] output_shape [1, 3, 4, 1, 5]
|
||||
// CHECK-SAME: memref<3x4x5xf32> into memref<1x3x4x1x5xf32>
|
||||
|
Loading…
x
Reference in New Issue
Block a user