Mehdi Amini ff86be21de
[MLIR][MemRef] Fix AllocOp/AllocaOp flattening domination violation (#188980)
The generic MemRefRewritePattern handles AllocOp/AllocaOp by calling
getFlattenMemrefAndOffset with the op's own result as the source memref.
This inserts ExtractStridedMetadataOp and ReinterpretCastOp that consume
op.result before the alloc op itself in the block. After
replaceOpWithNewOp, op.result is RAUW'd to the new ReinterpretCastOp
result, leaving those earlier ops with forward references — a domination
violation caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.

Replace the AllocOp/AllocaOp cases in MemRefRewritePattern with a
dedicated AllocLikeFlattenPattern that never touches op.result until the
final replaceOpWithNewOp:
- sizes come from op.getMixedSizes() (operands, not the result)
- strides come from getStridesAndOffset on the MemRefType
- the flat allocation size is computed via
getLinearizedMemRefOffsetAndSize plus the static base offset so the
buffer covers [0, offset+extent)
- castAllocResult is simplified to take the pre-computed sizes and
strides rather than inserting an ExtractStridedMetadataOp on the
original op
- non-zero static base offsets are now correctly preserved in the
reinterpret_cast (the old code hardcoded offset=0, which was a verifier
error for layouts with offset \!= 0)
- dynamic offsets or strides bail out via notifyMatchFailure

Also remove the now-dead AllocOp/AllocaOp branches from replaceOp() and
the constexpr specialisation in getIndices().

Assisted-by: Claude Code
2026-04-03 11:21:00 +02:00

360 lines
14 KiB
C++

//===----- FlattenMemRefs.cpp - MemRef ops flattener pass ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains patterns for flattening an multi-rank memref-related
// ops into 1-d memref ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
namespace mlir {
namespace memref {
#define GEN_PASS_DEF_FLATTENMEMREFSPASS
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
} // namespace memref
} // namespace mlir
using namespace mlir;
static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc,
OpFoldResult in) {
if (Attribute offsetAttr = dyn_cast<Attribute>(in)) {
return arith::ConstantIndexOp::create(
rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt());
}
return cast<Value>(in);
}
/// Returns a collapsed memref and the linearized index to access the element
/// at the specified indices.
static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter,
Location loc,
Value source,
ValueRange indices) {
int64_t sourceOffset;
SmallVector<int64_t, 4> sourceStrides;
auto sourceType = cast<MemRefType>(source.getType());
if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset))) {
assert(false);
}
memref::ExtractStridedMetadataOp stridedMetadata =
memref::ExtractStridedMetadataOp::create(rewriter, loc, source);
auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth();
OpFoldResult linearizedIndices;
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, typeBit, typeBit,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
getAsOpFoldResult(indices));
return std::make_pair(
memref::ReinterpretCastOp::create(
rewriter, loc, source,
/* offset = */ linearizedInfo.linearizedOffset,
/* shapes = */
ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize},
/* strides = */
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}),
getValueFromOpFoldResult(rewriter, loc, linearizedIndices));
}
static bool needFlattening(Value val) {
auto type = cast<MemRefType>(val.getType());
return type.getRank() > 1;
}
static bool checkLayout(Value val) {
auto type = cast<MemRefType>(val.getType());
return type.getLayout().isIdentity() ||
isa<StridedLayoutAttr>(type.getLayout());
}
namespace {
static Value getTargetMemref(Operation *op) {
return llvm::TypeSwitch<Operation *, Value>(op)
.template Case<memref::LoadOp, memref::StoreOp, memref::AllocaOp,
memref::AllocOp>([](auto op) { return op.getMemref(); })
.template Case<vector::LoadOp, vector::StoreOp, vector::MaskedLoadOp,
vector::MaskedStoreOp, vector::TransferReadOp,
vector::TransferWriteOp>(
[](auto op) { return op.getBase(); })
.Default(nullptr);
}
template <typename T>
static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref,
Value offset) {
Location loc = op->getLoc();
llvm::TypeSwitch<Operation *>(op.getOperation())
.Case([&](memref::LoadOp op) {
auto newLoad =
memref::LoadOp::create(rewriter, loc, op->getResultTypes(),
flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
})
.Case([&](memref::StoreOp op) {
auto newStore =
memref::StoreOp::create(rewriter, loc, op->getOperands().front(),
flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
})
.Case([&](vector::LoadOp op) {
auto newLoad =
vector::LoadOp::create(rewriter, loc, op->getResultTypes(),
flatMemref, ValueRange{offset});
newLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newLoad.getResult());
})
.Case([&](vector::StoreOp op) {
auto newStore =
vector::StoreOp::create(rewriter, loc, op->getOperands().front(),
flatMemref, ValueRange{offset});
newStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newStore);
})
.Case([&](vector::MaskedLoadOp op) {
auto newMaskedLoad = vector::MaskedLoadOp::create(
rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
op.getMask(), op.getPassThru());
newMaskedLoad->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedLoad.getResult());
})
.Case([&](vector::MaskedStoreOp op) {
auto newMaskedStore = vector::MaskedStoreOp::create(
rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(),
op.getValueToStore());
newMaskedStore->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newMaskedStore);
})
.Case([&](vector::TransferReadOp op) {
auto newTransferRead = vector::TransferReadOp::create(
rewriter, loc, op.getType(), flatMemref, ValueRange{offset},
op.getPadding());
rewriter.replaceOp(op, newTransferRead.getResult());
})
.Case([&](vector::TransferWriteOp op) {
auto newTransferWrite = vector::TransferWriteOp::create(
rewriter, loc, op.getVector(), flatMemref, ValueRange{offset});
rewriter.replaceOp(op, newTransferWrite);
})
.Default([&](auto op) {
op->emitOpError("unimplemented: do not know how to replace op.");
});
}
template <typename T>
static ValueRange getIndices(T op) {
return op.getIndices();
}
template <typename T>
static LogicalResult canBeFlattened(T op, PatternRewriter &rewriter) {
return llvm::TypeSwitch<Operation *, LogicalResult>(op.getOperation())
.template Case<vector::TransferReadOp, vector::TransferWriteOp>(
[&](auto oper) {
// For vector.transfer_read/write, must make sure:
// 1. all accesses are inbound, and
// 2. has an identity or minor identity permutation map.
auto permutationMap = oper.getPermutationMap();
if (!permutationMap.isIdentity() &&
!permutationMap.isMinorIdentity()) {
return rewriter.notifyMatchFailure(
oper, "only identity permutation map is supported");
}
mlir::ArrayAttr inbounds = oper.getInBounds();
if (llvm::any_of(inbounds, [](Attribute attr) {
return !cast<BoolAttr>(attr).getValue();
})) {
return rewriter.notifyMatchFailure(oper,
"only inbounds are supported");
}
return success();
})
.Default([&](auto op) { return success(); });
}
// Pattern for memref::AllocOp and memref::AllocaOp.
//
// The "source" memref for these ops IS the op's own result, so the generic
// MemRefRewritePattern cannot be used: getFlattenMemrefAndOffset would insert
// ExtractStridedMetadataOp and ReinterpretCastOp that use op.result BEFORE op
// in the block. After replaceOpWithNewOp the original result is RAUW'd to the
// new ReinterpretCastOp, leaving the earlier ops with forward references
// (domination violations) caught by MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS.
//
// Instead, sizes and strides are computed from the op's operands and type
// (which all dominate the op), avoiding any reference to op.result until the
// final replaceOpWithNewOp.
template <typename AllocLikeOp>
struct AllocLikeFlattenPattern : public OpRewritePattern<AllocLikeOp> {
using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AllocLikeOp op,
PatternRewriter &rewriter) const override {
if (!needFlattening(op.getMemref()) || !checkLayout(op.getMemref()))
return failure();
Location loc = op->getLoc();
auto memrefType = cast<MemRefType>(op.getType());
auto elemType = memrefType.getElementType();
if (!elemType.isIntOrFloat())
return failure();
unsigned elemBitWidth = elemType.getIntOrFloatBitWidth();
SmallVector<OpFoldResult> sizes = op.getMixedSizes();
int64_t staticOffset;
SmallVector<int64_t> staticStrides;
if (failed(memrefType.getStridesAndOffset(staticStrides, staticOffset)))
return failure();
if (staticOffset == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(op, "dynamic offset not supported");
SmallVector<OpFoldResult> strides;
strides.reserve(staticStrides.size());
for (int64_t stride : staticStrides) {
if (stride == ShapedType::kDynamic)
return rewriter.notifyMatchFailure(op,
"dynamic stride cannot be computed");
strides.push_back(rewriter.getIndexAttr(stride));
}
// Compute the linearized flat extent from sizes and strides (no SSA ops
// referencing op.result are created here).
memref::LinearizedMemRefInfo linearizedInfo;
OpFoldResult linearizedOffset;
std::tie(linearizedInfo, linearizedOffset) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, elemBitWidth, elemBitWidth, rewriter.getIndexAttr(0),
sizes, strides);
(void)linearizedOffset;
// The total allocation must cover [0, staticOffset + linearizedExtent).
// When the offset is non-zero, add it to the computed extent so that the
// buffer is large enough for elements accessed at positions
// [staticOffset, staticOffset + linearizedExtent).
OpFoldResult flatSizeOfr = linearizedInfo.linearizedSize;
if (staticOffset != 0) {
AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
flatSizeOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, s0 + staticOffset, {flatSizeOfr});
}
// Build the flat 1-D MemRefType. The linearized size may be static or
// dynamic (OpFoldResult of either IntegerAttr or a Value).
int64_t flatDimSize = ShapedType::kDynamic;
if (auto attr = dyn_cast<Attribute>(flatSizeOfr))
if (auto intAttr = dyn_cast<IntegerAttr>(attr))
flatDimSize = intAttr.getInt();
auto flatMemrefType =
MemRefType::get({flatDimSize}, memrefType.getElementType(),
StridedLayoutAttr::get(rewriter.getContext(), 0, {1}),
memrefType.getMemorySpace());
// Collect the flat dynamic-size operand (empty for fully-static case).
SmallVector<Value, 1> dynSizes;
if (flatDimSize == ShapedType::kDynamic)
dynSizes.push_back(getValueFromOpFoldResult(rewriter, loc, flatSizeOfr));
auto newOp = AllocLikeOp::create(rewriter, loc, flatMemrefType, dynSizes,
op.getAlignmentAttr());
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, cast<MemRefType>(op.getType()), newOp,
rewriter.getIndexAttr(staticOffset), sizes, strides);
return success();
}
};
template <typename T>
struct MemRefRewritePattern : public OpRewritePattern<T> {
using OpRewritePattern<T>::OpRewritePattern;
LogicalResult matchAndRewrite(T op,
PatternRewriter &rewriter) const override {
LogicalResult canFlatten = canBeFlattened(op, rewriter);
if (failed(canFlatten))
return canFlatten;
Value memref = getTargetMemref(op);
if (!needFlattening(memref) || !checkLayout(memref))
return failure();
auto &&[flatMemref, offset] = getFlattenMemrefAndOffset(
rewriter, op->getLoc(), memref, getIndices<T>(op));
replaceOp<T>(op, rewriter, flatMemref, offset);
return success();
}
};
struct FlattenMemrefsPass
: public mlir::memref::impl::FlattenMemrefsPassBase<FlattenMemrefsPass> {
using Base::Base;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect, arith::ArithDialect,
memref::MemRefDialect, vector::VectorDialect>();
}
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
memref::populateFlattenMemrefsPatterns(patterns);
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};
} // namespace
void memref::populateFlattenVectorOpsOnMemrefPatterns(
RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<vector::LoadOp>,
MemRefRewritePattern<vector::StoreOp>,
MemRefRewritePattern<vector::TransferReadOp>,
MemRefRewritePattern<vector::TransferWriteOp>,
MemRefRewritePattern<vector::MaskedLoadOp>,
MemRefRewritePattern<vector::MaskedStoreOp>>(
patterns.getContext());
}
void memref::populateFlattenMemrefOpsPatterns(RewritePatternSet &patterns) {
patterns.insert<MemRefRewritePattern<memref::LoadOp>,
MemRefRewritePattern<memref::StoreOp>,
AllocLikeFlattenPattern<memref::AllocOp>,
AllocLikeFlattenPattern<memref::AllocaOp>>(
patterns.getContext());
}
void memref::populateFlattenMemrefsPatterns(RewritePatternSet &patterns) {
populateFlattenMemrefOpsPatterns(patterns);
populateFlattenVectorOpsOnMemrefPatterns(patterns);
}