[mlir] Dialect Conversion: Add support for post-order legalization order (#166292)
By default, the dialect conversion driver processes operations in pre-order: the initial worklist is populated pre-order. (New/modified operations are immediately legalized recursively.) This commit adds a new API for selective post-order legalization. Patterns can request an operation / region legalization via `ConversionPatternRewriter::legalize`. They can call these helper functions on nested regions before rewriting the operation itself. Note: In rollback mode, a failed recursive legalization typically leads to a conversion failure. Since recursive legalization is performed by separate pattern applications, there is no way for the original pattern to recover from such a failure.
This commit is contained in:
parent
c1dc064ba0
commit
a38e094240
@ -981,6 +981,28 @@ public:
|
||||
/// Return a reference to the internal implementation.
|
||||
detail::ConversionPatternRewriterImpl &getImpl();
|
||||
|
||||
/// Attempt to legalize the given operation. This can be used within
|
||||
/// conversion patterns to change the default pre-order legalization order.
|
||||
/// Returns "success" if the operation was legalized, "failure" otherwise.
|
||||
///
|
||||
/// Note: In a partial conversion, this function returns "success" even if
|
||||
/// the operation could not be legalized, as long as it was not explicitly
|
||||
/// marked as illegal in the conversion target.
|
||||
LogicalResult legalize(Operation *op);
|
||||
|
||||
/// Attempt to legalize the given region. This can be used within
|
||||
/// conversion patterns to change the default pre-order legalization order.
|
||||
/// Returns "success" if the region was legalized, "failure" otherwise.
|
||||
///
|
||||
/// If the current pattern runs with a type converter, the entry block
|
||||
/// signature will be converted before legalizing the operations in the
|
||||
/// region.
|
||||
///
|
||||
/// Note: In a partial conversion, this function returns "success" even if
|
||||
/// an operation could not be legalized, as long as it was not explicitly
|
||||
/// marked as illegal in the conversion target.
|
||||
LogicalResult legalize(Region *r);
|
||||
|
||||
private:
|
||||
// Allow OperationConverter to construct new rewriters.
|
||||
friend struct OperationConverter;
|
||||
@ -989,7 +1011,8 @@ private:
|
||||
/// conversions. They apply some IR rewrites in a delayed fashion and could
|
||||
/// bring the IR into an inconsistent state when used standalone.
|
||||
explicit ConversionPatternRewriter(MLIRContext *ctx,
|
||||
const ConversionConfig &config);
|
||||
const ConversionConfig &config,
|
||||
OperationConverter &converter);
|
||||
|
||||
// Hide unsupported pattern rewriter API.
|
||||
using OpBuilder::setListener;
|
||||
|
||||
@ -92,6 +92,22 @@ static OpBuilder::InsertPoint computeInsertPoint(ArrayRef<Value> vals) {
|
||||
return pt;
|
||||
}
|
||||
|
||||
namespace {
|
||||
enum OpConversionMode {
|
||||
/// In this mode, the conversion will ignore failed conversions to allow
|
||||
/// illegal operations to co-exist in the IR.
|
||||
Partial,
|
||||
|
||||
/// In this mode, all operations must be legal for the given target for the
|
||||
/// conversion to succeed.
|
||||
Full,
|
||||
|
||||
/// In this mode, operations are analyzed for legality. No actual rewrites are
|
||||
/// applied to the operations on success.
|
||||
Analysis,
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConversionValueMapping
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -866,8 +882,9 @@ namespace mlir {
|
||||
namespace detail {
|
||||
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
|
||||
const ConversionConfig &config)
|
||||
: rewriter(rewriter), config(config),
|
||||
const ConversionConfig &config,
|
||||
OperationConverter &opConverter)
|
||||
: rewriter(rewriter), config(config), opConverter(opConverter),
|
||||
notifyingRewriter(rewriter.getContext(), config.listener) {}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
@ -1124,6 +1141,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
|
||||
/// Dialect conversion configuration.
|
||||
const ConversionConfig &config;
|
||||
|
||||
/// The operation converter to use for recursive legalization.
|
||||
OperationConverter &opConverter;
|
||||
|
||||
/// A set of erased operations. This set is utilized only if
|
||||
/// `allowPatternRollback` is set to "false". Conceptually, this set is
|
||||
/// similar to `replacedOps` (which is maintained when the flag is set to
|
||||
@ -2084,9 +2104,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ConversionPatternRewriter::ConversionPatternRewriter(
|
||||
MLIRContext *ctx, const ConversionConfig &config)
|
||||
: PatternRewriter(ctx),
|
||||
impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
|
||||
MLIRContext *ctx, const ConversionConfig &config,
|
||||
OperationConverter &opConverter)
|
||||
: PatternRewriter(ctx), impl(new detail::ConversionPatternRewriterImpl(
|
||||
*this, config, opConverter)) {
|
||||
setListener(impl.get());
|
||||
}
|
||||
|
||||
@ -2207,6 +2228,37 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConversionPatternRewriter::legalize(Region *r) {
|
||||
// Fast path: If the region is empty, there is nothing to legalize.
|
||||
if (r->empty())
|
||||
return success();
|
||||
|
||||
// Gather a list of all operations to legalize. This is done before
|
||||
// converting the entry block signature because unrealized_conversion_cast
|
||||
// ops should not be included.
|
||||
SmallVector<Operation *> ops;
|
||||
for (Block &b : *r)
|
||||
for (Operation &op : b)
|
||||
ops.push_back(&op);
|
||||
|
||||
// If the current pattern runs with a type converter, convert the entry block
|
||||
// signature.
|
||||
if (const TypeConverter *converter = impl->currentTypeConverter) {
|
||||
std::optional<TypeConverter::SignatureConversion> conversion =
|
||||
converter->convertBlockSignature(&r->front());
|
||||
if (!conversion)
|
||||
return failure();
|
||||
applySignatureConversion(&r->front(), *conversion, converter);
|
||||
}
|
||||
|
||||
// Legalize all operations in the region.
|
||||
for (Operation *op : ops)
|
||||
if (failed(legalize(op)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
|
||||
Block::iterator before,
|
||||
ValueRange argValues) {
|
||||
@ -3192,22 +3244,6 @@ static void reconcileUnrealizedCasts(
|
||||
// OperationConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
enum OpConversionMode {
|
||||
/// In this mode, the conversion will ignore failed conversions to allow
|
||||
/// illegal operations to co-exist in the IR.
|
||||
Partial,
|
||||
|
||||
/// In this mode, all operations must be legal for the given target for the
|
||||
/// conversion to succeed.
|
||||
Full,
|
||||
|
||||
/// In this mode, operations are analyzed for legality. No actual rewrites are
|
||||
/// applied to the operations on success.
|
||||
Analysis,
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
// This class converts operations to a given conversion target via a set of
|
||||
// rewrite patterns. The conversion behaves differently depending on the
|
||||
@ -3217,16 +3253,20 @@ struct OperationConverter {
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
const ConversionConfig &config,
|
||||
OpConversionMode mode)
|
||||
: rewriter(ctx, config), opLegalizer(rewriter, target, patterns),
|
||||
: rewriter(ctx, config, *this), opLegalizer(rewriter, target, patterns),
|
||||
mode(mode) {}
|
||||
|
||||
/// Converts the given operations to the conversion target.
|
||||
LogicalResult convertOperations(ArrayRef<Operation *> ops);
|
||||
|
||||
private:
|
||||
/// Converts an operation with the given rewriter.
|
||||
LogicalResult convert(Operation *op);
|
||||
/// Converts a single operation. If `isRecursiveLegalization` is "true", the
|
||||
/// conversion is a recursive legalization request, triggered from within a
|
||||
/// pattern. In that case, do not emit errors because there will be another
|
||||
/// attempt at legalizing the operation later (via the regular pre-order
|
||||
/// legalization mechanism).
|
||||
LogicalResult convert(Operation *op, bool isRecursiveLegalization = false);
|
||||
|
||||
private:
|
||||
/// The rewriter to use when converting operations.
|
||||
ConversionPatternRewriter rewriter;
|
||||
|
||||
@ -3238,32 +3278,42 @@ private:
|
||||
};
|
||||
} // namespace mlir
|
||||
|
||||
LogicalResult OperationConverter::convert(Operation *op) {
|
||||
LogicalResult ConversionPatternRewriter::legalize(Operation *op) {
|
||||
return impl->opConverter.convert(op, /*isRecursiveLegalization=*/true);
|
||||
}
|
||||
|
||||
LogicalResult OperationConverter::convert(Operation *op,
|
||||
bool isRecursiveLegalization) {
|
||||
const ConversionConfig &config = rewriter.getConfig();
|
||||
|
||||
// Legalize the given operation.
|
||||
if (failed(opLegalizer.legalize(op))) {
|
||||
// Handle the case of a failed conversion for each of the different modes.
|
||||
// Full conversions expect all operations to be converted.
|
||||
if (mode == OpConversionMode::Full)
|
||||
return op->emitError()
|
||||
<< "failed to legalize operation '" << op->getName() << "'";
|
||||
if (mode == OpConversionMode::Full) {
|
||||
if (!isRecursiveLegalization)
|
||||
op->emitError() << "failed to legalize operation '" << op->getName()
|
||||
<< "'";
|
||||
return failure();
|
||||
}
|
||||
// Partial conversions allow conversions to fail iff the operation was not
|
||||
// explicitly marked as illegal. If the user provided a `unlegalizedOps`
|
||||
// set, non-legalizable ops are added to that set.
|
||||
if (mode == OpConversionMode::Partial) {
|
||||
if (opLegalizer.isIllegal(op))
|
||||
return op->emitError()
|
||||
<< "failed to legalize operation '" << op->getName()
|
||||
<< "' that was explicitly marked illegal";
|
||||
if (config.unlegalizedOps)
|
||||
if (opLegalizer.isIllegal(op)) {
|
||||
if (!isRecursiveLegalization)
|
||||
op->emitError() << "failed to legalize operation '" << op->getName()
|
||||
<< "' that was explicitly marked illegal";
|
||||
return failure();
|
||||
}
|
||||
if (config.unlegalizedOps && !isRecursiveLegalization)
|
||||
config.unlegalizedOps->insert(op);
|
||||
}
|
||||
} else if (mode == OpConversionMode::Analysis) {
|
||||
// Analysis conversions don't fail if any operations fail to legalize,
|
||||
// they are only interested in the operations that were successfully
|
||||
// legalized.
|
||||
if (config.legalizableOps)
|
||||
if (config.legalizableOps && !isRecursiveLegalization)
|
||||
config.legalizableOps->insert(op);
|
||||
}
|
||||
return success();
|
||||
|
||||
@ -72,3 +72,21 @@ builtin.module {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// The region of "test.post_order_legalization" is converted before the op.
|
||||
|
||||
// expected-remark@+1 {{applyFullConversion failed}}
|
||||
builtin.module {
|
||||
func.func @test_preorder_legalization() {
|
||||
// expected-error@+1 {{failed to legalize operation 'test.post_order_legalization'}}
|
||||
"test.post_order_legalization"() ({
|
||||
^bb0(%arg0: i64):
|
||||
// Not-explicitly-legal ops are not allowed to survive.
|
||||
"test.remaining_consumer"(%arg0) : (i64) -> ()
|
||||
"test.invalid"(%arg0) : (i64) -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -163,3 +163,22 @@ func.func @create_unregistered_op_in_pattern() -> i32 {
|
||||
"test.return"(%0) : (i32) -> ()
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @test_failed_preorder_legalization
|
||||
// CHECK: "test.post_order_legalization"() ({
|
||||
// CHECK: %[[r:.*]] = "test.illegal_op_g"() : () -> i32
|
||||
// CHECK: "test.return"(%[[r]]) : (i32) -> ()
|
||||
// CHECK: }) : () -> ()
|
||||
// expected-remark @+1 {{applyPartialConversion failed}}
|
||||
module {
|
||||
func.func @test_failed_preorder_legalization() {
|
||||
// expected-error @+1 {{failed to legalize operation 'test.post_order_legalization' that was explicitly marked illegal}}
|
||||
"test.post_order_legalization"() ({
|
||||
%0 = "test.illegal_op_g"() : () -> (i32)
|
||||
"test.return"(%0) : (i32) -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -448,3 +448,35 @@ func.func @test_working_1to1_pattern(%arg0: f16) {
|
||||
"test.type_consumer"(%arg0) : (f16) -> ()
|
||||
"test.return"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// The region of "test.post_order_legalization" is converted before the op.
|
||||
|
||||
// CHECK: notifyBlockInserted into test.post_order_legalization: was unlinked
|
||||
// CHECK: notifyOperationInserted: test.invalid
|
||||
// CHECK: notifyBlockErased
|
||||
// CHECK: notifyOperationInserted: test.valid, was unlinked
|
||||
// CHECK: notifyOperationReplaced: test.invalid
|
||||
// CHECK: notifyOperationErased: test.invalid
|
||||
// CHECK: notifyOperationModified: test.post_order_legalization
|
||||
|
||||
// CHECK-LABEL: func @test_preorder_legalization
|
||||
// CHECK: "test.post_order_legalization"() ({
|
||||
// CHECK: ^{{.*}}(%[[arg0:.*]]: f64):
|
||||
// Note: The survival of a not-explicitly-invalid operation does *not* cause
|
||||
// a conversion failure in when applying a partial conversion.
|
||||
// CHECK: %[[cast:.*]] = "test.cast"(%[[arg0]]) : (f64) -> i64
|
||||
// CHECK: "test.remaining_consumer"(%[[cast]]) : (i64) -> ()
|
||||
// CHECK: "test.valid"(%[[arg0]]) : (f64) -> ()
|
||||
// CHECK: }) {is_legal} : () -> ()
|
||||
func.func @test_preorder_legalization() {
|
||||
"test.post_order_legalization"() ({
|
||||
^bb0(%arg0: i64):
|
||||
// expected-remark @+1 {{'test.remaining_consumer' is not legalizable}}
|
||||
"test.remaining_consumer"(%arg0) : (i64) -> ()
|
||||
"test.invalid"(%arg0) : (i64) -> ()
|
||||
}) : () -> ()
|
||||
// expected-remark @+1 {{'func.return' is not legalizable}}
|
||||
return
|
||||
}
|
||||
|
||||
@ -1418,6 +1418,22 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class TestPostOrderLegalization : public ConversionPattern {
|
||||
public:
|
||||
TestPostOrderLegalization(MLIRContext *ctx, const TypeConverter &converter)
|
||||
: ConversionPattern(converter, "test.post_order_legalization", 1, ctx) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
for (Region &r : op->getRegions())
|
||||
if (failed(rewriter.legalize(&r)))
|
||||
return failure();
|
||||
rewriter.modifyOpInPlace(
|
||||
op, [&]() { op->setAttr("is_legal", rewriter.getUnitAttr()); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
|
||||
/// function is just to trigger compiler errors. It is never executed.
|
||||
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
|
||||
@ -1532,7 +1548,8 @@ struct TestLegalizePatternDriver
|
||||
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
|
||||
TestPassthroughInvalidOp, TestMultiple1ToNReplacement,
|
||||
TestValueReplace, TestReplaceWithValidConsumer,
|
||||
TestTypeConsumerOpPattern>(&getContext(), converter);
|
||||
TestTypeConsumerOpPattern, TestPostOrderLegalization>(
|
||||
&getContext(), converter);
|
||||
patterns.add<TestConvertBlockArgs>(converter, &getContext());
|
||||
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
|
||||
converter);
|
||||
@ -1560,6 +1577,9 @@ struct TestLegalizePatternDriver
|
||||
target.addDynamicallyLegalOp(
|
||||
OperationName("test.value_replace", &getContext()),
|
||||
[](Operation *op) { return op->hasAttr("is_legal"); });
|
||||
target.addDynamicallyLegalOp(
|
||||
OperationName("test.post_order_legalization", &getContext()),
|
||||
[](Operation *op) { return op->hasAttr("is_legal"); });
|
||||
|
||||
// TestCreateUnregisteredOp creates `arith.constant` operation,
|
||||
// which was not added to target intentionally to test
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user