From 8ee32c7b367336f7281aa78c823f7bae4d5287c2 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 15 Jul 2025 10:14:37 +0200 Subject: [PATCH] [mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately (#148415) Store metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of `IRRewrite` objects. Therefore, metadata about unresolved materializations can no longer be retrieved from `UnresolvedMaterializationRewrite` objects. This commit also removes a pointer indirection and may slightly improve the performance of the existing driver. --- .../Transforms/Utils/DialectConversion.cpp | 86 +++++++++---------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4d01a83d3716..4c4ce3cb41fd 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -789,26 +789,13 @@ enum MaterializationKind { Source }; -/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" -/// op. Unresolved materializations are erased at the end of the dialect -/// conversion. -class UnresolvedMaterializationRewrite : public OperationRewrite { +/// Helper class that stores metadata about an unresolved materialization. +class UnresolvedMaterializationInfo { public: - UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl, - UnrealizedConversionCastOp op, - const TypeConverter *converter, - MaterializationKind kind, Type originalType, - ValueVector mappedValues); - - static bool classof(const IRRewrite *rewrite) { - return rewrite->getKind() == Kind::UnresolvedMaterialization; - } - - void rollback() override; - - UnrealizedConversionCastOp getOperation() const { - return cast(op); - } + UnresolvedMaterializationInfo() = default; + UnresolvedMaterializationInfo(const TypeConverter *converter, + MaterializationKind kind, Type originalType) + : converterAndKind(converter, kind), originalType(originalType) {} /// Return the type converter of this materialization (which may be null). const TypeConverter *getConverter() const { @@ -832,7 +819,30 @@ private: /// The original type of the SSA value. Only used for target /// materializations. Type originalType; +}; +/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast" +/// op. Unresolved materializations fold away or are replaced with +/// source/target materializations at the end of the dialect conversion. +class UnresolvedMaterializationRewrite : public OperationRewrite { +public: + UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl, + UnrealizedConversionCastOp op, + ValueVector mappedValues) + : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), + mappedValues(std::move(mappedValues)) {} + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::UnresolvedMaterialization; + } + + void rollback() override; + + UnrealizedConversionCastOp getOperation() const { + return cast(op); + } + +private: /// The values in the conversion value mapping that are being replaced by the /// results of this unresolved materialization. ValueVector mappedValues; @@ -1088,9 +1098,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// by the current pattern. SetVector patternInsertedBlocks; - /// A mapping of all unresolved materializations (UnrealizedConversionCastOp) - /// to the corresponding rewrite objects. - DenseMap + /// A mapping for looking up metadata of unresolved materializations. + DenseMap unresolvedMaterializations; /// The current type converter, or nullptr if no type converter is currently @@ -1210,18 +1219,6 @@ void CreateOperationRewrite::rollback() { op->erase(); } -UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( - ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, - const TypeConverter *converter, MaterializationKind kind, Type originalType, - ValueVector mappedValues) - : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), - converterAndKind(converter, kind), originalType(originalType), - mappedValues(std::move(mappedValues)) { - assert((!originalType || kind == MaterializationKind::Target) && - "original type is valid only for target materializations"); - rewriterImpl.unresolvedMaterializations[op] = this; -} - void UnresolvedMaterializationRewrite::rollback() { if (!mappedValues.empty()) rewriterImpl.mapping.erase(mappedValues); @@ -1510,8 +1507,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( mapping.map(valuesToMap, convertOp.getResults()); if (castOp) *castOp = convertOp; - appendRewrite( - convertOp, converter, kind, originalType, std::move(valuesToMap)); + unresolvedMaterializations[convertOp] = + UnresolvedMaterializationInfo(converter, kind, originalType); + appendRewrite(convertOp, + std::move(valuesToMap)); return convertOp.getResults(); } @@ -2678,21 +2677,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, static LogicalResult legalizeUnresolvedMaterialization(RewriterBase &rewriter, - UnresolvedMaterializationRewrite *rewrite) { - UnrealizedConversionCastOp op = rewrite->getOperation(); + UnrealizedConversionCastOp op, + const UnresolvedMaterializationInfo &info) { assert(!op.use_empty() && "expected that dead materializations have already been DCE'd"); Operation::operand_range inputOperands = op.getOperands(); // Try to materialize the conversion. - if (const TypeConverter *converter = rewrite->getConverter()) { + if (const TypeConverter *converter = info.getConverter()) { rewriter.setInsertionPoint(op); SmallVector newMaterialization; - switch (rewrite->getMaterializationKind()) { + switch (info.getMaterializationKind()) { case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), op.getResultTypes(), inputOperands, - rewrite->getOriginalType()); + info.getOriginalType()); break; case MaterializationKind::Source: assert(op->getNumResults() == 1 && "expected single result"); @@ -2767,7 +2766,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { // Gather all unresolved materializations. SmallVector allCastOps; - const DenseMap + const DenseMap &materializations = rewriterImpl.unresolvedMaterializations; for (auto it : materializations) allCastOps.push_back(it.first); @@ -2784,7 +2783,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef ops) { for (UnrealizedConversionCastOp castOp : remainingCastOps) { auto it = materializations.find(castOp); assert(it != materializations.end() && "inconsistent state"); - if (failed(legalizeUnresolvedMaterialization(rewriter, it->second))) + if (failed( + legalizeUnresolvedMaterialization(rewriter, castOp, it->second))) return failure(); } }