From 09dfc5713d7e2342bea4c8447d1ed76c85eb8225 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 20 Dec 2024 08:15:48 -0800 Subject: [PATCH] [mlir] Enable decoupling two kinds of greedy behavior. (#104649) The greedy rewriter is used in many different flows and it has a lot of convenience (work list management, debugging actions, tracing, etc). But it combines two kinds of greedy behavior 1) how ops are matched, 2) folding wherever it can. These are independent forms of greedy and leads to inefficiency. E.g., cases where one need to create different phases in lowering and is required to applying patterns in specific order split across different passes. Using the driver one ends up needlessly retrying folding/having multiple rounds of folding attempts, where one final run would have sufficed. Of course folks can locally avoid this behavior by just building their own, but this is also a common requested feature that folks keep on working around locally in suboptimal ways. For downstream users, there should be no behavioral change. Updating from the deprecated should just be a find and replace (e.g., `find ./ -type f -exec sed -i 's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety) as the API arguments hasn't changed between the two. --- .../HLFIR/Transforms/InlineElementals.cpp | 2 +- .../HLFIR/Transforms/LowerHLFIRIntrinsics.cpp | 4 +- .../Transforms/OptimizedBufferization.cpp | 2 +- .../Transforms/SimplifyHLFIRIntrinsics.cpp | 2 +- .../Transforms/AlgebraicSimplification.cpp | 3 +- .../Transforms/AssumedRankOpConversion.cpp | 2 +- .../ConstantArgumentGlobalisation.cpp | 4 +- .../lib/Optimizer/Transforms/StackArrays.cpp | 4 +- mlir/docs/PatternRewriter.md | 2 +- .../lib/Standalone/StandalonePasses.cpp | 2 +- .../Transforms/GreedyPatternRewriteDriver.h | 72 ++++++++++++++----- mlir/lib/CAPI/Transforms/Rewrite.cpp | 3 +- .../ArithToAMDGPU/ArithToAMDGPU.cpp | 2 +- .../ArithToArmSME/ArithToArmSME.cpp | 3 +- .../ArmNeon2dToIntr/ArmNeon2dToIntr.cpp | 3 +- .../GPUToNVVM/LowerGpuOpsToNVVMOps.cpp | 2 +- .../GPUToROCDL/LowerGpuOpsToROCDLOps.cpp | 2 +- mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp | 3 +- .../ConvertShapeConstraints.cpp | 2 +- .../VectorToArmSME/VectorToArmSMEPass.cpp | 2 +- .../Conversion/VectorToGPU/VectorToGPU.cpp | 3 +- .../VectorToLLVM/ConvertVectorToLLVMPass.cpp | 2 +- .../Conversion/VectorToSCF/VectorToSCF.cpp | 6 +- .../VectorToXeGPU/VectorToXeGPU.cpp | 3 +- .../TransformOps/AffineTransformOps.cpp | 2 +- .../Transforms/AffineDataCopyGeneration.cpp | 2 +- .../Transforms/AffineExpandIndexOps.cpp | 3 +- .../AffineExpandIndexOpsAsAffine.cpp | 3 +- .../Transforms/SimplifyAffineStructures.cpp | 2 +- mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 4 +- mlir/lib/Dialect/Affine/Utils/Utils.cpp | 6 +- .../Transforms/IntRangeOptimizations.cpp | 4 +- .../ArmSME/Transforms/OuterProductFusion.cpp | 3 +- .../Transforms/LegalizeVectorStorage.cpp | 3 +- .../Async/Transforms/AsyncParallelFor.cpp | 2 +- .../BufferDeallocationSimplification.cpp | 4 +- .../Transforms/EmptyTensorToAllocTensor.cpp | 2 +- .../EmitC/Transforms/FormExpressions.cpp | 2 +- .../GPU/Transforms/DecomposeMemRefs.cpp | 3 +- .../GPU/Transforms/EliminateBarriers.cpp | 2 +- .../LLVMIR/Transforms/OptimizeForNVVM.cpp | 2 +- .../TransformOps/LinalgTransformOps.cpp | 2 +- .../Linalg/Transforms/BlockPackMatmul.cpp | 2 +- .../Dialect/Linalg/Transforms/Detensorize.cpp | 3 +- .../Linalg/Transforms/DropUnitDims.cpp | 2 +- .../Linalg/Transforms/ElementwiseOpFusion.cpp | 2 +- .../Linalg/Transforms/Generalization.cpp | 2 +- .../Transforms/InlineScalarOperands.cpp | 2 +- mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 2 +- .../Linalg/Transforms/NamedOpConversions.cpp | 2 +- .../Dialect/Linalg/Transforms/Specialize.cpp | 2 +- .../Dialect/Math/Transforms/UpliftToFMA.cpp | 3 +- .../Transforms/ExpandStridedMetadata.cpp | 2 +- .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 2 +- .../ResolveShapedTypeResultDims.cpp | 4 +- .../lib/Dialect/SCF/Transforms/ForToWhile.cpp | 2 +- .../SCF/Transforms/LoopCanonicalization.cpp | 2 +- .../SCF/Transforms/LoopSpecialization.cpp | 2 +- .../SCF/Transforms/TileUsingInterface.cpp | 2 +- .../SPIRV/Transforms/CanonicalizeGLPass.cpp | 3 +- .../SPIRV/Transforms/SPIRVConversion.cpp | 8 +-- .../Transforms/SPIRVWebGPUTransforms.cpp | 3 +- .../Transforms/OutlineShapeComputation.cpp | 4 +- .../Transforms/RemoveShapeConstraints.cpp | 2 +- .../Transforms/SparseTensorPasses.cpp | 20 +++--- .../Tensor/Transforms/FoldTensorSubsetOps.cpp | 2 +- .../TosaLayerwiseConstantFoldPass.cpp | 2 +- .../Tosa/Transforms/TosaMakeBroadcastable.cpp | 2 +- .../Transforms/TosaOptionalDecompositions.cpp | 2 +- .../lib/Dialect/Transform/IR/TransformOps.cpp | 4 +- .../Vector/Transforms/LowerVectorMask.cpp | 2 +- .../Transforms/LowerVectorMultiReduction.cpp | 2 +- .../XeGPU/Transforms/XeGPUFoldAliasOps.cpp | 2 +- mlir/lib/Reducer/ReductionTreePass.cpp | 2 +- mlir/lib/Transforms/Canonicalizer.cpp | 2 +- .../Utils/GreedyPatternRewriteDriver.cpp | 16 ++--- .../Transforms/Utils/OneToNTypeConversion.cpp | 4 +- .../Transforms/test-operation-folder.mlir | 48 ++++++++++++- .../MathToVCIX/TestMathToVCIXConversion.cpp | 2 +- .../TestVectorReductionToSPIRVDotProd.cpp | 2 +- .../lib/Dialect/Affine/TestAffineDataCopy.cpp | 2 +- .../Dialect/ArmNeon/TestLowerToArmNeon.cpp | 2 +- mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp | 4 +- .../Linalg/TestDataLayoutPropagation.cpp | 3 +- .../Dialect/Linalg/TestLinalgDecomposeOps.cpp | 4 +- .../Linalg/TestLinalgElementwiseFusion.cpp | 28 ++++---- .../Linalg/TestLinalgFusionTransforms.cpp | 2 +- .../TestLinalgRankReduceContractionOps.cpp | 3 +- .../Dialect/Linalg/TestLinalgTransforms.cpp | 26 +++---- .../test/lib/Dialect/Linalg/TestPadFusion.cpp | 3 +- .../Math/TestAlgebraicSimplification.cpp | 2 +- mlir/test/lib/Dialect/Math/TestExpandMath.cpp | 2 +- .../Math/TestPolynomialApproximation.cpp | 2 +- .../lib/Dialect/MemRef/TestComposeSubView.cpp | 2 +- mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp | 8 +-- .../Mesh/TestReshardingSpmdization.cpp | 4 +- .../lib/Dialect/Mesh/TestSimplifications.cpp | 2 +- .../lib/Dialect/NVGPU/TestNVGPUTransforms.cpp | 2 +- mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp | 2 +- .../SCF/TestSCFWrapInZeroTripCheck.cpp | 2 +- .../lib/Dialect/SCF/TestUpliftWhileToFor.cpp | 2 +- .../Dialect/Tensor/TestTensorTransforms.cpp | 16 ++--- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 16 +++-- mlir/test/lib/Dialect/Test/TestTraits.cpp | 4 +- mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp | 2 +- .../Dialect/Vector/TestVectorTransforms.cpp | 42 +++++------ mlir/test/lib/Rewrite/TestPDLByteCode.cpp | 4 +- mlir/test/lib/Tools/PDLL/TestPDLL.cpp | 2 +- .../lib/Transforms/TestCommutativityUtils.cpp | 2 +- .../Transforms/TestMakeIsolatedFromAbove.cpp | 6 +- 110 files changed, 313 insertions(+), 246 deletions(-) diff --git a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp index 769e14b1316d..b68fe6ee0c74 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/InlineElementals.cpp @@ -125,7 +125,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.insert(context); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), "failure in HLFIR elemental inlining"); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp index 36fae90c83fd..091ed7ed999d 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp @@ -520,8 +520,8 @@ public: config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( - module, std::move(patterns), config))) { + if (mlir::failed( + mlir::applyPatternsGreedily(module, std::move(patterns), config))) { mlir::emitError(mlir::UnknownLoc::get(context), "failure in HLFIR intrinsic lowering"); signalPassFailure(); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp index c152c27c0a05..bf3cf861e46f 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp @@ -1372,7 +1372,7 @@ public: // patterns.insert>(context); // patterns.insert>(context); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), "failure in HLFIR optimized bufferization"); diff --git a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp index 28325bc8e548..bf3d261e7e88 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp @@ -491,7 +491,7 @@ public: patterns.insert(context); patterns.insert(context); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( + if (mlir::failed(mlir::applyPatternsGreedily( getOperation(), std::move(patterns), config))) { mlir::emitError(getOperation()->getLoc(), "failure in HLFIR intrinsic simplification"); diff --git a/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp b/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp index fd58375da618..fab1f0299ede 100644 --- a/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp +++ b/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp @@ -39,8 +39,7 @@ struct AlgebraicSimplification void AlgebraicSimplification::runOnOperation() { RewritePatternSet patterns(&getContext()); populateMathAlgebraicSimplificationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); } std::unique_ptr fir::createAlgebraicSimplificationPass() { diff --git a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp index 2c9c73e8a539..eb59045a5fde 100644 --- a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp @@ -154,7 +154,7 @@ public: mlir::GreedyRewriteConfig config; config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - (void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config); + (void)applyPatternsGreedily(mod, std::move(patterns), config); } }; } // namespace diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp index eef6f047fc1b..562f3058f20f 100644 --- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp +++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp @@ -173,8 +173,8 @@ public: config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps; patterns.insert(context, *di); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( - mod, std::move(patterns), config))) { + if (mlir::failed( + mlir::applyPatternsGreedily(mod, std::move(patterns), config))) { mlir::emitError(mod.getLoc(), "error in constant globalisation optimization\n"); signalPassFailure(); diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index 0c474f463f09..f9281000d21f 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -793,8 +793,8 @@ void StackArraysPass::runOnOperation() { config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; patterns.insert(&context, *candidateOps); - if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, - std::move(patterns), config))) { + if (mlir::failed(mlir::applyOpPatternsGreedily( + opsToConvert, std::move(patterns), config))) { mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); signalPassFailure(); } diff --git a/mlir/docs/PatternRewriter.md b/mlir/docs/PatternRewriter.md index da392b828933..d15e7e5a8067 100644 --- a/mlir/docs/PatternRewriter.md +++ b/mlir/docs/PatternRewriter.md @@ -358,7 +358,7 @@ which point the driver finishes. This driver comes in two fashions: -* `applyPatternsAndFoldGreedily` ("region-based driver") applies patterns to +* `applyPatternsGreedily` ("region-based driver") applies patterns to all ops in a given region or a given container op (but not the container op itself). I.e., the worklist is initialized with all containing ops. * `applyOpPatternsAndFold` ("op-based driver") applies patterns to the diff --git a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp index 8166aa238bf2..8c79a0753793 100644 --- a/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp +++ b/mlir/examples/standalone/lib/Standalone/StandalonePasses.cpp @@ -39,7 +39,7 @@ public: RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); FrozenRewritePatternSet patternSet(std::move(patterns)); - if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) + if (failed(applyPatternsGreedily(getOperation(), patternSet))) signalPassFailure(); } }; diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index eaff85804f6b..110b4f64856e 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -91,6 +91,13 @@ public: /// An optional listener that should be notified about IR modifications. RewriterBase::Listener *listener = nullptr; + + /// Whether this should fold while greedily rewriting. + bool fold = true; + + /// If set to "true", constants are CSE'd (even across multiple regions that + /// are in a parent-ancestor relationship). + bool cseConstants = true; }; //===----------------------------------------------------------------------===// @@ -104,8 +111,8 @@ public: /// The greedy rewrite may prematurely stop after a maximum number of /// iterations, which can be configured in the configuration parameter. /// -/// Also performs folding and simple dead-code elimination before attempting to -/// match any of the provided patterns. +/// Also performs simple dead-code elimination before attempting to match any of +/// the provided patterns. /// /// A region scope can be set in the configuration parameter. By default, the /// scope is set to the specified region. Only in-scope ops are added to the @@ -117,10 +124,20 @@ public: /// /// Note: This method does not apply patterns to the region's parent operation. LogicalResult +applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr); +/// Same as `applyPatternsAndGreedily` above with folding. +/// FIXME: Remove this once transition to above is complieted. +LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily") +inline LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig(), - bool *changed = nullptr); + bool *changed = nullptr) { + config.fold = true; + return applyPatternsGreedily(region, patterns, config, changed); +} /// Rewrite ops nested under the given operation, which must be isolated from /// above, by repeatedly applying the highest benefit patterns in a greedy @@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region ®ion, /// The greedy rewrite may prematurely stop after a maximum number of /// iterations, which can be configured in the configuration parameter. /// -/// Also performs folding and simple dead-code elimination before attempting to -/// match any of the provided patterns. +/// Also performs simple dead-code elimination before attempting to match any of +/// the provided patterns. /// /// This overload runs a separate greedy rewrite for each region of the /// specified op. A region scope can be set in the configuration parameter. By @@ -147,23 +164,32 @@ applyPatternsAndFoldGreedily(Region ®ion, /// /// Note: This method does not apply patterns to the given operation itself. inline LogicalResult -applyPatternsAndFoldGreedily(Operation *op, - const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config = GreedyRewriteConfig(), - bool *changed = nullptr) { +applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr) { bool anyRegionChanged = false; bool failed = false; for (Region ®ion : op->getRegions()) { bool regionChanged; - failed |= - applyPatternsAndFoldGreedily(region, patterns, config, ®ionChanged) - .failed(); + failed |= applyPatternsGreedily(region, patterns, config, ®ionChanged) + .failed(); anyRegionChanged |= regionChanged; } if (changed) *changed = anyRegionChanged; return failure(failed); } +/// Same as `applyPatternsGreedily` above with folding. +/// FIXME: Remove this once transition to above is complieted. +LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily") +inline LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr) { + config.fold = true; + return applyPatternsGreedily(op, patterns, config, changed); +} /// Rewrite the specified ops by repeatedly applying the highest benefit /// patterns in a greedy worklist driven manner until a fixpoint is reached. @@ -171,8 +197,8 @@ applyPatternsAndFoldGreedily(Operation *op, /// The greedy rewrite may prematurely stop after a maximum number of /// iterations, which can be configured in the configuration parameter. /// -/// Also performs folding and simple dead-code elimination before attempting to -/// match any of the provided patterns. +/// Also performs simple dead-code elimination before attempting to match any of +/// the provided patterns. /// /// Newly created ops and other pre-existing ops that use results of rewritten /// ops or supply operands to such ops are also processed, unless such ops are @@ -180,24 +206,36 @@ applyPatternsAndFoldGreedily(Operation *op, /// regardless of `strictMode`). /// /// In addition to strictness, a region scope can be specified. Only ops within -/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`, +/// the scope are simplified. This is similar to `applyPatternsGreedily`, /// where only ops within the given region/op are simplified by default. If no /// scope is specified, it is assumed to be the first common enclosing region of /// the given ops. /// /// Note that ops in `ops` could be erased as result of folding, becoming dead, /// or via pattern rewrites. If more far reaching simplification is desired, -/// `applyPatternsAndFoldGreedily` should be used. +/// `applyPatternsGreedily` should be used. /// /// Returns "success" if the iterative process converged (i.e., fixpoint was /// reached) and no more patterns can be matched. `changed` is set to "true" if /// the IR was modified at all. `allOpsErased` is set to "true" if all ops in /// `ops` were erased. LogicalResult +applyOpPatternsGreedily(ArrayRef ops, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + bool *changed = nullptr, bool *allErased = nullptr); +/// Same as `applyOpPatternsGreedily` with folding. +/// FIXME: Remove this once transition to above is complieted. +LLVM_DEPRECATED("Use applyOpPatternsGreedily() instead", + "applyOpPatternsGreedily") +inline LogicalResult applyOpPatternsAndFold(ArrayRef ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config = GreedyRewriteConfig(), - bool *changed = nullptr, bool *allErased = nullptr); + bool *changed = nullptr, bool *allErased = nullptr) { + config.fold = true; + return applyOpPatternsGreedily(ops, patterns, config, changed, allErased); +} } // namespace mlir diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 379f09cf5cc2..c4717ca61331 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -289,8 +289,7 @@ MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig) { - return wrap( - mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns))); + return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 6b9cbaf57676..a8283023afc5 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -385,6 +385,6 @@ void ArithToAMDGPUConversionPass::runOnOperation() { arith::populateArithToAMDGPUConversionPatterns( patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz, *maybeChipset); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp index 5aa2a098b176..cbe0b3fda341 100644 --- a/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp +++ b/mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp @@ -117,8 +117,7 @@ struct ArithToArmSMEConversionPass final void runOnOperation() override { RewritePatternSet patterns(&getContext()); arith::populateArithToArmSMEConversionPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index bdbf276d79b2..de8bfd6a1710 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -59,8 +59,7 @@ class ConvertArmNeon2dToIntr RewritePatternSet patterns(context); populateConvertArmNeon2dToIntrPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index b343cf71e3a2..e022d3ce6f63 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -271,7 +271,7 @@ struct LowerGpuOpsToNVVMOpsPass { RewritePatternSet patterns(m.getContext()); populateGpuRewritePatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns)))) + if (failed(applyPatternsGreedily(m, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index aa4d3b70329f..d52a86987b1c 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -271,7 +271,7 @@ struct LowerGpuOpsToROCDLOpsPass RewritePatternSet patterns(ctx); populateGpuRewritePatterns(patterns); arith::populateExpandBFloat16Patterns(patterns); - (void)applyPatternsAndFoldGreedily(m, std::move(patterns)); + (void)applyPatternsGreedily(m, std::move(patterns)); } LLVMTypeConverter converter(ctx, options); diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp index 6dd89ecf4d5c..e1de125ccaed 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp @@ -427,8 +427,7 @@ struct ConvertMeshToMPIPass ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>( ctx); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)); + (void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp index 7df1407da6f9..d92027a5e3d4 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp @@ -62,7 +62,7 @@ class ConvertShapeConstraints RewritePatternSet patterns(context); populateConvertShapeConstraintsConversionPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns)))) + if (failed(applyPatternsGreedily(func, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp index cc00bf4ca190..7419276651ae 100644 --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp @@ -33,7 +33,7 @@ void ConvertVectorToArmSMEPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateVectorToArmSMEPatterns(patterns, getContext()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } std::unique_ptr mlir::createConvertVectorToArmSMEPass() { diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 034f3e2d16e9..5b4414d67fda 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -1326,8 +1326,7 @@ struct ConvertVectorToGPUPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue()); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); IRRewriter rewriter(&getContext()); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 2d94c2f2e85a..2c4c5ada9815 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -82,7 +82,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorInsertExtractStridedSliceTransforms(patterns); populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } // Convert to the LLVM IR dialect. diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 3a4dc806efe9..01bc65c841e9 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -1730,12 +1730,12 @@ struct ConvertVectorToSCFPass RewritePatternSet lowerTransferPatterns(&getContext()); mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( lowerTransferPatterns); - (void)applyPatternsAndFoldGreedily(getOperation(), - std::move(lowerTransferPatterns)); + (void)applyPatternsGreedily(getOperation(), + std::move(lowerTransferPatterns)); RewritePatternSet patterns(&getContext()); populateVectorToSCFConversionPatterns(patterns, options); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 1232d8795d4d..8041bdf7da19 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -318,8 +318,7 @@ struct ConvertVectorToXeGPUPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorToXeGPUConversionPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index eb5229794072..9f7df7823d99 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -132,7 +132,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, static_cast(rewriter.getListener()); config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. - if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { + if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) { auto diag = emitDefiniteFailure() << "affine.min/max simplification did not converge"; return diag; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index 331b0f1b2c2b..9ffe54f61ebb 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -239,5 +239,5 @@ void AffineDataCopyGeneration::runOnOperation() { FrozenRewritePatternSet frozenPatterns(std::move(patterns)); GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - (void)applyOpPatternsAndFold(copyOps, frozenPatterns, config); + (void)applyOpPatternsGreedily(copyOps, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index d7b218225bc9..7e335ea929c4 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -198,8 +198,7 @@ public: MLIRContext *context = &getContext(); RewritePatternSet patterns(context); populateAffineExpandIndexOpsPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp index bfcc1ddf9165..16ba16d5c798 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp @@ -79,8 +79,7 @@ public: MLIRContext *context = &getContext(); RewritePatternSet patterns(context); populateAffineExpandIndexOpsAsAffinePatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp index 49618074ec22..31711ade3153 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -111,5 +111,5 @@ void SimplifyAffineStructures::runOnOperation() { }); GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - (void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config); + (void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index c5cc8bfeb0a6..0f2c889d4f39 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -318,8 +318,8 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp, GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; bool erased; - (void)applyOpPatternsAndFold(res.getOperation(), std::move(patterns), - config, /*changed=*/nullptr, &erased); + (void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns), + config, /*changed=*/nullptr, &erased); if (!erased && !prologue) prologue = res; if (!erased) diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 07d399adae0c..4d3ead20fb5c 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -425,8 +425,8 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; bool erased; - (void)applyOpPatternsAndFold(ifOp.getOperation(), frozenPatterns, config, - /*changed=*/nullptr, &erased); + (void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config, + /*changed=*/nullptr, &erased); if (erased) { if (folded) *folded = true; @@ -454,7 +454,7 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) { // Canonicalize to remove dead else blocks (happens whenever an 'if' moves up // a sequence of affine.fors that are all perfectly nested). - (void)applyPatternsAndFoldGreedily( + (void)applyPatternsGreedily( hoistedIfOp->getParentWithTrait(), frozenPatterns); diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index b54a53f5ef70..5982f5f55549 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -489,7 +489,7 @@ struct IntRangeOptimizationsPass final GreedyRewriteConfig config; config.listener = &listener; - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) signalPassFailure(); } }; @@ -518,7 +518,7 @@ struct IntRangeNarrowingPass final config.useTopDownTraversal = false; config.listener = &listener; - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) + if (failed(applyPatternsGreedily(op, std::move(patterns), config))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index ee1e374b25b0..23f2c2bf65e4 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -523,8 +523,7 @@ struct OuterProductFusionPass RewritePatternSet patterns(&getContext()); populateOuterProductFusionPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index 8b4bacd72271..d2ac850a5f70 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -317,8 +317,7 @@ struct LegalizeVectorStorage void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateLegalizeVectorStoragePatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } ConversionTarget target(getContext()); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index 273101ce5f3e..1320523aa989 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -931,7 +931,7 @@ void AsyncParallelForPass::runOnOperation() { [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { return builder.create(minTaskSize); }); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 5227b22653ee..de3ae82f8708 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -470,8 +470,8 @@ struct BufferDeallocationSimplificationPass config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal; populateDeallocOpCanonicalizationPatterns(patterns, &getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp index 7670220dce77..d20c6966d4eb 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorToAllocTensor.cpp @@ -60,7 +60,7 @@ void EmptyTensorToAllocTensor::runOnOperation() { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); populateEmptyTensorToAllocTensorPattern(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp index 82bd031430d3..338551437580 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/FormExpressions.cpp @@ -47,7 +47,7 @@ struct FormExpressionsPass RewritePatternSet patterns(context); populateExpressionPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns)))) + if (failed(applyPatternsGreedily(rootOp, std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp index 004d73a77e53..a504101fb3f2 100644 --- a/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/DecomposeMemRefs.cpp @@ -227,8 +227,7 @@ struct GpuDecomposeMemrefsPass populateGpuDecomposeMemrefsPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index 0ffd8131b893..2178555cb62f 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -630,7 +630,7 @@ class GpuEliminateBarriersPass auto funcOp = getOperation(); RewritePatternSet patterns(&getContext()); mlir::populateGpuEliminateBarriersPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp index 8c33148d1d2d..c1ec1df48e5b 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp @@ -96,7 +96,7 @@ void NVVMOptimizeForTarget::runOnOperation() { MLIRContext *ctx = getOperation()->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 18fd24da395b..221ca27b80fd 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3511,7 +3511,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne( TrackingListener listener(state, *this); GreedyRewriteConfig config; config.listener = &listener; - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config))) + if (failed(applyPatternsGreedily(target, std::move(patterns), config))) return emitDefaultDefiniteFailure(target); results.push_back(target); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index 91d4efa3372b..57344f986480 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -301,7 +301,7 @@ struct LinalgBlockPackMatmul }; linalg::populateBlockPackMatmulPatterns(patterns, controlFn); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index af3848529118..0e651f4cee4c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -563,8 +563,7 @@ struct LinalgDetensorize RewritePatternSet canonPatterns(context); tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(canonPatterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns)))) signalPassFailure(); // Get rid of the dummy entry block we created in the beginning to work diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index bb5034759691..9b97865990bf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -831,7 +831,7 @@ struct LinalgFoldUnitExtentDimsPass } linalg::populateFoldUnitExtentDimsPatterns(patterns, options); populateMoveInitOperandsToInputPattern(patterns); - (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); + (void)applyPatternsGreedily(op, std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index efc7934bc7d8..3a57f368d442 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2206,7 +2206,7 @@ struct LinalgElementwiseOpFusionPass // Use TopDownTraversal for compile time reasons GreedyRewriteConfig grc; grc.useTopDownTraversal = true; - (void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc); + (void)applyPatternsGreedily(op, std::move(patterns), grc); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 7ab3fef5dd03..78cee47c497e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -89,7 +89,7 @@ struct LinalgGeneralizeNamedOpsPass void LinalgGeneralizeNamedOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateLinalgNamedOpsGeneralizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp index 2a1445fb92fd..1f3336d2bfbb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -113,7 +113,7 @@ struct LinalgInlineScalarOperandsPass MLIRContext &ctx = getContext(); RewritePatternSet patterns(&ctx); populateInlineConstantOperandsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(op, std::move(patterns)); + (void)applyPatternsGreedily(op, std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 20a99491b664..984f3f5a34ab 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -321,7 +321,7 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) { affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); // Just apply the patterns greedily. - (void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns)); + (void)applyPatternsGreedily(enclosingOp, std::move(patterns)); } struct LowerToAffineLoops diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp index 84bde1bc0b84..bb1e97439187 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -152,7 +152,7 @@ struct LinalgNamedOpConversionPass Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); populateLinalgNamedOpConversionPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 748e2a137793..512fb7555a6b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -349,7 +349,7 @@ void LinalgSpecializeGenericOpsPass::runOnOperation() { populateLinalgGenericOpsSpecializationPatterns(patterns); populateDecomposeProjectedPermutationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp index 6b0d0f5e7466..de950bac819c 100644 --- a/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp +++ b/mlir/lib/Dialect/Math/Transforms/UpliftToFMA.cpp @@ -66,8 +66,7 @@ struct MathUpliftToFMA final void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateUpliftToFMAPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 92592d2345d7..aa008f8407b5 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -1223,7 +1223,7 @@ struct ExpandStridedMetadataPass final void ExpandStridedMetadataPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateExpandStridedMetadataPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } std::unique_ptr memref::createExpandStridedMetadataPass() { diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 96daf4c5972a..8e927a60087f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -857,7 +857,7 @@ struct FoldMemRefAliasOpsPass final void FoldMemRefAliasOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateFoldMemRefAliasOpPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } std::unique_ptr memref::createFoldMemRefAliasOpsPass() { diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 792e72291830..dfcbaeb15ae5 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -195,7 +195,7 @@ void memref::populateResolveShapedTypeResultDimsPatterns( void ResolveRankedShapeTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } @@ -203,7 +203,7 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); memref::populateResolveShapedTypeResultDimsPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index 9f8189ae15e6..3e93dc80b18e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -112,7 +112,7 @@ struct ForToWhileLoop : public impl::SCFForToWhileLoopBase { MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns)); + (void)applyPatternsGreedily(parentOp, std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp index c6d024c462e8..4ebd90dbcc1d 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp @@ -167,7 +167,7 @@ struct SCFForLoopCanonicalization MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); scf::populateSCFForLoopCanonicalizationPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns)))) + if (failed(applyPatternsGreedily(parentOp, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index 5104ad4b3a30..b71ec985fa6a 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -331,7 +331,7 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase { MLIRContext *ctx = parentOp->getContext(); RewritePatternSet patterns(ctx); patterns.add(ctx, peelFront, skipPartial); - (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns)); + (void)applyPatternsGreedily(parentOp, std::move(patterns)); // Drop the markers. parentOp->walk([](Operation *op) { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index ef5d4370e781..90db42d479a1 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1430,7 +1430,7 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { GreedyRewriteConfig config; config.listener = this; config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - return applyOpPatternsAndFold(ops, patterns.value(), config); + return applyOpPatternsGreedily(ops, patterns.value(), config); } void SliceTrackingListener::notifyOperationInserted( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLPass.cpp index 374c205897c8..cc59c2116ed3 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/CanonicalizeGLPass.cpp @@ -29,8 +29,7 @@ public: void runOnOperation() override { RewritePatternSet patterns(&getContext()); spirv::populateSPIRVGLCanonicalizationPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 877ac87fb0fe..29f7e8afe077 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -1354,7 +1354,7 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) { // looking for newly created func ops. GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; - return applyPatternsAndFoldGreedily(op, std::move(patterns), config); + return applyPatternsGreedily(op, std::move(patterns), config); } LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { @@ -1366,7 +1366,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { auto options = vector::UnrollVectorOptions().setNativeShapeFn( [](auto op) { return mlir::spirv::getNativeVectorShape(op); }); populateVectorUnrollPatterns(patterns, options); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); } @@ -1378,7 +1378,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { vector::VectorTransposeLowering::EltWise); vector::populateVectorTransposeLoweringPatterns(patterns, options); vector::populateVectorShapeCastLoweringPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); } @@ -1403,7 +1403,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) { vector::BroadcastOp::getCanonicalizationPatterns(patterns, context); vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) return failure(); } return success(); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp index d75c8552c9ad..af1cf2a1373e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVWebGPUTransforms.cpp @@ -236,8 +236,7 @@ struct WebGPUPreparePass final populateSPIRVExpandExtendedMultiplicationPatterns(patterns); populateSPIRVExpandNonFiniteArithmeticPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp index 655555f88354..e56742d52e13 100644 --- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -207,7 +207,7 @@ void OutlineShapeComputationPass::runOnOperation() { MLIRContext *context = funcOp.getContext(); RewritePatternSet prevPatterns(context); prevPatterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns)))) + if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns)))) return signalPassFailure(); // initialize class member `onlyUsedByWithShapes` @@ -254,7 +254,7 @@ void OutlineShapeComputationPass::runOnOperation() { } // Apply patterns, note this also performs DCE. - if (failed(applyPatternsAndFoldGreedily(funcOp, {}))) + if (failed(applyPatternsGreedily(funcOp, {}))) return signalPassFailure(); }); } diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp index e1cccd8fd5d6..d2b245f832e5 100644 --- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp @@ -55,7 +55,7 @@ class RemoveShapeConstraintsPass RewritePatternSet patterns(&ctx); populateRemoveShapeConstraintsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp index 8004bdb904b8..1cac949b68c7 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp @@ -57,7 +57,7 @@ struct SparseAssembler : public impl::SparseAssemblerBase { auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateSparseAssembler(patterns, directOut); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -73,7 +73,7 @@ struct SparseReinterpretMap auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateSparseReinterpretMap(patterns, scope); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -87,7 +87,7 @@ struct PreSparsificationRewritePass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populatePreSparsificationRewriting(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -110,7 +110,7 @@ struct SparsificationPass RewritePatternSet patterns(ctx); populateSparsificationPatterns(patterns, options); scf::ForOp::getCanonicalizationPatterns(patterns, ctx); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -122,7 +122,7 @@ struct StageSparseOperationsPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateStageSparseOperationsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -141,7 +141,7 @@ struct LowerSparseOpsToForeachPass RewritePatternSet patterns(ctx); populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary, enableConvert); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -154,7 +154,7 @@ struct LowerForeachToSCFPass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateLowerForeachToSCFPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -329,7 +329,7 @@ struct SparseBufferRewritePass auto *ctx = &getContext(); RewritePatternSet patterns(ctx); populateSparseBufferRewriting(patterns, enableBufferInitialization); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -351,7 +351,7 @@ struct SparseVectorizationPass populateSparseVectorizationPatterns( patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32); vector::populateVectorToVectorCanonicalizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -371,7 +371,7 @@ struct SparseGPUCodegenPass populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary); else populateSparseGPUCodegenPatterns(patterns, numThreads); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp index 0f5fa61879b7..998b0fb6eb4b 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp @@ -277,7 +277,7 @@ struct FoldTensorSubsetOpsPass final void FoldTensorSubsetOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); tensor::populateFoldTensorSubsetOpPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } std::unique_ptr tensor::createFoldTensorSubsetOpsPass() { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp index e1400f0c907b..9299db7e51a0 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp @@ -60,7 +60,7 @@ struct TosaLayerwiseConstantFoldPass aggressiveReduceConstant); populateTosaOpsCanonicalizationPatterns(ctx, patterns); - if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) + if (applyPatternsGreedily(func, std::move(patterns)).failed()) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp index 9c6ee4c62eee..2a990eed3f68 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -246,7 +246,7 @@ public: patterns.add>(ctx); patterns.add>(ctx); patterns.add>(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } }; } // namespace diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp index cef903a39e45..603185e48aa9 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaOptionalDecompositions.cpp @@ -42,7 +42,7 @@ struct TosaOptionalDecompositions mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns); mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns); - if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed()) + if (applyPatternsGreedily(func, std::move(patterns)).failed()) signalPassFailure(); } }; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 1f0f183e29f9..106a79473509 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -417,7 +417,7 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( if (target->hasTrait()) { // Op is isolated from above. Apply patterns and also perform region // simplification. - result = applyPatternsAndFoldGreedily(target, frozenPatterns, config); + result = applyPatternsGreedily(target, frozenPatterns, config); } else { // Manually gather list of ops because the other // GreedyPatternRewriteDriver overloads only accepts ops that are isolated @@ -429,7 +429,7 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( if (target != nestedOp) ops.push_back(nestedOp); }); - result = applyOpPatternsAndFold(ops, frozenPatterns, config); + result = applyOpPatternsGreedily(ops, frozenPatterns, config); } // A failure typically indicates that the pattern application did not diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index bfc05c71f534..1f6cac2aa6f9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -286,7 +286,7 @@ struct LowerVectorMaskPass populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns); MaskOp::getCanonicalizationPatterns(loweringPatterns, context); - if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) + if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 72bf329daaa7..0cafc9cd3551 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -486,7 +486,7 @@ struct LowerVectorMultiReductionPass populateVectorMultiReductionLoweringPatterns(loweringPatterns, this->loweringStrategy); - if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns)))) + if (failed(applyPatternsGreedily(op, std::move(loweringPatterns)))) signalPassFailure(); } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp index 9307e8eb784b..e3082c55427f 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUFoldAliasOps.cpp @@ -78,5 +78,5 @@ struct XeGPUFoldAliasOpsPass final void XeGPUFoldAliasOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); xegpu::populateXeGPUFoldAliasOpsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Reducer/ReductionTreePass.cpp b/mlir/lib/Reducer/ReductionTreePass.cpp index b00045a3a41b..2d2744bfc273 100644 --- a/mlir/lib/Reducer/ReductionTreePass.cpp +++ b/mlir/lib/Reducer/ReductionTreePass.cpp @@ -65,7 +65,7 @@ static void applyPatterns(Region ®ion, // because we don't have expectation this reduction will be success or not. GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingOps; - (void)applyOpPatternsAndFold(op, patterns, config); + (void)applyOpPatternsGreedily(op, patterns, config); } if (eraseOpNotInRange) diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index d50019bd6aee..5f4696050703 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -60,7 +60,7 @@ struct Canonicalizer : public impl::CanonicalizerBase { } void runOnOperation() override { LogicalResult converged = - applyPatternsAndFoldGreedily(getOperation(), *patterns, config); + applyPatternsGreedily(getOperation(), *patterns, config); // Canonicalization is best-effort. Non-convergence is not a pass failure. if (testConvergence && failed(converged)) signalPassFailure(); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index e0d0acd122e2..99f3569b767b 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements mlir::applyPatternsAndFoldGreedily. +// This file implements mlir::applyPatternsGreedily. // //===----------------------------------------------------------------------===// @@ -488,7 +488,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { // infinite folding loop, as every constant op would be folded to an // Attribute and then immediately be rematerialized as a constant op, which // is then put on the worklist. - if (!op->hasTrait()) { + if (config.fold && !op->hasTrait()) { SmallVector foldResults; if (succeeded(op->fold(foldResults))) { LLVM_DEBUG(logResultWithLine("success", "operation was folded")); @@ -852,13 +852,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { if (!config.useTopDownTraversal) { // Add operations to the worklist in postorder. region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) + if (!config.cseConstants || !insertKnownConstant(op)) addToWorklist(op); }); } else { // Add all nested operations to the worklist in preorder. region.walk([&](Operation *op) { - if (!insertKnownConstant(op)) { + if (!config.cseConstants || !insertKnownConstant(op)) { addToWorklist(op); return WalkResult::advance(); } @@ -894,9 +894,9 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { } LogicalResult -mlir::applyPatternsAndFoldGreedily(Region ®ion, - const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config, bool *changed) { +mlir::applyPatternsGreedily(Region ®ion, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config, bool *changed) { // The top-level operation must be known to be isolated from above to // prevent performing canonicalizations on operations defined at or above // the region containing 'op'. @@ -1012,7 +1012,7 @@ static Region *findCommonAncestor(ArrayRef ops) { return region; } -LogicalResult mlir::applyOpPatternsAndFold( +LogicalResult mlir::applyOpPatternsGreedily( ArrayRef ops, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config, bool *changed, bool *allErased) { if (ops.empty()) { diff --git a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp index c208716891ef..6474c59595eb 100644 --- a/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp +++ b/mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp @@ -296,7 +296,7 @@ OneToNConversionPattern::matchAndRewrite(Operation *op, namespace mlir { // This function applies the provided patterns using -// `applyPatternsAndFoldGreedily` and then replaces all newly inserted +// `applyPatternsGreedily` and then replaces all newly inserted // `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts // from target to source types inserted by a `OneToNConversionPattern` normally // fold away with the "forward" casts from source to target types inserted by @@ -317,7 +317,7 @@ applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter, #endif // NDEBUG // Apply provided conversion patterns. - if (failed(applyPatternsAndFoldGreedily(op, patterns))) { + if (failed(applyPatternsGreedily(op, patterns))) { emitError(op->getLoc()) << "failed to apply conversion patterns"; return failure(); } diff --git a/mlir/test/Transforms/test-operation-folder.mlir b/mlir/test/Transforms/test-operation-folder.mlir index 3c0cd15dc6c5..86ed6c25a227 100644 --- a/mlir/test/Transforms/test-operation-folder.mlir +++ b/mlir/test/Transforms/test-operation-folder.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s // RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s +// RUN: mlir-opt -test-greedy-patterns='cse-constants=false' %s | FileCheck %s --check-prefix=NOCSE +// RUN: mlir-opt -test-greedy-patterns='fold=false' %s | FileCheck %s --check-prefix=NOFOLD func.func @foo() -> i32 { %c42 = arith.constant 42 : i32 @@ -25,7 +27,8 @@ func.func @test_fold_before_previously_folded_op() -> (i32, i32) { } func.func @test_dont_reorder_constants() -> (i32, i32, i32) { - // Test that we don't reorder existing constants during folding if it isn't necessary. + // Test that we don't reorder existing constants during folding if it isn't + // necessary. // CHECK: %[[CST:.+]] = arith.constant 1 // CHECK-NEXT: %[[CST:.+]] = arith.constant 2 // CHECK-NEXT: %[[CST:.+]] = arith.constant 3 @@ -34,3 +37,46 @@ func.func @test_dont_reorder_constants() -> (i32, i32, i32) { %2 = arith.constant 3 : i32 return %0, %1, %2 : i32, i32, i32 } + +// CHECK-LABEL: test_fold_nofold_nocse +// NOCSE-LABEL: test_fold_nofold_nocse +// NOFOLD-LABEL: test_fold_nofold_nocse +func.func @test_fold_nofold_nocse() -> (i32, i32, i32, i32, i32, i32) { + // Test either not folding or deduping constants. + + // Testing folding. There should be only 4 constants here. + // CHECK-NOT: arith.constant + // CHECK-DAG: %[[CST:.+]] = arith.constant 0 + // CHECK-DAG: %[[CST:.+]] = arith.constant 1 + // CHECK-DAG: %[[CST:.+]] = arith.constant 2 + // CHECK-DAG: %[[CST:.+]] = arith.constant 3 + // CHECK-NOT: arith.constant + // CHECK-NEXT: return + + // Testing not-CSE'ing. In this case we have the 3 original constants and 3 + // produced by folding. + // NOCSE-DAG: arith.constant 0 : i32 + // NOCSE-DAG: arith.constant 1 : i32 + // NOCSE-DAG: arith.constant 2 : i32 + // NOCSE-DAG: arith.constant 1 : i32 + // NOCSE-DAG: arith.constant 2 : i32 + // NOCSE-DAG: arith.constant 3 : i32 + // NOCSE-NEXT: return + + // Testing not folding. In this case we just have the original constants. + // NOFOLD-DAG: %[[CST:.+]] = arith.constant 0 + // NOFOLD-DAG: %[[CST:.+]] = arith.constant 1 + // NOFOLD-DAG: %[[CST:.+]] = arith.constant 2 + // NOFOLD: arith.addi + // NOFOLD: arith.addi + // NOFOLD: arith.addi + + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %0 = arith.addi %c0, %c1 : i32 + %1 = arith.addi %0, %c1 : i32 + %2 = arith.addi %c2, %c1 : i32 + return %0, %1, %2, %c0, %c1, %c2 : i32, i32, i32, i32, i32, i32 +} + diff --git a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp index e17fe12b9088..1e45ab57ebcc 100644 --- a/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp +++ b/mlir/test/lib/Conversion/MathToVCIX/TestMathToVCIXConversion.cpp @@ -248,7 +248,7 @@ struct TestMathToVCIX RewritePatternSet patterns(ctx); patterns.add( ctx); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp index 1864d2f7f503..d49b4e391a68 100644 --- a/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp +++ b/mlir/test/lib/Conversion/VectorToSPIRV/TestVectorReductionToSPIRVDotProd.cpp @@ -41,7 +41,7 @@ struct TestVectorReductionToSPIRVDotProd void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorReductionToSPIRVDotProductPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp index b418a457473a..404f34ebee17 100644 --- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp +++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp @@ -136,7 +136,7 @@ void TestAffineDataCopy::runOnOperation() { } GreedyRewriteConfig config; config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - (void)applyOpPatternsAndFold(copyOps, std::move(patterns), config); + (void)applyOpPatternsGreedily(copyOps, std::move(patterns), config); } namespace mlir { diff --git a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp index f6bfd9f85828..03c80b601a34 100644 --- a/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp +++ b/mlir/test/lib/Dialect/ArmNeon/TestLowerToArmNeon.cpp @@ -47,7 +47,7 @@ void TestLowerToArmNeon::runOnOperation() { MLIRContext *context = &getContext(); RewritePatternSet patterns(context); populateLowerContractionToSMMLAPatternPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp index 74d057c0b7b6..a49d304baf5c 100644 --- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp +++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp @@ -38,7 +38,7 @@ struct TestGpuRewritePass void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateGpuRewritePatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -85,7 +85,7 @@ struct TestGpuSubgroupReduceLoweringPass patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32); } - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; } // namespace diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp index 4cf2460150d1..d0700f9a4f1a 100644 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -34,8 +34,7 @@ struct TestDataLayoutPropagationPass RewritePatternSet patterns(context); linalg::populateDataLayoutPropagationPatterns( patterns, [](OpOperand *opOperand) { return true; }); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp index 311244aeffb9..0143a27bfe84 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp @@ -43,8 +43,8 @@ struct TestLinalgDecomposeOps RewritePatternSet decompositionPatterns(context); linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns, removeDeadArgsAndResults); - if (failed(applyPatternsAndFoldGreedily( - getOperation(), std::move(decompositionPatterns)))) { + if (failed(applyPatternsGreedily(getOperation(), + std::move(decompositionPatterns)))) { return signalPassFailure(); } } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp index 7f68f4aec3a1..e4883e47f206 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -155,8 +155,8 @@ struct TestLinalgElementwiseFusion RewritePatternSet fusionPatterns(context); auto controlFn = [](OpOperand *operand) { return true; }; linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) return signalPassFailure(); return; } @@ -166,8 +166,8 @@ struct TestLinalgElementwiseFusion linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, setFusedOpOperandLimit<4>); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) return signalPassFailure(); return; } @@ -176,8 +176,8 @@ struct TestLinalgElementwiseFusion RewritePatternSet fusionPatterns(context); linalg::populateFoldReshapeOpsByExpansionPatterns( fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; }); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) return signalPassFailure(); return; } @@ -212,8 +212,8 @@ struct TestLinalgElementwiseFusion linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, controlReshapeFusionFn); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(fusionPatterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) return signalPassFailure(); return; } @@ -222,8 +222,7 @@ struct TestLinalgElementwiseFusion RewritePatternSet patterns(context); linalg::populateFoldReshapeOpsByCollapsingPatterns( patterns, [](OpOperand * /*fusedOperand */) { return true; }); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } @@ -239,8 +238,7 @@ struct TestLinalgElementwiseFusion return true; }; linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } @@ -248,8 +246,7 @@ struct TestLinalgElementwiseFusion if (fuseMultiUseProducer) { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } @@ -265,8 +262,7 @@ struct TestLinalgElementwiseFusion }; RewritePatternSet patterns(context); linalg::populateCollapseDimensions(patterns, collapseFn); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp index 2d8ee2f9bb6e..81e7eedabd5d 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -84,7 +84,7 @@ struct TestLinalgGreedyFusion pm.addPass(createCanonicalizerPass()); pm.addPass(createCSEPass()); do { - (void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); + (void)applyPatternsGreedily(getOperation(), frozenPatterns); if (failed(runPipeline(pm, getOperation()))) this->signalPassFailure(); } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp index 8b455d7d68c3..750ba6b5d987 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgRankReduceContractionOps.cpp @@ -49,8 +49,7 @@ struct TestLinalgRankReduceContractionOps RewritePatternSet patterns(context); linalg::populateContractionOpRankReducingPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), - std::move(patterns)))) + if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns)))) return signalPassFailure(); return; } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 25aec75c3c14..fa2a27dcfa99 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -147,14 +147,14 @@ static void applyPatterns(func::FuncOp funcOp) { //===--------------------------------------------------------------------===// patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { RewritePatternSet forwardPattern(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); forwardPattern.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern)); + (void)applyPatternsGreedily(funcOp, std::move(forwardPattern)); } static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { @@ -163,68 +163,68 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { patterns.add(ctx); populatePadOpVectorizationPatterns(patterns); populateConvolutionVectorizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyDecomposePadPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); patterns.add(funcOp.getContext()); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateBubbleUpExtractSliceOpPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateSwapExtractSliceWithFillPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateEraseUnusedOperandsAndResultsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateEraseUnnecessaryInputsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyWinogradConv2D(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3); populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } static void applyDecomposeWinogradOps(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); populateDecomposeWinogradOpsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); } /// Apply transformations specified as patterns. diff --git a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp index 073e0d8d4e14..b927767038a9 100644 --- a/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestPadFusion.cpp @@ -36,8 +36,7 @@ struct TestPadFusionPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } }; diff --git a/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp index 084a59221524..42491d4c716c 100644 --- a/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp +++ b/mlir/test/lib/Dialect/Math/TestAlgebraicSimplification.cpp @@ -40,7 +40,7 @@ struct TestMathAlgebraicSimplificationPass void TestMathAlgebraicSimplificationPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateMathAlgebraicSimplificationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp index 69af2a08b97b..0139eabba373 100644 --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -53,7 +53,7 @@ void TestExpandMathPass::runOnOperation() { populateExpandRoundFPattern(patterns); populateExpandRoundEvenPattern(patterns); populateExpandRsqrtPattern(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp index 8a01ac509c30..9fdd200e2b2c 100644 --- a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp +++ b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp @@ -59,7 +59,7 @@ void TestMathPolynomialApproximationPass::runOnOperation() { MathPolynomialApproximationOptions approxOptions; approxOptions.enableAvx2 = enableAvx2; populateMathPolynomialApproximationPatterns(patterns, approxOptions); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } namespace mlir { diff --git a/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp index 02a9dbbe263f..08d22ab59f94 100644 --- a/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp +++ b/mlir/test/lib/Dialect/MemRef/TestComposeSubView.cpp @@ -38,7 +38,7 @@ void TestComposeSubViewPass::getDependentDialects( void TestComposeSubViewPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateComposeSubViewPatterns(patterns, &getContext()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } } // namespace diff --git a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp index 1f836be1ae7a..dbae93b380f2 100644 --- a/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestOpLowering.cpp @@ -26,9 +26,9 @@ struct TestAllSliceOpLoweringPass SymbolTableCollection symbolTableCollection; mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection); LogicalResult status = - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyPatternsGreedily(getOperation(), std::move(patterns)); (void)status; - assert(succeeded(status) && "applyPatternsAndFoldGreedily failed."); + assert(succeeded(status) && "applyPatternsGreedily failed."); } void getDependentDialects(DialectRegistry ®istry) const override { mesh::registerAllSliceOpLoweringDialects(registry); @@ -51,9 +51,9 @@ struct TestMultiIndexOpLoweringPass mesh::populateProcessMultiIndexOpLoweringPatterns(patterns, symbolTableCollection); LogicalResult status = - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyPatternsGreedily(getOperation(), std::move(patterns)); (void)status; - assert(succeeded(status) && "applyPatternsAndFoldGreedily failed."); + assert(succeeded(status) && "applyPatternsGreedily failed."); } void getDependentDialects(DialectRegistry ®istry) const override { mesh::registerProcessMultiIndexOpLoweringDialects(registry); diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp index 98992c4cc11f..102e64de4bd1 100644 --- a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp @@ -97,8 +97,8 @@ struct TestMeshReshardingPass void runOnOperation() override { RewritePatternSet patterns(&getContext()); patterns.insert(&getContext()); - if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation().getOperation(), + std::move(patterns)))) { return signalPassFailure(); } } diff --git a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp index 512b16af64c9..01e196d29f7a 100644 --- a/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp +++ b/mlir/test/lib/Dialect/Mesh/TestSimplifications.cpp @@ -34,7 +34,7 @@ void TestMeshSimplificationsPass::runOnOperation() { SymbolTableCollection symbolTableCollection; mesh::populateSimplificationPatterns(patterns, symbolTableCollection); [[maybe_unused]] LogicalResult status = - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyPatternsGreedily(getOperation(), std::move(patterns)); assert(succeeded(status) && "Rewrite patters application did not converge."); } diff --git a/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp b/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp index 8ca29257b812..0099dc8caf42 100644 --- a/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp +++ b/mlir/test/lib/Dialect/NVGPU/TestNVGPUTransforms.cpp @@ -60,7 +60,7 @@ struct TestMmaSyncF32ToTF32Patterns RewritePatternSet patterns(&getContext()); populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp index a3be1f94fa28..b4f3fa30f8ab 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp @@ -226,7 +226,7 @@ struct TestSCFPipeliningPass options.peelEpilogue = false; } scf::populateSCFLoopPipeliningPatterns(patterns, options); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); getOperation().walk([](Operation *op) { // Clean up the markers. op->removeAttr(kTestPipeliningStageMarker); diff --git a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp index 7e51d67702b0..856cde19edd5 100644 --- a/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp +++ b/mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp @@ -59,7 +59,7 @@ struct TestWrapWhileLoopInZeroTripCheckPass } else { RewritePatternSet patterns(context); scf::populateSCFRotateWhileLoopPatterns(patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } diff --git a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp index 468bc0ca7848..cf123fe28024 100644 --- a/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp +++ b/mlir/test/lib/Dialect/SCF/TestUpliftWhileToFor.cpp @@ -34,7 +34,7 @@ struct TestSCFUpliftWhileToFor MLIRContext *ctx = op->getContext(); RewritePatternSet patterns(ctx); scf::populateUpliftWhileToForPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + if (failed(applyPatternsGreedily(op, std::move(patterns)))) signalPassFailure(); } }; diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index 34de600132f5..173bfd8955f2 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -104,19 +104,19 @@ struct TestTensorTransforms static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateReassociativeReshapeFoldingPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyBubbleUpExpandShapePatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateBubbleUpExpandShapePatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateFoldIntoPackAndUnpackPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) { @@ -132,26 +132,26 @@ static void applyFoldConstantExtractSlicePatterns(Operation *rootOp) { }; tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyFoldConsecutiveInsertExtractSlicePatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } static void applySimplifyPackUnpackPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateSimplifyPackAndUnpackPatterns(patterns); - (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + (void)applyPatternsGreedily(rootOp, std::move(patterns)); } namespace { @@ -293,7 +293,7 @@ applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp, else patterns.add( rootOp->getContext()); - return applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); + return applyPatternsGreedily(rootOp, std::move(patterns)); } namespace { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 8a0bc597c56b..ce2820b80a94 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -388,8 +388,9 @@ struct TestGreedyPatternDriver GreedyRewriteConfig config; config.useTopDownTraversal = this->useTopDownTraversal; config.maxIterations = this->maxIterations; - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + config.fold = this->fold; + config.cseConstants = this->cseConstants; + (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); } Option useTopDownTraversal{ @@ -400,6 +401,11 @@ struct TestGreedyPatternDriver *this, "max-iterations", llvm::cl::desc("Max. iterations in the GreedyRewriteConfig"), llvm::cl::init(GreedyRewriteConfig().maxIterations)}; + Option fold{*this, "fold", llvm::cl::desc("Whether to fold"), + llvm::cl::init(GreedyRewriteConfig().fold)}; + Option cseConstants{*this, "cse-constants", + llvm::cl::desc("Whether to CSE constants"), + llvm::cl::init(GreedyRewriteConfig().cseConstants)}; }; struct DumpNotifications : public RewriterBase::Listener { @@ -511,8 +517,8 @@ public: // operation will trigger the assertion while processing. bool changed = false; bool allErased = false; - (void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config, - &changed, &allErased); + (void)applyOpPatternsGreedily(ArrayRef(ops), std::move(patterns), config, + &changed, &allErased); Builder b(ctx); getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed)); getOperation()->setAttr("pattern_driver_all_erased", @@ -2101,7 +2107,7 @@ struct TestSelectiveReplacementPatternDriver MLIRContext *context = &getContext(); mlir::RewritePatternSet patterns(context); patterns.add(context); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; } // namespace diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp index 031e1062dac7..d8763f562cbe 100644 --- a/mlir/test/lib/Dialect/Test/TestTraits.cpp +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -38,8 +38,8 @@ struct TestTraitFolder StringRef getArgument() const final { return "test-trait-folder"; } StringRef getDescription() const final { return "Run trait folding"; } void runOnOperation() override { - (void)applyPatternsAndFoldGreedily(getOperation(), - RewritePatternSet(&getContext())); + (void)applyPatternsGreedily(getOperation(), + RewritePatternSet(&getContext())); } }; } // namespace diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp index e5a3e2b6fcca..ac904c3e01c9 100644 --- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp +++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -197,7 +197,7 @@ void TosaTestQuantUtilAPI::runOnOperation() { patterns.add(ctx); patterns.add(ctx); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } } // namespace diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index f67a24755ac0..74838bc0ca2f 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -73,7 +73,7 @@ struct TestVectorToVectorLowering populateVectorToVectorCanonicalizationPatterns(patterns); populateBubbleVectorBitCastOpPatterns(patterns); populateCastAwayVectorLeadingOneDimPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } private: @@ -137,7 +137,7 @@ struct TestVectorContractionPrepareForMMTLowering MLIRContext *ctx = &getContext(); RewritePatternSet patterns(ctx); vector::populateVectorContractCanonicalizeMatmulToMMT(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -223,7 +223,7 @@ struct TestVectorUnrollingPatterns })); } populateVectorToVectorCanonicalizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } ListOption unrollOrder{*this, "unroll-order", @@ -283,7 +283,7 @@ struct TestVectorTransferUnrollingPatterns } populateVectorUnrollPatterns(patterns, opts); populateVectorToVectorCanonicalizationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } Option reverseUnrollOrder{ @@ -326,7 +326,7 @@ struct TestScalarVectorTransferLoweringPatterns RewritePatternSet patterns(ctx); vector::populateScalarVectorTransferLoweringPatterns( patterns, /*benefit=*/1, allowMultipleUses.getValue()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -370,7 +370,7 @@ struct TestVectorTransferCollapseInnerMostContiguousDims void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorTransferCollapseInnerMostContiguousDimsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -395,7 +395,7 @@ struct TestVectorSinkPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateSinkVectorOpsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -415,7 +415,7 @@ struct TestVectorReduceToContractPatternsPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorReductionToContractPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -434,7 +434,7 @@ struct TestVectorChainedReductionFoldingPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateChainedVectorReductionFoldingPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -455,7 +455,7 @@ struct TestVectorBreakDownReductionPatterns RewritePatternSet patterns(&getContext()); populateBreakDownVectorReductionPatterns(patterns, /*maxNumElementsToExtract=*/2); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -496,7 +496,7 @@ struct TestFlattenVectorTransferPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -512,7 +512,7 @@ struct TestVectorScanLowering void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorScanLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -662,18 +662,18 @@ struct TestVectorDistribution /*readBenefit=*/0); vector::populateDistributeReduction(patterns, warpReduction, 1); populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } else if (distributeTransferWriteOps) { RewritePatternSet patterns(ctx); populateDistributeTransferWriteOpPatterns(patterns, distributionFn, maxTransferWriteElements); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } else if (propagateDistribution) { RewritePatternSet patterns(ctx); vector::populatePropagateWarpVectorDistributionPatterns( patterns, distributionFn, shuffleFn); vector::populateDistributeReduction(patterns, warpReduction); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } WarpExecuteOnLane0LoweringOptions options; options.warpAllocationFn = allocateGlobalSharedMemory; @@ -684,7 +684,7 @@ struct TestVectorDistribution // Test on one pattern in isolation. if (warpOpToSCF) { populateWarpExecuteOnLane0OpToScfForPattern(patterns, options); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); return; } } @@ -706,7 +706,7 @@ struct TestVectorExtractStridedSliceLowering void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorExtractStridedSliceToExtractInsertChainPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -726,7 +726,7 @@ struct TestVectorBreakDownBitCast populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) { return op.getSourceVectorType().getShape().back() > 4; }); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -782,7 +782,7 @@ struct TestVectorGatherLowering void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorGatherLoweringPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -809,7 +809,7 @@ struct TestFoldArithExtensionIntoVectorContractPatterns void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateFoldArithExtensionPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; @@ -834,7 +834,7 @@ struct TestVectorEmulateMaskedLoadStore final void runOnOperation() override { RewritePatternSet patterns(&getContext()); populateVectorMaskedLoadStoreEmulationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp index 77aa30f847dc..7b96bf5e28d3 100644 --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -161,8 +161,8 @@ struct TestPDLByteCodePass patternList.add(std::move(pdlPattern)); // Invoke the pattern driver with the provided patterns. - (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), - std::move(patternList)); + (void)applyPatternsGreedily(irModule.getBodyRegion(), + std::move(patternList)); } }; } // namespace diff --git a/mlir/test/lib/Tools/PDLL/TestPDLL.cpp b/mlir/test/lib/Tools/PDLL/TestPDLL.cpp index db45d0eadf81..f6b2b2b1c683 100644 --- a/mlir/test/lib/Tools/PDLL/TestPDLL.cpp +++ b/mlir/test/lib/Tools/PDLL/TestPDLL.cpp @@ -39,7 +39,7 @@ struct TestPDLLPass : public PassWrapper> { void runOnOperation() final { // Invoke the pattern driver with the provided patterns. - (void)applyPatternsAndFoldGreedily(getOperation(), patterns); + (void)applyPatternsGreedily(getOperation(), patterns); } FrozenRewritePatternSet patterns; diff --git a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp index 2ec0334ae0d0..5ea35759bb72 100644 --- a/mlir/test/lib/Transforms/TestCommutativityUtils.cpp +++ b/mlir/test/lib/Transforms/TestCommutativityUtils.cpp @@ -36,7 +36,7 @@ struct CommutativityUtils RewritePatternSet patterns(context); populateCommutativityUtilsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } }; } // namespace diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp index 82fa6cdb68d2..4e0213c0e4cf 100644 --- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp +++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp @@ -123,7 +123,7 @@ void TestMakeIsolatedFromAbovePass::runOnOperation() { if (simple) { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return; @@ -132,7 +132,7 @@ void TestMakeIsolatedFromAbovePass::runOnOperation() { if (cloneOpsWithNoOperands) { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return; @@ -141,7 +141,7 @@ void TestMakeIsolatedFromAbovePass::runOnOperation() { if (cloneOpsWithOperands) { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) { return signalPassFailure(); } return;