From ff86be21de109403175caf6d906be856210df494 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 3 Apr 2026 11:21:00 +0200 Subject: [PATCH] [MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (#188980) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The generic MemRefRewritePattern handles AllocOp/AllocaOp by calling getFlattenMemrefAndOffset with the op's own result as the source memref. This inserts ExtractStridedMetadataOp and ReinterpretCastOp that consume op.result before the alloc op itself in the block. After replaceOpWithNewOp, op.result is RAUW'd to the new ReinterpretCastOp result, leaving those earlier ops with forward references — a domination violation caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS. Replace the AllocOp/AllocaOp cases in MemRefRewritePattern with a dedicated AllocLikeFlattenPattern that never touches op.result until the final replaceOpWithNewOp: - sizes come from op.getMixedSizes() (operands, not the result) - strides come from getStridesAndOffset on the MemRefType - the flat allocation size is computed via getLinearizedMemRefOffsetAndSize plus the static base offset so the buffer covers [0, offset+extent) - castAllocResult is simplified to take the pre-computed sizes and strides rather than inserting an ExtractStridedMetadataOp on the original op - non-zero static base offsets are now correctly preserved in the reinterpret_cast (the old code hardcoded offset=0, which was a verifier error for layouts with offset \!= 0) - dynamic offsets or strides bail out via notifyMatchFailure Also remove the now-dead AllocOp/AllocaOp branches from replaceOp() and the constexpr specialisation in getIndices(). Assisted-by: Claude Code --- .../MemRef/Transforms/FlattenMemRefs.cpp | 131 +++++++++++++----- mlir/test/Dialect/MemRef/flatten_memref.mlir | 73 ++++++++++ 2 files changed, 170 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 32244728ff33..6b56ea3ff5ca 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -107,35 +107,11 @@ static Value getTargetMemref(Operation *op) { .Default(nullptr); } -template -static void castAllocResult(T oper, T newOper, Location loc, - PatternRewriter &rewriter) { - memref::ExtractStridedMetadataOp stridedMetadata = - memref::ExtractStridedMetadataOp::create(rewriter, loc, oper); - rewriter.replaceOpWithNewOp( - oper, cast(oper.getType()), newOper, - /*offset=*/rewriter.getIndexAttr(0), - stridedMetadata.getConstifiedMixedSizes(), - stridedMetadata.getConstifiedMixedStrides()); -} - template static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, Value offset) { Location loc = op->getLoc(); llvm::TypeSwitch(op.getOperation()) - .Case([&](memref::AllocOp oper) { - auto newAlloc = memref::AllocOp::create( - rewriter, loc, cast(flatMemref.getType()), - oper.getAlignmentAttr()); - castAllocResult(oper, newAlloc, loc, rewriter); - }) - .Case([&](memref::AllocaOp oper) { - auto newAlloca = memref::AllocaOp::create( - rewriter, loc, cast(flatMemref.getType()), - oper.getAlignmentAttr()); - castAllocResult(oper, newAlloca, loc, rewriter); - }) .Case([&](memref::LoadOp op) { auto newLoad = memref::LoadOp::create(rewriter, loc, op->getResultTypes(), @@ -196,12 +172,7 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, template static ValueRange getIndices(T op) { - if constexpr (std::is_same_v || - std::is_same_v) { - return ValueRange{}; - } else { - return op.getIndices(); - } + return op.getIndices(); } template @@ -230,19 +201,111 @@ static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) { .Default([&](auto op) { return success(); }); } +// Pattern for memref::AllocOp and memref::AllocaOp. +// +// The "source" memref for these ops IS the op's own result, so the generic +// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert +// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op +// in the block. After replaceOpWithNewOp the original result is RAUW'd to the +// new ReinterpretCastOp, leaving the earlier ops with forward references +// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS. +// +// Instead, sizes and strides are computed from the op's operands and type +// (which all dominate the op), avoiding any reference to op.result until the +// final replaceOpWithNewOp. +template +struct AllocLikeFlattenPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AllocLikeOp op, + PatternRewriter &rewriter) const override { + if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref())) + return failure(); + + Location loc = op->getLoc(); + auto memrefType = cast(op.getType()); + auto elemType = memrefType.getElementType(); + if (!elemType.isIntOrFloat()) + return failure(); + unsigned elemBitWidth = elemType.getIntOrFloatBitWidth(); + + SmallVector sizes = op.getMixedSizes(); + + int64_t staticOffset; + SmallVector staticStrides; + if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset))) + return failure(); + if (staticOffset == ShapedType::kDynamic) + return rewriter.notifyMatchFailure(op, "dynamic offset not supported"); + SmallVector strides; + strides.reserve(staticStrides.size()); + for (int64_t stride : staticStrides) { + if (stride == ShapedType::kDynamic) + return rewriter.notifyMatchFailure(op, + "dynamic stride cannot be computed"); + strides.push_back(rewriter.getIndexAttr(stride)); + } + + // Compute the linearized flat extent from sizes and strides (no SSA ops + // referencing op.result are created here). + memref::LinearizedMemRefInfo linearizedInfo; + OpFoldResult linearizedOffset; + std::tie(linearizedInfo, linearizedOffset) = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0), + sizes, strides); + (void)linearizedOffset; + + // The total allocation must cover [0, staticOffset + linearizedExtent). + // When the offset is non-zero, add it to the computed extent so that the + // buffer is large enough for elements accessed at positions + // [staticOffset, staticOffset + linearizedExtent). + OpFoldResult flatSizeOfr = linearizedInfo.linearizedSize; + if (staticOffset != 0) { + AffineExpr s0; + bindSymbols(rewriter.getContext(), s0); + flatSizeOfr = affine::makeComposedFoldedAffineApply( + rewriter, loc, s0 + staticOffset, {flatSizeOfr}); + } + + // Build the flat 1-D MemRefType. The linearized size may be static or + // dynamic (OpFoldResult of either IntegerAttr or a Value). + int64_t flatDimSize = ShapedType::kDynamic; + if (auto attr = dyn_cast(flatSizeOfr)) + if (auto intAttr = dyn_cast(attr)) + flatDimSize = intAttr.getInt(); + + auto flatMemrefType = + MemRefType::get({flatDimSize}, memrefType.getElementType(), + StridedLayoutAttr::get(rewriter.getContext(), 0, {1}), + memrefType.getMemorySpace()); + + // Collect the flat dynamic-size operand (empty for fully-static case). + SmallVector dynSizes; + if (flatDimSize == ShapedType::kDynamic) + dynSizes.push_back(getValueFromOpFoldResult(rewriter, loc, flatSizeOfr)); + + auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes, + op.getAlignmentAttr()); + rewriter.replaceOpWithNewOp( + op, cast(op.getType()), newOp, + rewriter.getIndexAttr(staticOffset), sizes, strides); + return success(); + } +}; + template struct MemRefRewritePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(T op, PatternRewriter &rewriter) const override { LogicalResult canFlatten = canBeFlattened(op, rewriter); - if (failed(canFlatten)) { + if (failed(canFlatten)) return canFlatten; - } Value memref = getTargetMemref(op); if (!needFlattening(memref) || !checkLayout(memref)) return failure(); + auto &&[flatMemref, offset] = getFlattenMemrefAndOffset( rewriter, op->getLoc(), memref, getIndices(op)); replaceOp(op, rewriter, flatMemref, offset); @@ -285,8 +348,8 @@ void memref::populateFlattenVectorOpsOnMemrefPatterns( void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) { patterns.insert, MemRefRewritePattern, - MemRefRewritePattern, - MemRefRewritePattern>( + AllocLikeFlattenPattern, + AllocLikeFlattenPattern>( patterns.getContext()); } diff --git a/mlir/test/Dialect/MemRef/flatten_memref.mlir b/mlir/test/Dialect/MemRef/flatten_memref.mlir index e45a10ca0d43..c9166b11c8d1 100644 --- a/mlir/test/Dialect/MemRef/flatten_memref.mlir +++ b/mlir/test/Dialect/MemRef/flatten_memref.mlir @@ -271,6 +271,79 @@ func.func @alloca() -> memref<4x8xf32> { // ----- +func.func @alloc_dynamic(%n: index) -> memref { + %0 = memref.alloc(%n) : memref + return %0 : memref +} + +// CHECK-LABEL: func @alloc_dynamic +// CHECK-SAME: (%[[N:.*]]: index) +// CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) : memref> +// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [%[[N]], 4], strides: [4, 1] + +// ----- + +func.func @alloca_dynamic(%n: index) -> memref { + %0 = memref.alloca(%n) : memref + return %0 : memref +} + +// CHECK-LABEL: func @alloca_dynamic +// CHECK-SAME: (%[[N:.*]]: index) +// CHECK: %[[ALLOCA:.*]] = memref.alloca(%{{.*}}) : memref> +// CHECK: memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [%[[N]], 4], strides: [4, 1] + +// ----- + +// Explicit row-major strides: same as the default layout, should flatten. +func.func @flatten_alloc_strided_row_major() -> memref<4x8xf32, strided<[8, 1]>> { + %0 = memref.alloc() : memref<4x8xf32, strided<[8, 1]>> + return %0 : memref<4x8xf32, strided<[8, 1]>> +} + +// CHECK-LABEL: func @flatten_alloc_strided_row_major +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>> +// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32, strided<[8, 1]>> + +// ----- + +// Non-zero static offset: the flat allocation covers [0, offset+extent) = [0, 82) +// and the reinterpret_cast restores the original offset in the result type. +func.func @flatten_alloc_strided_offset() -> memref<4x8xf32, strided<[8, 1], offset: 50>> { + %0 = memref.alloc() : memref<4x8xf32, strided<[8, 1], offset: 50>> + return %0 : memref<4x8xf32, strided<[8, 1], offset: 50>> +} + +// CHECK-LABEL: func @flatten_alloc_strided_offset +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<82xf32, strided<[1]>> +// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [50], sizes: [4, 8], strides: [8, 1] : memref<82xf32, strided<[1]>> to memref<4x8xf32, strided<[8, 1], offset: 50>> + +// ----- + +// Padded strides: flatten to the maximum extent (max(18*4, 2*8) = 72). +func.func @flatten_alloc_strided_padded() -> memref<4x8xf32, strided<[18, 2]>> { + %0 = memref.alloc() : memref<4x8xf32, strided<[18, 2]>> + return %0 : memref<4x8xf32, strided<[18, 2]>> +} + +// CHECK-LABEL: func @flatten_alloc_strided_padded +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<72xf32, strided<[1]>> +// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [18, 2] : memref<72xf32, strided<[1]>> to memref<4x8xf32, strided<[18, 2]>> + +// ----- + +// Multi-dynamic alloc: strides are dynamic so the pattern bails out. +func.func @alloc_multi_dynamic(%m: index, %n: index) -> memref { + %0 = memref.alloc(%m, %n) : memref + return %0 : memref +} + +// CHECK-LABEL: func @alloc_multi_dynamic +// CHECK: memref.alloc(%{{.*}}, %{{.*}}) : memref +// CHECK-NOT: memref.reinterpret_cast + +// ----- + func.func @chained_alloc_load() -> vector<8xf32> { %c3 = arith.constant 3 : index %c6 = arith.constant 6 : index