[MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (#188980)
The generic MemRefRewritePattern handles AllocOp/AllocaOp by calling getFlattenMemrefAndOffset with the op's own result as the source memref. This inserts ExtractStridedMetadataOp and ReinterpretCastOp that consume op.result before the alloc op itself in the block. After replaceOpWithNewOp, op.result is RAUW'd to the new ReinterpretCastOp result, leaving those earlier ops with forward references — a domination violation caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS. Replace the AllocOp/AllocaOp cases in MemRefRewritePattern with a dedicated AllocLikeFlattenPattern that never touches op.result until the final replaceOpWithNewOp: - sizes come from op.getMixedSizes() (operands, not the result) - strides come from getStridesAndOffset on the MemRefType - the flat allocation size is computed via getLinearizedMemRefOffsetAndSize plus the static base offset so the buffer covers [0, offset+extent) - castAllocResult is simplified to take the pre-computed sizes and strides rather than inserting an ExtractStridedMetadataOp on the original op - non-zero static base offsets are now correctly preserved in the reinterpret_cast (the old code hardcoded offset=0, which was a verifier error for layouts with offset \!= 0) - dynamic offsets or strides bail out via notifyMatchFailure Also remove the now-dead AllocOp/AllocaOp branches from replaceOp() and the constexpr specialisation in getIndices(). Assisted-by: Claude Code
This commit is contained in:
parent
7c1d91c435
commit
ff86be21de
@ -107,35 +107,11 @@ static Value getTargetMemref(Operation *op) {
|
||||
.Default(nullptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void castAllocResult(T oper, T newOper, Location loc,
|
||||
PatternRewriter &rewriter) {
|
||||
memref::ExtractStridedMetadataOp stridedMetadata =
|
||||
memref::ExtractStridedMetadataOp::create(rewriter, loc, oper);
|
||||
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
|
||||
oper, cast<MemRefType>(oper.getType()), newOper,
|
||||
/*offset=*/rewriter.getIndexAttr(0),
|
||||
stridedMetadata.getConstifiedMixedSizes(),
|
||||
stridedMetadata.getConstifiedMixedStrides());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
|
||||
Value offset) {
|
||||
Location loc = op->getLoc();
|
||||
llvm::TypeSwitch<Operation *>(op.getOperation())
|
||||
.Case([&](memref::AllocOp oper) {
|
||||
auto newAlloc = memref::AllocOp::create(
|
||||
rewriter, loc, cast<MemRefType>(flatMemref.getType()),
|
||||
oper.getAlignmentAttr());
|
||||
castAllocResult(oper, newAlloc, loc, rewriter);
|
||||
})
|
||||
.Case([&](memref::AllocaOp oper) {
|
||||
auto newAlloca = memref::AllocaOp::create(
|
||||
rewriter, loc, cast<MemRefType>(flatMemref.getType()),
|
||||
oper.getAlignmentAttr());
|
||||
castAllocResult(oper, newAlloca, loc, rewriter);
|
||||
})
|
||||
.Case([&](memref::LoadOp op) {
|
||||
auto newLoad =
|
||||
memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
|
||||
@ -196,12 +172,7 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
|
||||
|
||||
template <typename T>
|
||||
static ValueRange getIndices(T op) {
|
||||
if constexpr (std::is_same_v<T, memref::AllocaOp> ||
|
||||
std::is_same_v<T, memref::AllocOp>) {
|
||||
return ValueRange{};
|
||||
} else {
|
||||
return op.getIndices();
|
||||
}
|
||||
return op.getIndices();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -230,19 +201,111 @@ static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
|
||||
.Default([&](auto op) { return success(); });
|
||||
}
|
||||
|
||||
// Pattern for memref::AllocOp and memref::AllocaOp.
|
||||
//
|
||||
// The "source" memref for these ops IS the op's own result, so the generic
|
||||
// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert
|
||||
// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op
|
||||
// in the block. After replaceOpWithNewOp the original result is RAUW'd to the
|
||||
// new ReinterpretCastOp, leaving the earlier ops with forward references
|
||||
// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
|
||||
//
|
||||
// Instead, sizes and strides are computed from the op's operands and type
|
||||
// (which all dominate the op), avoiding any reference to op.result until the
|
||||
// final replaceOpWithNewOp.
|
||||
template <typename AllocLikeOp>
|
||||
struct AllocLikeFlattenPattern : public OpRewritePattern<AllocLikeOp> {
|
||||
using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AllocLikeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
|
||||
return failure();
|
||||
|
||||
Location loc = op->getLoc();
|
||||
auto memrefType = cast<MemRefType>(op.getType());
|
||||
auto elemType = memrefType.getElementType();
|
||||
if (!elemType.isIntOrFloat())
|
||||
return failure();
|
||||
unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
|
||||
|
||||
SmallVector<OpFoldResult> sizes = op.getMixedSizes();
|
||||
|
||||
int64_t staticOffset;
|
||||
SmallVector<int64_t> staticStrides;
|
||||
if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
|
||||
return failure();
|
||||
if (staticOffset == ShapedType::kDynamic)
|
||||
return rewriter.notifyMatchFailure(op, "dynamic offset not supported");
|
||||
SmallVector<OpFoldResult> strides;
|
||||
strides.reserve(staticStrides.size());
|
||||
for (int64_t stride : staticStrides) {
|
||||
if (stride == ShapedType::kDynamic)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"dynamic stride cannot be computed");
|
||||
strides.push_back(rewriter.getIndexAttr(stride));
|
||||
}
|
||||
|
||||
// Compute the linearized flat extent from sizes and strides (no SSA ops
|
||||
// referencing op.result are created here).
|
||||
memref::LinearizedMemRefInfo linearizedInfo;
|
||||
OpFoldResult linearizedOffset;
|
||||
std::tie(linearizedInfo, linearizedOffset) =
|
||||
memref::getLinearizedMemRefOffsetAndSize(
|
||||
rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0),
|
||||
sizes, strides);
|
||||
(void)linearizedOffset;
|
||||
|
||||
// The total allocation must cover [0, staticOffset + linearizedExtent).
|
||||
// When the offset is non-zero, add it to the computed extent so that the
|
||||
// buffer is large enough for elements accessed at positions
|
||||
// [staticOffset, staticOffset + linearizedExtent).
|
||||
OpFoldResult flatSizeOfr = linearizedInfo.linearizedSize;
|
||||
if (staticOffset != 0) {
|
||||
AffineExpr s0;
|
||||
bindSymbols(rewriter.getContext(), s0);
|
||||
flatSizeOfr = affine::makeComposedFoldedAffineApply(
|
||||
rewriter, loc, s0 + staticOffset, {flatSizeOfr});
|
||||
}
|
||||
|
||||
// Build the flat 1-D MemRefType. The linearized size may be static or
|
||||
// dynamic (OpFoldResult of either IntegerAttr or a Value).
|
||||
int64_t flatDimSize = ShapedType::kDynamic;
|
||||
if (auto attr = dyn_cast<Attribute>(flatSizeOfr))
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
|
||||
flatDimSize = intAttr.getInt();
|
||||
|
||||
auto flatMemrefType =
|
||||
MemRefType::get({flatDimSize}, memrefType.getElementType(),
|
||||
StridedLayoutAttr::get(rewriter.getContext(), 0, {1}),
|
||||
memrefType.getMemorySpace());
|
||||
|
||||
// Collect the flat dynamic-size operand (empty for fully-static case).
|
||||
SmallVector<Value, 1> dynSizes;
|
||||
if (flatDimSize == ShapedType::kDynamic)
|
||||
dynSizes.push_back(getValueFromOpFoldResult(rewriter, loc, flatSizeOfr));
|
||||
|
||||
auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
|
||||
op.getAlignmentAttr());
|
||||
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
|
||||
op, cast<MemRefType>(op.getType()), newOp,
|
||||
rewriter.getIndexAttr(staticOffset), sizes, strides);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MemRefRewritePattern : public OpRewritePattern<T> {
|
||||
using OpRewritePattern<T>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(T op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
LogicalResult canFlatten = canBeFlattened(op, rewriter);
|
||||
if (failed(canFlatten)) {
|
||||
if (failed(canFlatten))
|
||||
return canFlatten;
|
||||
}
|
||||
|
||||
Value memref = getTargetMemref(op);
|
||||
if (!needFlattening(memref) || !checkLayout(memref))
|
||||
return failure();
|
||||
|
||||
auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
|
||||
rewriter, op->getLoc(), memref, getIndices<T>(op));
|
||||
replaceOp<T>(op, rewriter, flatMemref, offset);
|
||||
@ -285,8 +348,8 @@ void memref::populateFlattenVectorOpsOnMemrefPatterns(
|
||||
void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
|
||||
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
|
||||
MemRefRewritePattern<memref::StoreOp>,
|
||||
MemRefRewritePattern<memref::AllocOp>,
|
||||
MemRefRewritePattern<memref::AllocaOp>>(
|
||||
AllocLikeFlattenPattern<memref::AllocOp>,
|
||||
AllocLikeFlattenPattern<memref::AllocaOp>>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
|
||||
@ -271,6 +271,79 @@ func.func @alloca() -> memref<4x8xf32> {
|
||||
|
||||
// -----
|
||||
|
||||
func.func @alloc_dynamic(%n: index) -> memref<?x4xf32> {
|
||||
%0 = memref.alloc(%n) : memref<?x4xf32>
|
||||
return %0 : memref<?x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @alloc_dynamic
|
||||
// CHECK-SAME: (%[[N:.*]]: index)
|
||||
// CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) : memref<?xf32, strided<[1]>>
|
||||
// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [%[[N]], 4], strides: [4, 1]
|
||||
|
||||
// -----
|
||||
|
||||
func.func @alloca_dynamic(%n: index) -> memref<?x4xf32> {
|
||||
%0 = memref.alloca(%n) : memref<?x4xf32>
|
||||
return %0 : memref<?x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @alloca_dynamic
|
||||
// CHECK-SAME: (%[[N:.*]]: index)
|
||||
// CHECK: %[[ALLOCA:.*]] = memref.alloca(%{{.*}}) : memref<?xf32, strided<[1]>>
|
||||
// CHECK: memref.reinterpret_cast %[[ALLOCA]] to offset: [0], sizes: [%[[N]], 4], strides: [4, 1]
|
||||
|
||||
// -----
|
||||
|
||||
// Explicit row-major strides: same as the default layout, should flatten.
|
||||
func.func @flatten_alloc_strided_row_major() -> memref<4x8xf32, strided<[8, 1]>> {
|
||||
%0 = memref.alloc() : memref<4x8xf32, strided<[8, 1]>>
|
||||
return %0 : memref<4x8xf32, strided<[8, 1]>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @flatten_alloc_strided_row_major
|
||||
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32xf32, strided<[1]>>
|
||||
// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [8, 1] : memref<32xf32, strided<[1]>> to memref<4x8xf32, strided<[8, 1]>>
|
||||
|
||||
// -----
|
||||
|
||||
// Non-zero static offset: the flat allocation covers [0, offset+extent) = [0, 82)
|
||||
// and the reinterpret_cast restores the original offset in the result type.
|
||||
func.func @flatten_alloc_strided_offset() -> memref<4x8xf32, strided<[8, 1], offset: 50>> {
|
||||
%0 = memref.alloc() : memref<4x8xf32, strided<[8, 1], offset: 50>>
|
||||
return %0 : memref<4x8xf32, strided<[8, 1], offset: 50>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @flatten_alloc_strided_offset
|
||||
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<82xf32, strided<[1]>>
|
||||
// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [50], sizes: [4, 8], strides: [8, 1] : memref<82xf32, strided<[1]>> to memref<4x8xf32, strided<[8, 1], offset: 50>>
|
||||
|
||||
// -----
|
||||
|
||||
// Padded strides: flatten to the maximum extent (max(18*4, 2*8) = 72).
|
||||
func.func @flatten_alloc_strided_padded() -> memref<4x8xf32, strided<[18, 2]>> {
|
||||
%0 = memref.alloc() : memref<4x8xf32, strided<[18, 2]>>
|
||||
return %0 : memref<4x8xf32, strided<[18, 2]>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @flatten_alloc_strided_padded
|
||||
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<72xf32, strided<[1]>>
|
||||
// CHECK: memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [4, 8], strides: [18, 2] : memref<72xf32, strided<[1]>> to memref<4x8xf32, strided<[18, 2]>>
|
||||
|
||||
// -----
|
||||
|
||||
// Multi-dynamic alloc: strides are dynamic so the pattern bails out.
|
||||
func.func @alloc_multi_dynamic(%m: index, %n: index) -> memref<?x?xf32> {
|
||||
%0 = memref.alloc(%m, %n) : memref<?x?xf32>
|
||||
return %0 : memref<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @alloc_multi_dynamic
|
||||
// CHECK: memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
|
||||
// CHECK-NOT: memref.reinterpret_cast
|
||||
|
||||
// -----
|
||||
|
||||
func.func @chained_alloc_load() -> vector<8xf32> {
|
||||
%c3 = arith.constant 3 : index
|
||||
%c6 = arith.constant 6 : index
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user