[mlir] Dialect Conversion: Add support for post-order legalization order (#166292)

By default, the dialect conversion driver processes operations in
pre-order: the initial worklist is populated pre-order. (New/modified
operations are immediately legalized recursively.)

This commit adds a new API for selective post-order legalization.
Patterns can request an operation / region legalization via
`ConversionPatternRewriter::legalize`. They can call these helper
functions on nested regions before rewriting the operation itself.

Note: In rollback mode, a failed recursive legalization typically leads
to a conversion failure. Since recursive legalization is performed by
separate pattern applications, there is no way for the original pattern
to recover from such a failure.
This commit is contained in:
Matthias Springer 2025-11-05 21:04:32 +09:00 committed by GitHub
parent c1dc064ba0
commit a38e094240
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 199 additions and 37 deletions

View File

@ -981,6 +981,28 @@ public:
/// Return a reference to the internal implementation.
detail::ConversionPatternRewriterImpl &getImpl();
/// Attempt to legalize the given operation. This can be used within
/// conversion patterns to change the default pre-order legalization order.
/// Returns "success" if the operation was legalized, "failure" otherwise.
///
/// Note: In a partial conversion, this function returns "success" even if
/// the operation could not be legalized, as long as it was not explicitly
/// marked as illegal in the conversion target.
LogicalResult legalize(Operation *op);
/// Attempt to legalize the given region. This can be used within
/// conversion patterns to change the default pre-order legalization order.
/// Returns "success" if the region was legalized, "failure" otherwise.
///
/// If the current pattern runs with a type converter, the entry block
/// signature will be converted before legalizing the operations in the
/// region.
///
/// Note: In a partial conversion, this function returns "success" even if
/// an operation could not be legalized, as long as it was not explicitly
/// marked as illegal in the conversion target.
LogicalResult legalize(Region *r);
private:
// Allow OperationConverter to construct new rewriters.
friend struct OperationConverter;
@ -989,7 +1011,8 @@ private:
/// conversions. They apply some IR rewrites in a delayed fashion and could
/// bring the IR into an inconsistent state when used standalone.
explicit ConversionPatternRewriter(MLIRContext *ctx,
const ConversionConfig &config);
const ConversionConfig &config,
OperationConverter &converter);
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;

View File

@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
return pt;
}
namespace {
enum OpConversionMode {
/// In this mode, the conversion will ignore failed conversions to allow
/// illegal operations to co-exist in the IR.
Partial,
/// In this mode, all operations must be legal for the given target for the
/// conversion to succeed.
Full,
/// In this mode, operations are analyzed for legality. No actual rewrites are
/// applied to the operations on success.
Analysis,
};
} // namespace
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
@ -866,8 +882,9 @@ namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
: rewriter(rewriter), config(config),
const ConversionConfig &config,
OperationConverter &opConverter)
: rewriter(rewriter), config(config), opConverter(opConverter),
notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
/// The operation converter to use for recursive legalization.
OperationConverter &opConverter;
/// A set of erased operations. This set is utilized only if
/// `allowPatternRollback` is set to "false". Conceptually, this set is
/// similar to `replacedOps` (which is maintained when the flag is set to
@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
//===----------------------------------------------------------------------===//
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
MLIRContext *ctx, const ConversionConfig &config,
OperationConverter &opConverter)
: PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(
*this, config, opConverter)) {
setListener(impl.get());
}
@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
return success();
}
LogicalResult ConversionPatternRewriter::legalize(Region *r) {
// Fast path: If the region is empty, there is nothing to legalize.
if (r->empty())
return success();
// Gather a list of all operations to legalize. This is done before
// converting the entry block signature because unrealized_conversion_cast
// ops should not be included.
SmallVector<Operation *> ops;
for (Block &b : *r)
for (Operation &op : b)
ops.push_back(&op);
// If the current pattern runs with a type converter, convert the entry block
// signature.
if (const TypeConverter *converter = impl->currentTypeConverter) {
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(&r->front());
if (!conversion)
return failure();
applySignatureConversion(&r->front(), *conversion, converter);
}
// Legalize all operations in the region.
for (Operation *op : ops)
if (failed(legalize(op)))
return failure();
return success();
}
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
Block::iterator before,
ValueRange argValues) {
@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts(
// OperationConverter
//===----------------------------------------------------------------------===//
namespace {
enum OpConversionMode {
/// In this mode, the conversion will ignore failed conversions to allow
/// illegal operations to co-exist in the IR.
Partial,
/// In this mode, all operations must be legal for the given target for the
/// conversion to succeed.
Full,
/// In this mode, operations are analyzed for legality. No actual rewrites are
/// applied to the operations on success.
Analysis,
};
} // namespace
namespace mlir {
// This class converts operations to a given conversion target via a set of
// rewrite patterns. The conversion behaves differently depending on the
@ -3217,16 +3253,20 @@ struct OperationConverter {
const FrozenRewritePatternSet &patterns,
const ConversionConfig &config,
OpConversionMode mode)
: rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
: rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
mode(mode) {}
/// Converts the given operations to the conversion target.
LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts an operation with the given rewriter.
LogicalResult convert(Operation *op);
/// Converts a single operation. If `isRecursiveLegalization` is "true", the
/// conversion is a recursive legalization request, triggered from within a
/// pattern. In that case, do not emit errors because there will be another
/// attempt at legalizing the operation later (via the regular pre-order
/// legalization mechanism).
LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
private:
/// The rewriter to use when converting operations.
ConversionPatternRewriter rewriter;
@ -3238,32 +3278,42 @@ private:
};
} // namespace mlir
LogicalResult OperationConverter::convert(Operation *op) {
LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
}
LogicalResult OperationConverter::convert(Operation *op,
bool isRecursiveLegalization) {
const ConversionConfig &config = rewriter.getConfig();
// Legalize the given operation.
if (failed(opLegalizer.legalize(op))) {
// Handle the case of a failed conversion for each of the different modes.
// Full conversions expect all operations to be converted.
if (mode == OpConversionMode::Full)
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
if (mode == OpConversionMode::Full) {
if (!isRecursiveLegalization)
op->emitError() << "failed to legalize operation '" << op->getName()
<< "'";
return failure();
}
// Partial conversions allow conversions to fail iff the operation was not
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
// set, non-legalizable ops are added to that set.
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
if (config.unlegalizedOps)
if (opLegalizer.isIllegal(op)) {
if (!isRecursiveLegalization)
op->emitError() << "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
return failure();
}
if (config.unlegalizedOps && !isRecursiveLegalization)
config.unlegalizedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
if (config.legalizableOps)
if (config.legalizableOps && !isRecursiveLegalization)
config.legalizableOps->insert(op);
}
return success();

View File

@ -72,3 +72,21 @@ builtin.module {
}
}
// -----
// The region of "test.post_order_legalization" is converted before the op.
// expected-remark@+1 {{applyFullConversion failed}}
builtin.module {
func.func @test_preorder_legalization() {
// expected-error@+1 {{failed to legalize operation 'test.post_order_legalization'}}
"test.post_order_legalization"() ({
^bb0(%arg0: i64):
// Not-explicitly-legal ops are not allowed to survive.
"test.remaining_consumer"(%arg0) : (i64) -> ()
"test.invalid"(%arg0) : (i64) -> ()
}) : () -> ()
return
}
}

View File

@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
"test.return"(%0) : (i32) -> ()
}
}
// -----
// CHECK-LABEL: func @test_failed_preorder_legalization
// CHECK: "test.post_order_legalization"() ({
// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32
// CHECK: "test.return"(%[[r]]) : (i32) -> ()
// CHECK: }) : () -> ()
// expected-remark @+1 {{applyPartialConversion failed}}
module {
func.func @test_failed_preorder_legalization() {
// expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
"test.post_order_legalization"() ({
%0 = "test.illegal_op_g"() : () -> (i32)
"test.return"(%0) : (i32) -> ()
}) : () -> ()
return
}
}

