[mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (#189533)
When encountering an invalid offset or size, we only skip the current invalid value and continue attempting to fold other valid offsets or sizes.
This commit is contained in:
parent
461a1c51bf
commit
158f10fe24
@ -2331,6 +2331,24 @@ public:
|
||||
SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
|
||||
SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
|
||||
|
||||
// If the offset is a negative constant, we can't fold it because the
|
||||
// resulting memref type would be invalid. In that case, we keep the
|
||||
// original offset.
|
||||
if (auto cst = getConstantIntValue(offsets[0]))
|
||||
if (*cst < 0)
|
||||
offsets[0] = op.getMixedOffsets()[0];
|
||||
|
||||
// If the size is a negative constant, we can't fold it because the
|
||||
// resulting memref type would be invalid. In that case, we keep the
|
||||
// original size.
|
||||
for (auto it : llvm::zip(op.getMixedSizes(), sizes)) {
|
||||
auto &srcSizeOfr = std::get<0>(it);
|
||||
auto &sizeOfr = std::get<1>(it);
|
||||
if (auto cst = getConstantIntValue(sizeOfr))
|
||||
if (*cst < 0)
|
||||
sizeOfr = srcSizeOfr;
|
||||
}
|
||||
|
||||
// TODO: Using counting comparison instead of direct comparison because
|
||||
// getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
|
||||
// IntegerAttrs, while constifyIndexValues (and therefore
|
||||
@ -2340,21 +2358,6 @@ public:
|
||||
[](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
|
||||
return failure();
|
||||
|
||||
// Do not fold if the offset is a negative constant; ViewLikeInterface
|
||||
// verifies that static offsets are non-negative.
|
||||
if (auto cst = getConstantIntValue(offsets[0]))
|
||||
if (*cst < 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "negative constant offset is invalid");
|
||||
|
||||
// Do not fold if any size is a negative constant; MemRefType::get asserts
|
||||
// non-negative static sizes.
|
||||
for (OpFoldResult sizeOfr : sizes)
|
||||
if (auto cst = getConstantIntValue(sizeOfr))
|
||||
if (*cst < 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "negative constant size is invalid");
|
||||
|
||||
auto newReinterpretCast = ReinterpretCastOp::create(
|
||||
rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
|
||||
|
||||
|
||||
@ -1287,16 +1287,14 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me
|
||||
|
||||
// -----
|
||||
|
||||
// Check that reinterpret_cast with a negative constant size is not folded.
|
||||
// Check that reinterpret_cast with a negative constant size.
|
||||
// Folding would attempt to create a MemRefType with a negative static dimension,
|
||||
// which triggers an assertion in MemRefType::get (issue #188407).
|
||||
// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size
|
||||
// CHECK-LABEL: func @reinterpret_cast_with_negative_size
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
|
||||
// CHECK: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[SZ:.*]] = arith.constant -1 : index
|
||||
// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]]
|
||||
func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
|
||||
// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1]
|
||||
func.func @reinterpret_cast_with_negative_size(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
%sz = arith.constant -1 : index
|
||||
@ -1308,16 +1306,14 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem
|
||||
|
||||
// -----
|
||||
|
||||
// Check that reinterpret_cast with a negative constant offset is not folded.
|
||||
// Check that reinterpret_cast with a negative constant offset.
|
||||
// Folding would create an op with a static negative offset, which violates the
|
||||
// ViewLikeInterface constraint that offsets must be non-negative.
|
||||
// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset
|
||||
// CHECK-LABEL: func @reinterpret_cast_with_negative_offset
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
|
||||
// CHECK: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK: %[[NEG:.*]] = arith.constant -1 : index
|
||||
// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
|
||||
func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
|
||||
// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1]
|
||||
func.func @reinterpret_cast_with_negative_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%neg = arith.constant -1 : index
|
||||
@ -1329,6 +1325,39 @@ func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> m
|
||||
|
||||
// -----
|
||||
|
||||
// Check that reinterpret_cast with a negative constant size and offset.
|
||||
// CHECK-LABEL: func @reinterpret_cast_with_negative_size_and_offset
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
|
||||
// CHECK: %[[NEG:.*]] = arith.constant -1 : index
|
||||
// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, %[[NEG]]], strides: [2, 1]
|
||||
func.func @reinterpret_cast_with_negative_size_and_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%neg = arith.constant -1 : index
|
||||
%output = memref.reinterpret_cast %arg0 to
|
||||
offset: [%neg], sizes: [%c1, %neg], strides: [%c2, %c1]
|
||||
: memref<2x3xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
return %output : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that reinterpret_cast with all negative constant size and offset is not
|
||||
// folded.
|
||||
// CHECK-LABEL: func @reinterpret_cast_no_fold_with_all_negative_size_and_offset
|
||||
// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>)
|
||||
// CHECK: %[[NEG:.*]] = arith.constant -1 : index
|
||||
// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[NEG]], %[[NEG]]], strides: [2, 1]
|
||||
func.func @reinterpret_cast_no_fold_with_all_negative_size_and_offset(%arg0: memref<2x3xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
|
||||
%neg = arith.constant -1 : index
|
||||
%output = memref.reinterpret_cast %arg0 to
|
||||
offset: [%neg], sizes: [%neg, %neg], strides: [2, 1]
|
||||
: memref<2x3xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
return %output : memref<?x?xf32, strided<[?, ?], offset: ?>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that reinterpret_cast with a negative constant stride IS folded.
|
||||
// Negative strides are valid in MemRef layouts (e.g. reverse iteration),
|
||||
// and the ViewLikeInterface places no non-negativity constraint on strides.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user