From 158f10fe24a39208e45d6039dfc6d605967ade2a Mon Sep 17 00:00:00 2001 From: Ming Yan Date: Wed, 1 Apr 2026 21:30:12 +0800 Subject: [PATCH] [mlir][memref] Fold memref.reinterpret_cast operations with valid offset or size constants. (#189533) When encountering an invalid offset or size, we only skip the current invalid value and continue attempting to fold other valid offsets or sizes. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 33 ++++++++------ mlir/test/Dialect/MemRef/canonicalize.mlir | 53 +++++++++++++++++----- 2 files changed, 59 insertions(+), 27 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 31546f123b51..27c1649ee4ed 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -2331,6 +2331,24 @@ public: SmallVector sizes = op.getConstifiedMixedSizes(); SmallVector strides = op.getConstifiedMixedStrides(); + // If the offset is a negative constant, we can't fold it because the + // resulting memref type would be invalid. In that case, we keep the + // original offset. + if (auto cst = getConstantIntValue(offsets[0])) + if (*cst < 0) + offsets[0] = op.getMixedOffsets()[0]; + + // If the size is a negative constant, we can't fold it because the + // resulting memref type would be invalid. In that case, we keep the + // original size. + for (auto it : llvm::zip(op.getMixedSizes(), sizes)) { + auto &srcSizeOfr = std::get<0>(it); + auto &sizeOfr = std::get<1>(it); + if (auto cst = getConstantIntValue(sizeOfr)) + if (*cst < 0) + sizeOfr = srcSizeOfr; + } + // TODO: Using counting comparison instead of direct comparison because // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns // IntegerAttrs, while constifyIndexValues (and therefore @@ -2340,21 +2358,6 @@ public: [](OpFoldResult ofr) { return isa(ofr); })) return failure(); - // Do not fold if the offset is a negative constant; ViewLikeInterface - // verifies that static offsets are non-negative. - if (auto cst = getConstantIntValue(offsets[0])) - if (*cst < 0) - return rewriter.notifyMatchFailure( - op, "negative constant offset is invalid"); - - // Do not fold if any size is a negative constant; MemRefType::get asserts - // non-negative static sizes. - for (OpFoldResult sizeOfr : sizes) - if (auto cst = getConstantIntValue(sizeOfr)) - if (*cst < 0) - return rewriter.notifyMatchFailure( - op, "negative constant size is invalid"); - auto newReinterpretCast = ReinterpretCastOp::create( rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index fb1e7d00feb4..6c4fd6f8f58d 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -1287,16 +1287,14 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : me // ----- -// Check that reinterpret_cast with a negative constant size is not folded. +// Check that reinterpret_cast with a negative constant size. // Folding would attempt to create a MemRefType with a negative static dimension, // which triggers an assertion in MemRefType::get (issue #188407). -// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_size +// CHECK-LABEL: func @reinterpret_cast_with_negative_size // CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>) -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[SZ:.*]] = arith.constant -1 : index -// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [%[[C1]], %[[SZ]]], strides: [%[[SZ]], %[[C1]]] -func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> memref> { +// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [1, %[[SZ]]], strides: [-1, 1] +func.func @reinterpret_cast_with_negative_size(%arg0: memref<2x3xf32>) -> memref> { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %sz = arith.constant -1 : index @@ -1308,16 +1306,14 @@ func.func @reinterpret_cast_no_fold_negative_size(%arg0: memref<2x3xf32>) -> mem // ----- -// Check that reinterpret_cast with a negative constant offset is not folded. +// Check that reinterpret_cast with a negative constant offset. // Folding would create an op with a static negative offset, which violates the // ViewLikeInterface constraint that offsets must be non-negative. -// CHECK-LABEL: func @reinterpret_cast_no_fold_negative_offset +// CHECK-LABEL: func @reinterpret_cast_with_negative_offset // CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>) -// CHECK: %[[C1:.*]] = arith.constant 1 : index -// CHECK: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[NEG:.*]] = arith.constant -1 : index -// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[C1]], %[[C2]]], strides: [%[[C2]], %[[C1]]] -func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> memref> { +// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, 2], strides: [2, 1] +func.func @reinterpret_cast_with_negative_offset(%arg0: memref<2x3xf32>) -> memref> { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %neg = arith.constant -1 : index @@ -1329,6 +1325,39 @@ func.func @reinterpret_cast_no_fold_negative_offset(%arg0: memref<2x3xf32>) -> m // ----- +// Check that reinterpret_cast with a negative constant size and offset. +// CHECK-LABEL: func @reinterpret_cast_with_negative_size_and_offset +// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>) +// CHECK: %[[NEG:.*]] = arith.constant -1 : index +// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [1, %[[NEG]]], strides: [2, 1] +func.func @reinterpret_cast_with_negative_size_and_offset(%arg0: memref<2x3xf32>) -> memref> { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %neg = arith.constant -1 : index + %output = memref.reinterpret_cast %arg0 to + offset: [%neg], sizes: [%c1, %neg], strides: [%c2, %c1] + : memref<2x3xf32> to memref> + return %output : memref> +} + +// ----- + +// Check that reinterpret_cast with all negative constant size and offset is not +// folded. +// CHECK-LABEL: func @reinterpret_cast_no_fold_with_all_negative_size_and_offset +// CHECK-SAME: (%[[ARG:.*]]: memref<2x3xf32>) +// CHECK: %[[NEG:.*]] = arith.constant -1 : index +// CHECK: memref.reinterpret_cast %[[ARG]] to offset: [%[[NEG]]], sizes: [%[[NEG]], %[[NEG]]], strides: [2, 1] +func.func @reinterpret_cast_no_fold_with_all_negative_size_and_offset(%arg0: memref<2x3xf32>) -> memref> { + %neg = arith.constant -1 : index + %output = memref.reinterpret_cast %arg0 to + offset: [%neg], sizes: [%neg, %neg], strides: [2, 1] + : memref<2x3xf32> to memref> + return %output : memref> +} + +// ----- + // Check that reinterpret_cast with a negative constant stride IS folded. // Negative strides are valid in MemRef layouts (e.g. reverse iteration), // and the ViewLikeInterface places no non-negativity constraint on strides.