View File

@ -448,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
"test.type_consumer"(%arg0) : (f16) -> ()
"test.return"() : () -> ()
}
// -----
// The region of "test.post_order_legalization" is converted before the op.
// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
// CHECK: notifyOperationInserted: test.invalid
// CHECK: notifyBlockErased
// CHECK: notifyOperationInserted: test.valid, was unlinked
// CHECK: notifyOperationReplaced: test.invalid
// CHECK: notifyOperationErased: test.invalid
// CHECK: notifyOperationModified: test.post_order_legalization
// CHECK-LABEL: func @test_preorder_legalization
// CHECK: "test.post_order_legalization"() ({
// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
// Note: The survival of a not-explicitly-invalid operation does *not* cause
// a conversion failure in when applying a partial conversion.
// CHECK: %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64
// CHECK: "test.remaining_consumer"(%[[cast]]) : (i64) -> ()
// CHECK: "test.valid"(%[[arg0]]) : (f64) -> ()
// CHECK: }) {is_legal} : () -> ()
func.func @test_preorder_legalization() {
"test.post_order_legalization"() ({
^bb0(%arg0: i64):
// expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
"test.remaining_consumer"(%arg0) : (i64) -> ()
"test.invalid"(%arg0) : (i64) -> ()
}) : () -> ()
// expected-remark @+1 {{'func.return' is not legalizable}}
return
}

View File

@ -1418,6 +1418,22 @@ public:
}
};
class TestPostOrderLegalization : public ConversionPattern {
public:
TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
: ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
ConversionPatternRewriter &rewriter) const final {
for (Region &r : op->getRegions())
if (failed(rewriter.legalize(&r)))
return failure();
rewriter.modifyOpInPlace(
op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
return success();
}
};
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
/// function is just to trigger compiler errors. It is never executed.
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
TestValueReplace, TestReplaceWithValidConsumer,
TestTypeConsumerOpPattern>(&getContext(), converter);
TestTypeConsumerOpPattern, TestPostOrderLegalization>(
&getContext(), converter);
patterns.add<TestConvertBlockArgs>(converter, &getContext());
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
converter);
@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
target.addDynamicallyLegalOp(
OperationName("test.value_replace", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
target.addDynamicallyLegalOp(
OperationName("test.post_order_legalization", &getContext()),
[](Operation *op) { return op->hasAttr("is_legal"); });
// TestCreateUnregisteredOp creates `arith.constant` operation,
// which was not added to target intentionally to test