[mlir][Transforms][NFC] Dialect Conversion: Store materialization metadata separately (#148415)
Store metadata about unresolved materializations in a separate data structure. This is in preparation of the One-Shot Dialect Conversion refactoring, which no longer maintains a stack of `IRRewrite` objects. Therefore, metadata about unresolved materializations can no longer be retrieved from `UnresolvedMaterializationRewrite` objects. This commit also removes a pointer indirection and may slightly improve the performance of the existing driver.
This commit is contained in:
parent
cbdc18542c
commit
8ee32c7b36
@ -789,26 +789,13 @@ enum MaterializationKind {
|
||||
Source
|
||||
};
|
||||
|
||||
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
|
||||
/// op. Unresolved materializations are erased at the end of the dialect
|
||||
/// conversion.
|
||||
class UnresolvedMaterializationRewrite : public OperationRewrite {
|
||||
/// Helper class that stores metadata about an unresolved materialization.
|
||||
class UnresolvedMaterializationInfo {
|
||||
public:
|
||||
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
|
||||
UnrealizedConversionCastOp op,
|
||||
const TypeConverter *converter,
|
||||
MaterializationKind kind, Type originalType,
|
||||
ValueVector mappedValues);
|
||||
|
||||
static bool classof(const IRRewrite *rewrite) {
|
||||
return rewrite->getKind() == Kind::UnresolvedMaterialization;
|
||||
}
|
||||
|
||||
void rollback() override;
|
||||
|
||||
UnrealizedConversionCastOp getOperation() const {
|
||||
return cast<UnrealizedConversionCastOp>(op);
|
||||
}
|
||||
UnresolvedMaterializationInfo() = default;
|
||||
UnresolvedMaterializationInfo(const TypeConverter *converter,
|
||||
MaterializationKind kind, Type originalType)
|
||||
: converterAndKind(converter, kind), originalType(originalType) {}
|
||||
|
||||
/// Return the type converter of this materialization (which may be null).
|
||||
const TypeConverter *getConverter() const {
|
||||
@ -832,7 +819,30 @@ private:
|
||||
/// The original type of the SSA value. Only used for target
|
||||
/// materializations.
|
||||
Type originalType;
|
||||
};
|
||||
|
||||
/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
|
||||
/// op. Unresolved materializations fold away or are replaced with
|
||||
/// source/target materializations at the end of the dialect conversion.
|
||||
class UnresolvedMaterializationRewrite : public OperationRewrite {
|
||||
public:
|
||||
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
|
||||
UnrealizedConversionCastOp op,
|
||||
ValueVector mappedValues)
|
||||
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
|
||||
mappedValues(std::move(mappedValues)) {}
|
||||
|
||||
static bool classof(const IRRewrite *rewrite) {
|
||||
return rewrite->getKind() == Kind::UnresolvedMaterialization;
|
||||
}
|
||||
|
||||
void rollback() override;
|
||||
|
||||
UnrealizedConversionCastOp getOperation() const {
|
||||
return cast<UnrealizedConversionCastOp>(op);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The values in the conversion value mapping that are being replaced by the
|
||||
/// results of this unresolved materialization.
|
||||
ValueVector mappedValues;
|
||||
@ -1088,9 +1098,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
/// by the current pattern.
|
||||
SetVector<Block *> patternInsertedBlocks;
|
||||
|
||||
/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
|
||||
/// to the corresponding rewrite objects.
|
||||
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
|
||||
/// A mapping for looking up metadata of unresolved materializations.
|
||||
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
|
||||
unresolvedMaterializations;
|
||||
|
||||
/// The current type converter, or nullptr if no type converter is currently
|
||||
@ -1210,18 +1219,6 @@ void CreateOperationRewrite::rollback() {
|
||||
op->erase();
|
||||
}
|
||||
|
||||
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
|
||||
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
|
||||
const TypeConverter *converter, MaterializationKind kind, Type originalType,
|
||||
ValueVector mappedValues)
|
||||
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
|
||||
converterAndKind(converter, kind), originalType(originalType),
|
||||
mappedValues(std::move(mappedValues)) {
|
||||
assert((!originalType || kind == MaterializationKind::Target) &&
|
||||
"original type is valid only for target materializations");
|
||||
rewriterImpl.unresolvedMaterializations[op] = this;
|
||||
}
|
||||
|
||||
void UnresolvedMaterializationRewrite::rollback() {
|
||||
if (!mappedValues.empty())
|
||||
rewriterImpl.mapping.erase(mappedValues);
|
||||
@ -1510,8 +1507,10 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
|
||||
mapping.map(valuesToMap, convertOp.getResults());
|
||||
if (castOp)
|
||||
*castOp = convertOp;
|
||||
appendRewrite<UnresolvedMaterializationRewrite>(
|
||||
convertOp, converter, kind, originalType, std::move(valuesToMap));
|
||||
unresolvedMaterializations[convertOp] =
|
||||
UnresolvedMaterializationInfo(converter, kind, originalType);
|
||||
appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
|
||||
std::move(valuesToMap));
|
||||
return convertOp.getResults();
|
||||
}
|
||||
|
||||
@ -2678,21 +2677,21 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
|
||||
|
||||
static LogicalResult
|
||||
legalizeUnresolvedMaterialization(RewriterBase &rewriter,
|
||||
UnresolvedMaterializationRewrite *rewrite) {
|
||||
UnrealizedConversionCastOp op = rewrite->getOperation();
|
||||
UnrealizedConversionCastOp op,
|
||||
const UnresolvedMaterializationInfo &info) {
|
||||
assert(!op.use_empty() &&
|
||||
"expected that dead materializations have already been DCE'd");
|
||||
Operation::operand_range inputOperands = op.getOperands();
|
||||
|
||||
// Try to materialize the conversion.
|
||||
if (const TypeConverter *converter = rewrite->getConverter()) {
|
||||
if (const TypeConverter *converter = info.getConverter()) {
|
||||
rewriter.setInsertionPoint(op);
|
||||
SmallVector<Value> newMaterialization;
|
||||
switch (rewrite->getMaterializationKind()) {
|
||||
switch (info.getMaterializationKind()) {
|
||||
case MaterializationKind::Target:
|
||||
newMaterialization = converter->materializeTargetConversion(
|
||||
rewriter, op->getLoc(), op.getResultTypes(), inputOperands,
|
||||
rewrite->getOriginalType());
|
||||
info.getOriginalType());
|
||||
break;
|
||||
case MaterializationKind::Source:
|
||||
assert(op->getNumResults() == 1 && "expected single result");
|
||||
@ -2767,7 +2766,7 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
|
||||
|
||||
// Gather all unresolved materializations.
|
||||
SmallVector<UnrealizedConversionCastOp> allCastOps;
|
||||
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
|
||||
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
|
||||
&materializations = rewriterImpl.unresolvedMaterializations;
|
||||
for (auto it : materializations)
|
||||
allCastOps.push_back(it.first);
|
||||
@ -2784,7 +2783,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
|
||||
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
|
||||
auto it = materializations.find(castOp);
|
||||
assert(it != materializations.end() && "inconsistent state");
|
||||
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
|
||||
if (failed(
|
||||
legalizeUnresolvedMaterialization(rewriter, castOp, it->second)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user