diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 4e651a048989..ea828480be4c 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1161,6 +1161,16 @@ public: // ConversionConfig //===----------------------------------------------------------------------===// +/// An enum to control folding behavior during dialect conversion. +enum class DialectConversionFoldingMode { + /// Never attempt to fold. + Never, + /// Only attempt to fold not legal operations before applying patterns. + BeforePatterns, + /// Only attempt to fold not legal operations after applying patterns. + AfterPatterns, +}; + /// Dialect conversion configuration. struct ConversionConfig { /// An optional callback used to notify about match failure diagnostics during @@ -1243,6 +1253,10 @@ struct ConversionConfig { /// your patterns do not trigger any IR rollbacks. For details, see /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083. bool allowPatternRollback = true; + + /// The folding mode to use during conversion. + DialectConversionFoldingMode foldingMode = + DialectConversionFoldingMode::BeforePatterns; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 0c26b4ed46b3..2470f2b122de 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2257,15 +2257,17 @@ OperationLegalizer::legalize(Operation *op, return success(); } - // If the operation isn't legal, try to fold it in-place. - // TODO: Should we always try to do this, even if the op is - // already legal? - if (succeeded(legalizeWithFold(op, rewriter))) { - LLVM_DEBUG({ - logSuccess(logger, "operation was folded"); - logger.startLine() << logLineComment; - }); - return success(); + // If the operation is not legal, try to fold it in-place if the folding mode + // is 'BeforePatterns'. 'Never' will skip this. + const ConversionConfig &config = rewriter.getConfig(); + if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) { + if (succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } } // Otherwise, we need to apply a legalization pattern to this operation. @@ -2277,6 +2279,18 @@ OperationLegalizer::legalize(Operation *op, return success(); } + // If the operation can't be legalized via patterns, try to fold it in-place + // if the folding mode is 'AfterPatterns'. + if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) { + if (succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } + } + LLVM_DEBUG({ logFailure(logger, "no matched legalization pattern"); logger.startLine() << logLineComment; diff --git a/mlir/test/Transforms/test-legalizer-fold-after.mlir b/mlir/test/Transforms/test-legalizer-fold-after.mlir new file mode 100644 index 000000000000..7f80252dc960 --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-fold-after.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=after-patterns" | FileCheck %s + +// CHECK-LABEL: @fold_legalization +func.func @fold_legalization() -> i32 { + // CHECK-NOT: op_in_place_self_fold + // CHECK: 97 + %1 = "test.op_in_place_self_fold"() : () -> (i32) + "test.return"(%1) : (i32) -> () +} diff --git a/mlir/test/Transforms/test-legalizer-fold-before.mlir b/mlir/test/Transforms/test-legalizer-fold-before.mlir new file mode 100644 index 000000000000..fe6e29351a5d --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-fold-before.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -test-legalize-patterns="test-legalize-folding-mode=before-patterns" | FileCheck %s + +// CHECK-LABEL: @fold_legalization +func.func @fold_legalization() -> i32 { + // CHECK: op_in_place_self_fold + // CHECK-SAME: folded + %1 = "test.op_in_place_self_fold"() : () -> (i32) + "test.return"(%1) : (i32) -> () +} diff --git a/mlir/test/Transforms/test-legalizer-no-fold.mlir b/mlir/test/Transforms/test-legalizer-no-fold.mlir new file mode 100644 index 000000000000..720d17f41943 --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-fold.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -test-legalize-patterns="test-legalize-folding-mode=never" | FileCheck %s + +// CHECK-LABEL: @remove_foldable_op( +func.func @remove_foldable_op(%arg0 : i32) -> (i32) { + // Check that op was not folded. + // CHECK: "test.op_with_region_fold" + %0 = "test.op_with_region_fold"(%arg0) ({ + "foo.op_with_region_terminator"() : () -> () + }) : (i32) -> (i32) + "test.return"(%0) : (i32) -> () +} + diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 3cdd2f226687..231400ec9cd2 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1498,6 +1498,8 @@ def TestOpInPlaceSelfFold : TEST_Op<"op_in_place_self_fold"> { let results = (outs I32); let hasFolder = 1; } +def : Pat<(TestOpInPlaceSelfFold:$op $_), + (TestOpConstant ConstantAttr)>; // Test op that simply returns success. def TestOpInPlaceFoldSuccess : TEST_Op<"op_in_place_fold_success"> { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index f92f0982f85b..ff958d9a3d2b 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1507,8 +1507,8 @@ struct TestLegalizePatternDriver ConversionTarget target(getContext()); target.addLegalOp(); target.addLegalOp(); + TerminatorOp, TestOpConstant, OneRegionOp, + TestValidProducerOp, TestValidConsumerOp>(); target.addLegalOp(OperationName("test.legal_op", &getContext())); target .addIllegalOp(); @@ -1563,6 +1563,7 @@ struct TestLegalizePatternDriver DumpNotifications dumpNotifications; config.listener = &dumpNotifications; config.unlegalizedOps = &unlegalizedOps; + config.foldingMode = foldingMode; if (failed(applyPartialConversion(getOperation(), target, std::move(patterns), config))) { getOperation()->emitRemark() << "applyPartialConversion failed"; @@ -1582,6 +1583,7 @@ struct TestLegalizePatternDriver ConversionConfig config; DumpNotifications dumpNotifications; + config.foldingMode = foldingMode; config.listener = &dumpNotifications; if (failed(applyFullConversion(getOperation(), target, std::move(patterns), config))) { @@ -1596,6 +1598,7 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet legalizedOps; ConversionConfig config; + config.foldingMode = foldingMode; config.legalizableOps = &legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, std::move(patterns), config))) @@ -1616,6 +1619,21 @@ struct TestLegalizePatternDriver clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"), clEnumValN(ConversionMode::Partial, "partial", "Perform a partial conversion"))}; + + Option foldingMode{ + *this, "test-legalize-folding-mode", + llvm::cl::desc("The folding mode to use with the test driver"), + llvm::cl::init(DialectConversionFoldingMode::BeforePatterns), + llvm::cl::values(clEnumValN(DialectConversionFoldingMode::Never, "never", + "Never attempt to fold"), + clEnumValN(DialectConversionFoldingMode::BeforePatterns, + "before-patterns", + "Only attempt to fold not legal operations " + "before applying patterns"), + clEnumValN(DialectConversionFoldingMode::AfterPatterns, + "after-patterns", + "Only attempt to fold not legal operations " + "after applying patterns"))}; }; } // namespace