[mlir] Enable decoupling two kinds of greedy behavior. (#104649)

The greedy rewriter is used in many different flows and it has a lot of
convenience (work list management, debugging actions, tracing, etc). But
it combines two kinds of greedy behavior 1) how ops are matched, 2)
folding wherever it can.

These are independent forms of greedy and leads to inefficiency. E.g.,
cases where one need to create different phases in lowering and is
required to applying patterns in specific order split across different
passes. Using the driver one ends up needlessly retrying folding/having
multiple rounds of folding attempts, where one final run would have
sufficed.

Of course folks can locally avoid this behavior by just building their
own, but this is also a common requested feature that folks keep on
working around locally in suboptimal ways.

For downstream users, there should be no behavioral change. Updating
from the deprecated should just be a find and replace (e.g., `find ./
-type f -exec sed -i
's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety)
as the API arguments hasn't changed between the two.
This commit is contained in:
Jacques Pienaar 2024-12-20 08:15:48 -08:00 committed by GitHub
parent 412e1af19a
commit 09dfc5713d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
110 changed files with 313 additions and 246 deletions

View File

@ -125,7 +125,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.insert<InlineElementalConversion>(context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR elemental inlining");

View File

@ -520,8 +520,8 @@ public:
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
module, std::move(patterns), config))) {
if (mlir::failed(
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
mlir::emitError(mlir::UnknownLoc::get(context),
"failure in HLFIR intrinsic lowering");
signalPassFailure();

View File

@ -1372,7 +1372,7 @@ public:
// patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
// patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR optimized bufferization");

View File

@ -491,7 +491,7 @@ public:
patterns.insert<SumAsElementalConversion>(context);
patterns.insert<CShiftAsElementalConversion>(context);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
if (mlir::failed(mlir::applyPatternsGreedily(
getOperation(), std::move(patterns), config))) {
mlir::emitError(getOperation()->getLoc(),
"failure in HLFIR intrinsic simplification");

View File

@ -39,8 +39,7 @@ struct AlgebraicSimplification
void AlgebraicSimplification::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateMathAlgebraicSimplificationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config);
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
}
std::unique_ptr<mlir::Pass> fir::createAlgebraicSimplificationPass() {

View File

@ -154,7 +154,7 @@ public:
mlir::GreedyRewriteConfig config;
config.enableRegionSimplification =
mlir::GreedySimplifyRegionLevel::Disabled;
(void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
(void)applyPatternsGreedily(mod, std::move(patterns), config);
}
};
} // namespace

View File

@ -173,8 +173,8 @@ public:
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
patterns.insert<CallOpRewriter>(context, *di);
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
mod, std::move(patterns), config))) {
if (mlir::failed(
mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
mlir::emitError(mod.getLoc(),
"error in constant globalisation optimization\n");
signalPassFailure();

View File

@ -793,8 +793,8 @@ void StackArraysPass::runOnOperation() {
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
patterns.insert<AllocMemConversion>(&context, *candidateOps);
if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
std::move(patterns), config))) {
if (mlir::failed(mlir::applyOpPatternsGreedily(
opsToConvert, std::move(patterns), config))) {
mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
signalPassFailure();
}

View File

@ -358,7 +358,7 @@ which point the driver finishes.
This driver comes in two fashions:
* `applyPatternsAndFoldGreedily` ("region-based driver") applies patterns to
* `applyPatternsGreedily` ("region-based driver") applies patterns to
all ops in a given region or a given container op (but not the container op
itself). I.e., the worklist is initialized with all containing ops.
* `applyOpPatternsAndFold` ("op-based driver") applies patterns to the

View File

@ -39,7 +39,7 @@ public:
RewritePatternSet patterns(&getContext());
patterns.add<StandaloneSwitchBarFooRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
if (failed(applyPatternsGreedily(getOperation(), patternSet)))
signalPassFailure();
}
};

View File

@ -91,6 +91,13 @@ public:
/// An optional listener that should be notified about IR modifications.
RewriterBase::Listener *listener = nullptr;
/// Whether this should fold while greedily rewriting.
bool fold = true;
/// If set to "true", constants are CSE'd (even across multiple regions that
/// are in a parent-ancestor relationship).
bool cseConstants = true;
};
//===----------------------------------------------------------------------===//
@ -104,8 +111,8 @@ public:
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// A region scope can be set in the configuration parameter. By default, the
/// scope is set to the specified region. Only in-scope ops are added to the
@ -117,10 +124,20 @@ public:
///
/// Note: This method does not apply patterns to the region's parent operation.
LogicalResult
applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
/// Same as `applyPatternsAndGreedily` above with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr);
bool *changed = nullptr) {
config.fold = true;
return applyPatternsGreedily(region, patterns, config, changed);
}
/// Rewrite ops nested under the given operation, which must be isolated from
/// above, by repeatedly applying the highest benefit patterns in a greedy
@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region &region,
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// This overload runs a separate greedy rewrite for each region of the
/// specified op. A region scope can be set in the configuration parameter. By
@ -147,23 +164,32 @@ applyPatternsAndFoldGreedily(Region &region,
///
/// Note: This method does not apply patterns to the given operation itself.
inline LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
bool anyRegionChanged = false;
bool failed = false;
for (Region &region : op->getRegions()) {
bool regionChanged;
failed |=
applyPatternsAndFoldGreedily(region, patterns, config, &regionChanged)
.failed();
failed |= applyPatternsGreedily(region, patterns, config, &regionChanged)
.failed();
anyRegionChanged |= regionChanged;
}
if (changed)
*changed = anyRegionChanged;
return failure(failed);
}
/// Same as `applyPatternsGreedily` above with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
inline LogicalResult
applyPatternsAndFoldGreedily(Operation *op,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr) {
config.fold = true;
return applyPatternsGreedily(op, patterns, config, changed);
}
/// Rewrite the specified ops by repeatedly applying the highest benefit
/// patterns in a greedy worklist driven manner until a fixpoint is reached.
@ -171,8 +197,8 @@ applyPatternsAndFoldGreedily(Operation *op,
/// The greedy rewrite may prematurely stop after a maximum number of
/// iterations, which can be configured in the configuration parameter.
///
/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.
/// Also performs simple dead-code elimination before attempting to match any of
/// the provided patterns.
///
/// Newly created ops and other pre-existing ops that use results of rewritten
/// ops or supply operands to such ops are also processed, unless such ops are
@ -180,24 +206,36 @@ applyPatternsAndFoldGreedily(Operation *op,
/// regardless of `strictMode`).
///
/// In addition to strictness, a region scope can be specified. Only ops within
/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
/// the scope are simplified. This is similar to `applyPatternsGreedily`,
/// where only ops within the given region/op are simplified by default. If no
/// scope is specified, it is assumed to be the first common enclosing region of
/// the given ops.
///
/// Note that ops in `ops` could be erased as result of folding, becoming dead,
/// or via pattern rewrites. If more far reaching simplification is desired,
/// `applyPatternsAndFoldGreedily` should be used.
/// `applyPatternsGreedily` should be used.
///
/// Returns "success" if the iterative process converged (i.e., fixpoint was
/// reached) and no more patterns can be matched. `changed` is set to "true" if
/// the IR was modified at all. `allOpsErased` is set to "true" if all ops in
/// `ops` were erased.
LogicalResult
applyOpPatternsGreedily(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
/// Same as `applyOpPatternsGreedily` with folding.
/// FIXME: Remove this once transition to above is complieted.
LLVM_DEPRECATED("Use applyOpPatternsGreedily() instead",
"applyOpPatternsGreedily")
inline LogicalResult
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
bool *changed = nullptr, bool *allErased = nullptr) {
config.fold = true;
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
}
} // namespace mlir

View File

@ -289,8 +289,7 @@ MlirLogicalResult
mlirApplyPatternsAndFoldGreedily(MlirModule op,
MlirFrozenRewritePatternSet patterns,
MlirGreedyRewriteDriverConfig) {
return wrap(
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
}
//===----------------------------------------------------------------------===//

View File

@ -385,6 +385,6 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
arith::populateArithToAMDGPUConversionPatterns(
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
*maybeChipset);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}

View File

@ -117,8 +117,7 @@ struct ArithToArmSMEConversionPass final
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
arith::populateArithToArmSMEConversionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -59,8 +59,7 @@ class ConvertArmNeon2dToIntr
RewritePatternSet patterns(context);
populateConvertArmNeon2dToIntrPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -271,7 +271,7 @@ struct LowerGpuOpsToNVVMOpsPass
{
RewritePatternSet patterns(m.getContext());
populateGpuRewritePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
if (failed(applyPatternsGreedily(m, std::move(patterns))))
return signalPassFailure();
}

View File

@ -271,7 +271,7 @@ struct LowerGpuOpsToROCDLOpsPass
RewritePatternSet patterns(ctx);
populateGpuRewritePatterns(patterns);
arith::populateExpandBFloat16Patterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
(void)applyPatternsGreedily(m, std::move(patterns));
}
LLVMTypeConverter converter(ctx, options);

View File

@ -427,8 +427,7 @@ struct ConvertMeshToMPIPass
ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
ctx);
(void)mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns));
(void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

View File

@ -62,7 +62,7 @@ class ConvertShapeConstraints
RewritePatternSet patterns(context);
populateConvertShapeConstraintsConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
if (failed(applyPatternsGreedily(func, std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -33,7 +33,7 @@ void ConvertVectorToArmSMEPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateVectorToArmSMEPatterns(patterns, getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> mlir::createConvertVectorToArmSMEPass() {

View File

@ -1326,8 +1326,7 @@ struct ConvertVectorToGPUPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
IRRewriter rewriter(&getContext());

View File

@ -82,7 +82,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorInsertExtractStridedSliceTransforms(patterns);
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
// Convert to the LLVM IR dialect.

View File

@ -1730,12 +1730,12 @@ struct ConvertVectorToSCFPass
RewritePatternSet lowerTransferPatterns(&getContext());
mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
lowerTransferPatterns);
(void)applyPatternsAndFoldGreedily(getOperation(),
std::move(lowerTransferPatterns));
(void)applyPatternsGreedily(getOperation(),
std::move(lowerTransferPatterns));
RewritePatternSet patterns(&getContext());
populateVectorToSCFConversionPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

View File

@ -318,8 +318,7 @@ struct ConvertVectorToXeGPUPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToXeGPUConversionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -132,7 +132,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
static_cast<RewriterBase::Listener *>(rewriter.getListener());
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;

View File

@ -239,5 +239,5 @@ void AffineDataCopyGeneration::runOnOperation() {
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsAndFold(copyOps, frozenPatterns, config);
(void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
}

View File

@ -198,8 +198,7 @@ public:
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
populateAffineExpandIndexOpsPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -79,8 +79,7 @@ public:
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
populateAffineExpandIndexOpsAsAffinePatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -111,5 +111,5 @@ void SimplifyAffineStructures::runOnOperation() {
});
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config);
(void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
}

View File

@ -318,8 +318,8 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsAndFold(res.getOperation(), std::move(patterns),
config, /*changed=*/nullptr, &erased);
(void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
config, /*changed=*/nullptr, &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)

View File

@ -425,8 +425,8 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsAndFold(ifOp.getOperation(), frozenPatterns, config,
/*changed=*/nullptr, &erased);
(void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
/*changed=*/nullptr, &erased);
if (erased) {
if (folded)
*folded = true;
@ -454,7 +454,7 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
// Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
// a sequence of affine.fors that are all perfectly nested).
(void)applyPatternsAndFoldGreedily(
(void)applyPatternsGreedily(
hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
frozenPatterns);

View File

@ -489,7 +489,7 @@ struct IntRangeOptimizationsPass final
GreedyRewriteConfig config;
config.listener = &listener;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
signalPassFailure();
}
};
@ -518,7 +518,7 @@ struct IntRangeNarrowingPass final
config.useTopDownTraversal = false;
config.listener = &listener;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
signalPassFailure();
}
};

View File

@ -523,8 +523,7 @@ struct OuterProductFusionPass
RewritePatternSet patterns(&getContext());
populateOuterProductFusionPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};

View File

@ -317,8 +317,7 @@ struct LegalizeVectorStorage
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateLegalizeVectorStoragePatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
ConversionTarget target(getContext());

View File

@ -931,7 +931,7 @@ void AsyncParallelForPass::runOnOperation() {
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
return builder.create<arith::ConstantIndexOp>(minTaskSize);
});
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}

View File

@ -470,8 +470,8 @@ struct BufferDeallocationSimplificationPass
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
config)))
if (failed(
applyPatternsGreedily(getOperation(), std::move(patterns), config)))
signalPassFailure();
}
};

View File

@ -60,7 +60,7 @@ void EmptyTensorToAllocTensor::runOnOperation() {
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateEmptyTensorToAllocTensorPattern(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
signalPassFailure();
}

View File

@ -47,7 +47,7 @@ struct FormExpressionsPass
RewritePatternSet patterns(context);
populateExpressionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns))))
if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
return signalPassFailure();
}

View File

@ -227,8 +227,7 @@ struct GpuDecomposeMemrefsPass
populateGpuDecomposeMemrefsPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -630,7 +630,7 @@ class GpuEliminateBarriersPass
auto funcOp = getOperation();
RewritePatternSet patterns(&getContext());
mlir::populateGpuEliminateBarriersPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

View File

@ -96,7 +96,7 @@ void NVVMOptimizeForTarget::runOnOperation() {
MLIRContext *ctx = getOperation()->getContext();
RewritePatternSet patterns(ctx);
patterns.add<ExpandDivF16>(ctx);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}

View File

@ -3511,7 +3511,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
TrackingListener listener(state, *this);
GreedyRewriteConfig config;
config.listener = &listener;
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
return emitDefaultDefiniteFailure(target);
results.push_back(target);

View File

@ -301,7 +301,7 @@ struct LinalgBlockPackMatmul
};
linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -563,8 +563,7 @@ struct LinalgDetensorize
RewritePatternSet canonPatterns(context);
tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(canonPatterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
signalPassFailure();
// Get rid of the dummy entry block we created in the beginning to work

View File

@ -831,7 +831,7 @@ struct LinalgFoldUnitExtentDimsPass
}
linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
populateMoveInitOperandsToInputPattern(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
(void)applyPatternsGreedily(op, std::move(patterns));
}
};

View File

@ -2206,7 +2206,7 @@ struct LinalgElementwiseOpFusionPass
// Use TopDownTraversal for compile time reasons
GreedyRewriteConfig grc;
grc.useTopDownTraversal = true;
(void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
(void)applyPatternsGreedily(op, std::move(patterns), grc);
}
};

View File

@ -89,7 +89,7 @@ struct LinalgGeneralizeNamedOpsPass
void LinalgGeneralizeNamedOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgNamedOpsGeneralizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(

View File

@ -113,7 +113,7 @@ struct LinalgInlineScalarOperandsPass
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
populateInlineConstantOperandsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
(void)applyPatternsGreedily(op, std::move(patterns));
}
};
} // namespace

View File

@ -321,7 +321,7 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);
// Just apply the patterns greedily.
(void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns));
(void)applyPatternsGreedily(enclosingOp, std::move(patterns));
}
struct LowerToAffineLoops

View File

@ -152,7 +152,7 @@ struct LinalgNamedOpConversionPass
Operation *op = getOperation();
RewritePatternSet patterns(op->getContext());
populateLinalgNamedOpConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -349,7 +349,7 @@ void LinalgSpecializeGenericOpsPass::runOnOperation() {
populateLinalgGenericOpsSpecializationPatterns(patterns);
populateDecomposeProjectedPermutationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}

View File

@ -66,8 +66,7 @@ struct MathUpliftToFMA final
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateUpliftToFMAPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -1223,7 +1223,7 @@ struct ExpandStridedMetadataPass final
void ExpandStridedMetadataPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateExpandStridedMetadataPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {

View File

@ -857,7 +857,7 @@ struct FoldMemRefAliasOpsPass final
void FoldMemRefAliasOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateFoldMemRefAliasOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {

View File

@ -195,7 +195,7 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
@ -203,7 +203,7 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}

View File

@ -112,7 +112,7 @@ struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
MLIRContext *ctx = parentOp->getContext();
RewritePatternSet patterns(ctx);
patterns.add<ForLoopLoweringPattern>(ctx);
(void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
(void)applyPatternsGreedily(parentOp, std::move(patterns));
}
};
} // namespace

View File

@ -167,7 +167,7 @@ struct SCFForLoopCanonicalization
MLIRContext *ctx = parentOp->getContext();
RewritePatternSet patterns(ctx);
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
if (failed(applyPatternsGreedily(parentOp, std::move(patterns))))
signalPassFailure();
}
};

View File

@ -331,7 +331,7 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
MLIRContext *ctx = parentOp->getContext();
RewritePatternSet patterns(ctx);
patterns.add<ForLoopPeelingPattern>(ctx, peelFront, skipPartial);
(void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
(void)applyPatternsGreedily(parentOp, std::move(patterns));
// Drop the markers.
parentOp->walk([](Operation *op) {

View File

@ -1430,7 +1430,7 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
GreedyRewriteConfig config;
config.listener = this;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
return applyOpPatternsAndFold(ops, patterns.value(), config);
return applyOpPatternsGreedily(ops, patterns.value(), config);
}
void SliceTrackingListener::notifyOperationInserted(

View File

@ -29,8 +29,7 @@ public:
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
spirv::populateSPIRVGLCanonicalizationPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -1354,7 +1354,7 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
// looking for newly created func ops.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
return applyPatternsAndFoldGreedily(op, std::move(patterns), config);
return applyPatternsGreedily(op, std::move(patterns), config);
}
LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
@ -1366,7 +1366,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
[](auto op) { return mlir::spirv::getNativeVectorShape(op); });
populateVectorUnrollPatterns(patterns, options);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
}
@ -1378,7 +1378,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
vector::VectorTransposeLowering::EltWise);
vector::populateVectorTransposeLoweringPatterns(patterns, options);
vector::populateVectorShapeCastLoweringPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
}
@ -1403,7 +1403,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
if (failed(applyPatternsGreedily(op, std::move(patterns))))
return failure();
}
return success();

View File

@ -236,8 +236,7 @@ struct WebGPUPreparePass final
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}
};

View File

@ -207,7 +207,7 @@ void OutlineShapeComputationPass::runOnOperation() {
MLIRContext *context = funcOp.getContext();
RewritePatternSet prevPatterns(context);
prevPatterns.insert<TensorDimOpRewriter>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns))))
if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns))))
return signalPassFailure();
// initialize class member `onlyUsedByWithShapes`
@ -254,7 +254,7 @@ void OutlineShapeComputationPass::runOnOperation() {
}
// Apply patterns, note this also performs DCE.
if (failed(applyPatternsAndFoldGreedily(funcOp, {})))
if (failed(applyPatternsGreedily(funcOp, {})))
return signalPassFailure();
});
}

View File

@ -55,7 +55,7 @@ class RemoveShapeConstraintsPass
RewritePatternSet patterns(&ctx);
populateRemoveShapeConstraintsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

View File

@ -57,7 +57,7 @@ struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseAssembler(patterns, directOut);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -73,7 +73,7 @@ struct SparseReinterpretMap
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseReinterpretMap(patterns, scope);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -87,7 +87,7 @@ struct PreSparsificationRewritePass
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populatePreSparsificationRewriting(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -110,7 +110,7 @@ struct SparsificationPass
RewritePatternSet patterns(ctx);
populateSparsificationPatterns(patterns, options);
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -122,7 +122,7 @@ struct StageSparseOperationsPass
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateStageSparseOperationsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -141,7 +141,7 @@ struct LowerSparseOpsToForeachPass
RewritePatternSet patterns(ctx);
populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
enableConvert);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -154,7 +154,7 @@ struct LowerForeachToSCFPass
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateLowerForeachToSCFPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -329,7 +329,7 @@ struct SparseBufferRewritePass
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateSparseBufferRewriting(patterns, enableBufferInitialization);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -351,7 +351,7 @@ struct SparseVectorizationPass
populateSparseVectorizationPatterns(
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -371,7 +371,7 @@ struct SparseGPUCodegenPass
populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
else
populateSparseGPUCodegenPatterns(patterns, numThreads);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

View File

@ -277,7 +277,7 @@ struct FoldTensorSubsetOpsPass final
void FoldTensorSubsetOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
tensor::populateFoldTensorSubsetOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {

View File

@ -60,7 +60,7 @@ struct TosaLayerwiseConstantFoldPass
aggressiveReduceConstant);
populateTosaOpsCanonicalizationPatterns(ctx, patterns);
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
if (applyPatternsGreedily(func, std::move(patterns)).failed())
signalPassFailure();
}
};

View File

@ -246,7 +246,7 @@ public:
patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
(void)applyPatternsGreedily(func, std::move(patterns));
}
};
} // namespace

View File

@ -42,7 +42,7 @@ struct TosaOptionalDecompositions
mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns);
mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns);
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
if (applyPatternsGreedily(func, std::move(patterns)).failed())
signalPassFailure();
}
};

View File

@ -417,7 +417,7 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
// Op is isolated from above. Apply patterns and also perform region
// simplification.
result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
result = applyPatternsGreedily(target, frozenPatterns, config);
} else {
// Manually gather list of ops because the other
// GreedyPatternRewriteDriver overloads only accepts ops that are isolated
@ -429,7 +429,7 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
if (target != nestedOp)
ops.push_back(nestedOp);
});
result = applyOpPatternsAndFold(ops, frozenPatterns, config);
result = applyOpPatternsGreedily(ops, frozenPatterns, config);
}
// A failure typically indicates that the pattern application did not

View File

@ -286,7 +286,7 @@ struct LowerVectorMaskPass
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
}

View File

@ -486,7 +486,7 @@ struct LowerVectorMultiReductionPass
populateVectorMultiReductionLoweringPatterns(loweringPatterns,
this->loweringStrategy);
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
signalPassFailure();
}

View File

@ -78,5 +78,5 @@ struct XeGPUFoldAliasOpsPass final
void XeGPUFoldAliasOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
xegpu::populateXeGPUFoldAliasOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}

View File

@ -65,7 +65,7 @@ static void applyPatterns(Region &region,
// because we don't have expectation this reduction will be success or not.
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyOpPatternsAndFold(op, patterns, config);
(void)applyOpPatternsGreedily(op, patterns, config);
}
if (eraseOpNotInRange)

View File

@ -60,7 +60,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
}
void runOnOperation() override {
LogicalResult converged =
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
applyPatternsGreedily(getOperation(), *patterns, config);
// Canonicalization is best-effort. Non-convergence is not a pass failure.
if (testConvergence && failed(converged))
signalPassFailure();

View File

@ -6,7 +6,7 @@
//
//===----------------------------------------------------------------------===//
//
// This file implements mlir::applyPatternsAndFoldGreedily.
// This file implements mlir::applyPatternsGreedily.
//
//===----------------------------------------------------------------------===//
@ -488,7 +488,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// infinite folding loop, as every constant op would be folded to an
// Attribute and then immediately be rematerialized as a constant op, which
// is then put on the worklist.
if (!op->hasTrait<OpTrait::ConstantLike>()) {
if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
@ -852,13 +852,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
if (!config.useTopDownTraversal) {
// Add operations to the worklist in postorder.
region.walk([&](Operation *op) {
if (!insertKnownConstant(op))
if (!config.cseConstants || !insertKnownConstant(op))
addToWorklist(op);
});
} else {
// Add all nested operations to the worklist in preorder.
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (!insertKnownConstant(op)) {
if (!config.cseConstants || !insertKnownConstant(op)) {
addToWorklist(op);
return WalkResult::advance();
}
@ -894,9 +894,9 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
}
LogicalResult
mlir::applyPatternsAndFoldGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed) {
mlir::applyPatternsGreedily(Region &region,
const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed) {
// The top-level operation must be known to be isolated from above to
// prevent performing canonicalizations on operations defined at or above
// the region containing 'op'.
@ -1012,7 +1012,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
return region;
}
LogicalResult mlir::applyOpPatternsAndFold(
LogicalResult mlir::applyOpPatternsGreedily(
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config, bool *changed, bool *allErased) {
if (ops.empty()) {

View File

@ -296,7 +296,7 @@ OneToNConversionPattern::matchAndRewrite(Operation *op,
namespace mlir {
// This function applies the provided patterns using
// `applyPatternsAndFoldGreedily` and then replaces all newly inserted
// `applyPatternsGreedily` and then replaces all newly inserted
// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
// from target to source types inserted by a `OneToNConversionPattern` normally
// fold away with the "forward" casts from source to target types inserted by
@ -317,7 +317,7 @@ applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
#endif // NDEBUG
// Apply provided conversion patterns.
if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
if (failed(applyPatternsGreedily(op, patterns))) {
emitError(op->getLoc()) << "failed to apply conversion patterns";
return failure();
}

View File

@ -1,5 +1,7 @@
// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
// RUN: mlir-opt -test-greedy-patterns='cse-constants=false' %s | FileCheck %s --check-prefix=NOCSE
// RUN: mlir-opt -test-greedy-patterns='fold=false' %s | FileCheck %s --check-prefix=NOFOLD
func.func @foo() -> i32 {
%c42 = arith.constant 42 : i32
@ -25,7 +27,8 @@ func.func @test_fold_before_previously_folded_op() -> (i32, i32) {
}
func.func @test_dont_reorder_constants() -> (i32, i32, i32) {
// Test that we don't reorder existing constants during folding if it isn't necessary.
// Test that we don't reorder existing constants during folding if it isn't
// necessary.
// CHECK: %[[CST:.+]] = arith.constant 1
// CHECK-NEXT: %[[CST:.+]] = arith.constant 2
// CHECK-NEXT: %[[CST:.+]] = arith.constant 3
@ -34,3 +37,46 @@ func.func @test_dont_reorder_constants() -> (i32, i32, i32) {
%2 = arith.constant 3 : i32
return %0, %1, %2 : i32, i32, i32
}
// CHECK-LABEL: test_fold_nofold_nocse
// NOCSE-LABEL: test_fold_nofold_nocse
// NOFOLD-LABEL: test_fold_nofold_nocse
func.func @test_fold_nofold_nocse() -> (i32, i32, i32, i32, i32, i32) {
// Test either not folding or deduping constants.
// Testing folding. There should be only 4 constants here.
// CHECK-NOT: arith.constant
// CHECK-DAG: %[[CST:.+]] = arith.constant 0
// CHECK-DAG: %[[CST:.+]] = arith.constant 1
// CHECK-DAG: %[[CST:.+]] = arith.constant 2
// CHECK-DAG: %[[CST:.+]] = arith.constant 3
// CHECK-NOT: arith.constant
// CHECK-NEXT: return
// Testing not-CSE'ing. In this case we have the 3 original constants and 3
// produced by folding.
// NOCSE-DAG: arith.constant 0 : i32
// NOCSE-DAG: arith.constant 1 : i32
// NOCSE-DAG: arith.constant 2 : i32
// NOCSE-DAG: arith.constant 1 : i32
// NOCSE-DAG: arith.constant 2 : i32
// NOCSE-DAG: arith.constant 3 : i32
// NOCSE-NEXT: return
// Testing not folding. In this case we just have the original constants.
// NOFOLD-DAG: %[[CST:.+]] = arith.constant 0
// NOFOLD-DAG: %[[CST:.+]] = arith.constant 1
// NOFOLD-DAG: %[[CST:.+]] = arith.constant 2
// NOFOLD: arith.addi
// NOFOLD: arith.addi
// NOFOLD: arith.addi
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%c2 = arith.constant 2 : i32
%0 = arith.addi %c0, %c1 : i32
%1 = arith.addi %0, %c1 : i32
%2 = arith.addi %c2, %c1 : i32
return %0, %1, %2, %c0, %c1, %c2 : i32, i32, i32, i32, i32, i32
}

View File

@ -248,7 +248,7 @@ struct TestMathToVCIX
RewritePatternSet patterns(ctx);
patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>(
ctx);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

View File

@ -41,7 +41,7 @@ struct TestVectorReductionToSPIRVDotProd
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorReductionToSPIRVDotProductPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

View File

@ -136,7 +136,7 @@ void TestAffineDataCopy::runOnOperation() {
}
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsAndFold(copyOps, std::move(patterns), config);
(void)applyOpPatternsGreedily(copyOps, std::move(patterns), config);
}
namespace mlir {

View File

@ -47,7 +47,7 @@ void TestLowerToArmNeon::runOnOperation() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
populateLowerContractionToSMMLAPatternPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}

View File

@ -38,7 +38,7 @@ struct TestGpuRewritePass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
@ -85,7 +85,7 @@ struct TestGpuSubgroupReduceLoweringPass
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
}
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
} // namespace

View File

@ -34,8 +34,7 @@ struct TestDataLayoutPropagationPass
RewritePatternSet patterns(context);
linalg::populateDataLayoutPropagationPatterns(
patterns, [](OpOperand *opOperand) { return true; });
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -43,8 +43,8 @@ struct TestLinalgDecomposeOps
RewritePatternSet decompositionPatterns(context);
linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns,
removeDeadArgsAndResults);
if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(decompositionPatterns)))) {
if (failed(applyPatternsGreedily(getOperation(),
std::move(decompositionPatterns)))) {
return signalPassFailure();
}
}

View File

@ -155,8 +155,8 @@ struct TestLinalgElementwiseFusion
RewritePatternSet fusionPatterns(context);
auto controlFn = [](OpOperand *operand) { return true; };
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
@ -166,8 +166,8 @@ struct TestLinalgElementwiseFusion
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
setFusedOpOperandLimit<4>);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
@ -176,8 +176,8 @@ struct TestLinalgElementwiseFusion
RewritePatternSet fusionPatterns(context);
linalg::populateFoldReshapeOpsByExpansionPatterns(
fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
@ -212,8 +212,8 @@ struct TestLinalgElementwiseFusion
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
controlReshapeFusionFn);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(),
std::move(fusionPatterns))))
return signalPassFailure();
return;
}
@ -222,8 +222,7 @@ struct TestLinalgElementwiseFusion
RewritePatternSet patterns(context);
linalg::populateFoldReshapeOpsByCollapsingPatterns(
patterns, [](OpOperand * /*fusedOperand */) { return true; });
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
return signalPassFailure();
return;
}
@ -239,8 +238,7 @@ struct TestLinalgElementwiseFusion
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
return signalPassFailure();
return;
}
@ -248,8 +246,7 @@ struct TestLinalgElementwiseFusion
if (fuseMultiUseProducer) {
RewritePatternSet patterns(context);
patterns.insert<TestMultiUseProducerFusion>(context);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
return signalPassFailure();
return;
}
@ -265,8 +262,7 @@ struct TestLinalgElementwiseFusion
};
RewritePatternSet patterns(context);
linalg::populateCollapseDimensions(patterns, collapseFn);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
return signalPassFailure();
return;
}

View File

@ -84,7 +84,7 @@ struct TestLinalgGreedyFusion
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());
do {
(void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
(void)applyPatternsGreedily(getOperation(), frozenPatterns);
if (failed(runPipeline(pm, getOperation())))
this->signalPassFailure();
} while (succeeded(fuseLinalgOpsGreedily(getOperation())));

View File

@ -49,8 +49,7 @@ struct TestLinalgRankReduceContractionOps
RewritePatternSet patterns(context);
linalg::populateContractionOpRankReducingPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(patterns))))
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
return signalPassFailure();
return;
}

View File

@ -147,14 +147,14 @@ static void applyPatterns(func::FuncOp funcOp) {
//===--------------------------------------------------------------------===//
patterns.add<CopyVectorizationPattern>(ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
RewritePatternSet forwardPattern(funcOp.getContext());
forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
(void)applyPatternsGreedily(funcOp, std::move(forwardPattern));
}
static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
@ -163,68 +163,68 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
patterns.add<CopyVectorizationPattern>(ctx);
populatePadOpVectorizationPatterns(patterns);
populateConvolutionVectorizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyDecomposePadPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<DecomposePadOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateBubbleUpExtractSliceOpPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateSwapExtractSliceWithFillPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateEraseUnusedOperandsAndResultsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateEraseUnnecessaryInputsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyWinogradConv2D(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
RewritePatternSet patterns(funcOp.getContext());
populateDecomposeWinogradOpsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsGreedily(funcOp, std::move(patterns));
}
/// Apply transformations specified as patterns.

View File

@ -36,8 +36,7 @@ struct TestPadFusionPass
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
return signalPassFailure();
}
};

View File

@ -40,7 +40,7 @@ struct TestMathAlgebraicSimplificationPass
void TestMathAlgebraicSimplificationPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateMathAlgebraicSimplificationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
namespace mlir {

View File

@ -53,7 +53,7 @@ void TestExpandMathPass::runOnOperation() {
populateExpandRoundFPattern(patterns);
populateExpandRoundEvenPattern(patterns);
populateExpandRsqrtPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
namespace mlir {

View File

@ -59,7 +59,7 @@ void TestMathPolynomialApproximationPass::runOnOperation() {
MathPolynomialApproximationOptions approxOptions;
approxOptions.enableAvx2 = enableAvx2;
populateMathPolynomialApproximationPatterns(patterns, approxOptions);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
namespace mlir {

View File

@ -38,7 +38,7 @@ void TestComposeSubViewPass::getDependentDialects(
void TestComposeSubViewPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
memref::populateComposeSubViewPatterns(patterns, &getContext());
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
} // namespace

View File

@ -26,9 +26,9 @@ struct TestAllSliceOpLoweringPass
SymbolTableCollection symbolTableCollection;
mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
LogicalResult status =
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
applyPatternsGreedily(getOperation(), std::move(patterns));
(void)status;
assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
assert(succeeded(status) && "applyPatternsGreedily failed.");
}
void getDependentDialects(DialectRegistry &registry) const override {
mesh::registerAllSliceOpLoweringDialects(registry);
@ -51,9 +51,9 @@ struct TestMultiIndexOpLoweringPass
mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
symbolTableCollection);
LogicalResult status =
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
applyPatternsGreedily(getOperation(), std::move(patterns));
(void)status;
assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
assert(succeeded(status) && "applyPatternsGreedily failed.");
}
void getDependentDialects(DialectRegistry &registry) const override {
mesh::registerProcessMultiIndexOpLoweringDialects(registry);

View File

@ -97,8 +97,8 @@ struct TestMeshReshardingPass
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
std::move(patterns)))) {
if (failed(applyPatternsGreedily(getOperation().getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}

View File

@ -34,7 +34,7 @@ void TestMeshSimplificationsPass::runOnOperation() {
SymbolTableCollection symbolTableCollection;
mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
[[maybe_unused]] LogicalResult status =
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
applyPatternsGreedily(getOperation(), std::move(patterns));
assert(succeeded(status) && "Rewrite patters application did not converge.");
}

View File

@ -60,7 +60,7 @@ struct TestMmaSyncF32ToTF32Patterns
RewritePatternSet patterns(&getContext());
populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};

View File

@ -226,7 +226,7 @@ struct TestSCFPipeliningPass
options.peelEpilogue = false;
}
scf::populateSCFLoopPipeliningPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
getOperation().walk([](Operation *op) {
// Clean up the markers.
op->removeAttr(kTestPipeliningStageMarker);

View File

@ -59,7 +59,7 @@ struct TestWrapWhileLoopInZeroTripCheckPass
} else {
RewritePatternSet patterns(context);
scf::populateSCFRotateWhileLoopPatterns(patterns);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
(void)applyPatternsGreedily(func, std::move(patterns));
}
}

Some files were not shown because too many files have changed in this diff Show More