[mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately (#148415)

Store metadata about unresolved materializations in a separate data
structure. This is in preparation of the One-Shot Dialect Conversion
refactoring, which no longer maintains a stack of `IRRewrite` objects.
Therefore, metadata about unresolved materializations can no longer be
retrieved from `UnresolvedMaterializationRewrite` objects.

This commit also removes a pointer indirection and may slightly improve
the performance of the existing driver.
This commit is contained in:
Matthias Springer 2025-07-15 10:14:37 +02:00 committed by GitHub
parent cbdc18542c
commit 8ee32c7b36
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -789,26 +789,13 @@ enum MaterializationKind {
Source
};
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
/// op. Unresolved materializations are erased at the end of the dialect
/// conversion.
class UnresolvedMaterializationRewrite : public OperationRewrite {
/// Helper class that stores metadata about an unresolved materialization.
class UnresolvedMaterializationInfo {
public:
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op,
const TypeConverter *converter,
MaterializationKind kind, Type originalType,
ValueVector mappedValues);
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
}
void rollback() override;
UnrealizedConversionCastOp getOperation() const {
return cast<UnrealizedConversionCastOp>(op);
}
UnresolvedMaterializationInfo() = default;
UnresolvedMaterializationInfo(const TypeConverter *converter,
MaterializationKind kind, Type originalType)
: converterAndKind(converter, kind), originalType(originalType) {}
/// Return the type converter of this materialization (which may be null).
const TypeConverter *getConverter() const {
@ -832,7 +819,30 @@ private:
/// The original type of the SSA value. Only used for target
/// materializations.
Type originalType;
};
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
/// op. Unresolved materializations fold away or are replaced with
/// source/target materializations at the end of the dialect conversion.
class UnresolvedMaterializationRewrite : public OperationRewrite {
public:
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
UnrealizedConversionCastOp op,
ValueVector mappedValues)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
mappedValues(std::move(mappedValues)) {}
static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::UnresolvedMaterialization;
}
void rollback() override;
UnrealizedConversionCastOp getOperation() const {
return cast<UnrealizedConversionCastOp>(op);
}
private:
/// The values in the conversion value mapping that are being replaced by the
/// results of this unresolved materialization.
ValueVector mappedValues;
@ -1088,9 +1098,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
/// to the corresponding rewrite objects.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
/// A mapping for looking up metadata of unresolved materializations.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
/// The current type converter, or nullptr if no type converter is currently
@ -1210,18 +1219,6 @@ void CreateOperationRewrite::rollback() {
op->erase();
}
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
const TypeConverter *converter, MaterializationKind kind, Type originalType,
ValueVector mappedValues)
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
converterAndKind(converter, kind), originalType(originalType),
mappedValues(std::move(mappedValues)) {
assert((!originalType || kind == MaterializationKind::Target) &&
"original type is valid only for target materializations");
rewriterImpl.unresolvedMaterializations[op] = this;
}
void UnresolvedMaterializationRewrite::rollback() {
if (!mappedValues.empty())
rewriterImpl.mapping.erase(mappedValues);
@ -1510,8 +1507,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
mapping.map(valuesToMap, convertOp.getResults());
if (castOp)
*castOp = convertOp;
appendRewrite<UnresolvedMaterializationRewrite>(
convertOp, converter, kind, originalType, std::move(valuesToMap));
unresolvedMaterializations[convertOp] =
UnresolvedMaterializationInfo(converter, kind, originalType);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
std::move(valuesToMap));
return convertOp.getResults();
}
@ -2678,21 +2677,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
static LogicalResult
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
UnresolvedMaterializationRewrite *rewrite) {
UnrealizedConversionCastOp op = rewrite->getOperation();
UnrealizedConversionCastOp op,
const UnresolvedMaterializationInfo &info) {
assert(!op.use_empty() &&
"expected that dead materializations have already been DCE'd");
Operation::operand_range inputOperands = op.getOperands();
// Try to materialize the conversion.
if (const TypeConverter *converter = rewrite->getConverter()) {
if (const TypeConverter *converter = info.getConverter()) {
rewriter.setInsertionPoint(op);
SmallVector<Value> newMaterialization;
switch (rewrite->getMaterializationKind()) {
switch (info.getMaterializationKind()) {
case MaterializationKind::Target:
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
rewrite->getOriginalType());
info.getOriginalType());
break;
case MaterializationKind::Source:
assert(op->getNumResults() == 1 && "expected single result");
@ -2767,7 +2766,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Gather all unresolved materializations.
SmallVector<UnrealizedConversionCastOp> allCastOps;
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
&materializations = rewriterImpl.unresolvedMaterializations;
for (auto it : materializations)
allCastOps.push_back(it.first);
@ -2784,7 +2783,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = materializations.find(castOp);
assert(it != materializations.end() && "inconsistent state");
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
if (failed(
legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
return failure();
}
}