[mlir] Enable decoupling two kinds of greedy behavior. (#104649)
The greedy rewriter is used in many different flows and it has a lot of convenience (work list management, debugging actions, tracing, etc). But it combines two kinds of greedy behavior 1) how ops are matched, 2) folding wherever it can. These are independent forms of greedy and leads to inefficiency. E.g., cases where one need to create different phases in lowering and is required to applying patterns in specific order split across different passes. Using the driver one ends up needlessly retrying folding/having multiple rounds of folding attempts, where one final run would have sufficed. Of course folks can locally avoid this behavior by just building their own, but this is also a common requested feature that folks keep on working around locally in suboptimal ways. For downstream users, there should be no behavioral change. Updating from the deprecated should just be a find and replace (e.g., `find ./ -type f -exec sed -i 's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety) as the API arguments hasn't changed between the two.
This commit is contained in:
parent
412e1af19a
commit
09dfc5713d
@ -125,7 +125,7 @@ public:
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.insert<InlineElementalConversion>(context);
|
||||
|
||||
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
|
||||
if (mlir::failed(mlir::applyPatternsGreedily(
|
||||
getOperation(), std::move(patterns), config))) {
|
||||
mlir::emitError(getOperation()->getLoc(),
|
||||
"failure in HLFIR elemental inlining");
|
||||
|
@ -520,8 +520,8 @@ public:
|
||||
config.enableRegionSimplification =
|
||||
mlir::GreedySimplifyRegionLevel::Disabled;
|
||||
|
||||
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
|
||||
module, std::move(patterns), config))) {
|
||||
if (mlir::failed(
|
||||
mlir::applyPatternsGreedily(module, std::move(patterns), config))) {
|
||||
mlir::emitError(mlir::UnknownLoc::get(context),
|
||||
"failure in HLFIR intrinsic lowering");
|
||||
signalPassFailure();
|
||||
|
@ -1372,7 +1372,7 @@ public:
|
||||
// patterns.insert<ReductionMaskConversion<hlfir::MaxvalOp>>(context);
|
||||
// patterns.insert<ReductionMaskConversion<hlfir::MinvalOp>>(context);
|
||||
|
||||
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
|
||||
if (mlir::failed(mlir::applyPatternsGreedily(
|
||||
getOperation(), std::move(patterns), config))) {
|
||||
mlir::emitError(getOperation()->getLoc(),
|
||||
"failure in HLFIR optimized bufferization");
|
||||
|
@ -491,7 +491,7 @@ public:
|
||||
patterns.insert<SumAsElementalConversion>(context);
|
||||
patterns.insert<CShiftAsElementalConversion>(context);
|
||||
|
||||
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
|
||||
if (mlir::failed(mlir::applyPatternsGreedily(
|
||||
getOperation(), std::move(patterns), config))) {
|
||||
mlir::emitError(getOperation()->getLoc(),
|
||||
"failure in HLFIR intrinsic simplification");
|
||||
|
@ -39,8 +39,7 @@ struct AlgebraicSimplification
|
||||
void AlgebraicSimplification::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateMathAlgebraicSimplificationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
config);
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns), config);
|
||||
}
|
||||
|
||||
std::unique_ptr<mlir::Pass> fir::createAlgebraicSimplificationPass() {
|
||||
|
@ -154,7 +154,7 @@ public:
|
||||
mlir::GreedyRewriteConfig config;
|
||||
config.enableRegionSimplification =
|
||||
mlir::GreedySimplifyRegionLevel::Disabled;
|
||||
(void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config);
|
||||
(void)applyPatternsGreedily(mod, std::move(patterns), config);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -173,8 +173,8 @@ public:
|
||||
config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps;
|
||||
|
||||
patterns.insert<CallOpRewriter>(context, *di);
|
||||
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
|
||||
mod, std::move(patterns), config))) {
|
||||
if (mlir::failed(
|
||||
mlir::applyPatternsGreedily(mod, std::move(patterns), config))) {
|
||||
mlir::emitError(mod.getLoc(),
|
||||
"error in constant globalisation optimization\n");
|
||||
signalPassFailure();
|
||||
|
@ -793,8 +793,8 @@ void StackArraysPass::runOnOperation() {
|
||||
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
|
||||
|
||||
patterns.insert<AllocMemConversion>(&context, *candidateOps);
|
||||
if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert,
|
||||
std::move(patterns), config))) {
|
||||
if (mlir::failed(mlir::applyOpPatternsGreedily(
|
||||
opsToConvert, std::move(patterns), config))) {
|
||||
mlir::emitError(func->getLoc(), "error in stack arrays optimization\n");
|
||||
signalPassFailure();
|
||||
}
|
||||
|
@ -358,7 +358,7 @@ which point the driver finishes.
|
||||
|
||||
This driver comes in two fashions:
|
||||
|
||||
* `applyPatternsAndFoldGreedily` ("region-based driver") applies patterns to
|
||||
* `applyPatternsGreedily` ("region-based driver") applies patterns to
|
||||
all ops in a given region or a given container op (but not the container op
|
||||
itself). I.e., the worklist is initialized with all containing ops.
|
||||
* `applyOpPatternsAndFold` ("op-based driver") applies patterns to the
|
||||
|
@ -39,7 +39,7 @@ public:
|
||||
RewritePatternSet patterns(&getContext());
|
||||
patterns.add<StandaloneSwitchBarFooRewriter>(&getContext());
|
||||
FrozenRewritePatternSet patternSet(std::move(patterns));
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet)))
|
||||
if (failed(applyPatternsGreedily(getOperation(), patternSet)))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -91,6 +91,13 @@ public:
|
||||
|
||||
/// An optional listener that should be notified about IR modifications.
|
||||
RewriterBase::Listener *listener = nullptr;
|
||||
|
||||
/// Whether this should fold while greedily rewriting.
|
||||
bool fold = true;
|
||||
|
||||
/// If set to "true", constants are CSE'd (even across multiple regions that
|
||||
/// are in a parent-ancestor relationship).
|
||||
bool cseConstants = true;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -104,8 +111,8 @@ public:
|
||||
/// The greedy rewrite may prematurely stop after a maximum number of
|
||||
/// iterations, which can be configured in the configuration parameter.
|
||||
///
|
||||
/// Also performs folding and simple dead-code elimination before attempting to
|
||||
/// match any of the provided patterns.
|
||||
/// Also performs simple dead-code elimination before attempting to match any of
|
||||
/// the provided patterns.
|
||||
///
|
||||
/// A region scope can be set in the configuration parameter. By default, the
|
||||
/// scope is set to the specified region. Only in-scope ops are added to the
|
||||
@ -117,10 +124,20 @@ public:
|
||||
///
|
||||
/// Note: This method does not apply patterns to the region's parent operation.
|
||||
LogicalResult
|
||||
applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr);
|
||||
/// Same as `applyPatternsAndGreedily` above with folding.
|
||||
/// FIXME: Remove this once transition to above is complieted.
|
||||
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
|
||||
inline LogicalResult
|
||||
applyPatternsAndFoldGreedily(Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr);
|
||||
bool *changed = nullptr) {
|
||||
config.fold = true;
|
||||
return applyPatternsGreedily(region, patterns, config, changed);
|
||||
}
|
||||
|
||||
/// Rewrite ops nested under the given operation, which must be isolated from
|
||||
/// above, by repeatedly applying the highest benefit patterns in a greedy
|
||||
@ -129,8 +146,8 @@ applyPatternsAndFoldGreedily(Region ®ion,
|
||||
/// The greedy rewrite may prematurely stop after a maximum number of
|
||||
/// iterations, which can be configured in the configuration parameter.
|
||||
///
|
||||
/// Also performs folding and simple dead-code elimination before attempting to
|
||||
/// match any of the provided patterns.
|
||||
/// Also performs simple dead-code elimination before attempting to match any of
|
||||
/// the provided patterns.
|
||||
///
|
||||
/// This overload runs a separate greedy rewrite for each region of the
|
||||
/// specified op. A region scope can be set in the configuration parameter. By
|
||||
@ -147,23 +164,32 @@ applyPatternsAndFoldGreedily(Region ®ion,
|
||||
///
|
||||
/// Note: This method does not apply patterns to the given operation itself.
|
||||
inline LogicalResult
|
||||
applyPatternsAndFoldGreedily(Operation *op,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr) {
|
||||
applyPatternsGreedily(Operation *op, const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr) {
|
||||
bool anyRegionChanged = false;
|
||||
bool failed = false;
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
bool regionChanged;
|
||||
failed |=
|
||||
applyPatternsAndFoldGreedily(region, patterns, config, ®ionChanged)
|
||||
.failed();
|
||||
failed |= applyPatternsGreedily(region, patterns, config, ®ionChanged)
|
||||
.failed();
|
||||
anyRegionChanged |= regionChanged;
|
||||
}
|
||||
if (changed)
|
||||
*changed = anyRegionChanged;
|
||||
return failure(failed);
|
||||
}
|
||||
/// Same as `applyPatternsGreedily` above with folding.
|
||||
/// FIXME: Remove this once transition to above is complieted.
|
||||
LLVM_DEPRECATED("Use applyPatternsGreedily() instead", "applyPatternsGreedily")
|
||||
inline LogicalResult
|
||||
applyPatternsAndFoldGreedily(Operation *op,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr) {
|
||||
config.fold = true;
|
||||
return applyPatternsGreedily(op, patterns, config, changed);
|
||||
}
|
||||
|
||||
/// Rewrite the specified ops by repeatedly applying the highest benefit
|
||||
/// patterns in a greedy worklist driven manner until a fixpoint is reached.
|
||||
@ -171,8 +197,8 @@ applyPatternsAndFoldGreedily(Operation *op,
|
||||
/// The greedy rewrite may prematurely stop after a maximum number of
|
||||
/// iterations, which can be configured in the configuration parameter.
|
||||
///
|
||||
/// Also performs folding and simple dead-code elimination before attempting to
|
||||
/// match any of the provided patterns.
|
||||
/// Also performs simple dead-code elimination before attempting to match any of
|
||||
/// the provided patterns.
|
||||
///
|
||||
/// Newly created ops and other pre-existing ops that use results of rewritten
|
||||
/// ops or supply operands to such ops are also processed, unless such ops are
|
||||
@ -180,24 +206,36 @@ applyPatternsAndFoldGreedily(Operation *op,
|
||||
/// regardless of `strictMode`).
|
||||
///
|
||||
/// In addition to strictness, a region scope can be specified. Only ops within
|
||||
/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
|
||||
/// the scope are simplified. This is similar to `applyPatternsGreedily`,
|
||||
/// where only ops within the given region/op are simplified by default. If no
|
||||
/// scope is specified, it is assumed to be the first common enclosing region of
|
||||
/// the given ops.
|
||||
///
|
||||
/// Note that ops in `ops` could be erased as result of folding, becoming dead,
|
||||
/// or via pattern rewrites. If more far reaching simplification is desired,
|
||||
/// `applyPatternsAndFoldGreedily` should be used.
|
||||
/// `applyPatternsGreedily` should be used.
|
||||
///
|
||||
/// Returns "success" if the iterative process converged (i.e., fixpoint was
|
||||
/// reached) and no more patterns can be matched. `changed` is set to "true" if
|
||||
/// the IR was modified at all. `allOpsErased` is set to "true" if all ops in
|
||||
/// `ops` were erased.
|
||||
LogicalResult
|
||||
applyOpPatternsGreedily(ArrayRef<Operation *> ops,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr, bool *allErased = nullptr);
|
||||
/// Same as `applyOpPatternsGreedily` with folding.
|
||||
/// FIXME: Remove this once transition to above is complieted.
|
||||
LLVM_DEPRECATED("Use applyOpPatternsGreedily() instead",
|
||||
"applyOpPatternsGreedily")
|
||||
inline LogicalResult
|
||||
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr, bool *allErased = nullptr);
|
||||
bool *changed = nullptr, bool *allErased = nullptr) {
|
||||
config.fold = true;
|
||||
return applyOpPatternsGreedily(ops, patterns, config, changed, allErased);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
@ -289,8 +289,7 @@ MlirLogicalResult
|
||||
mlirApplyPatternsAndFoldGreedily(MlirModule op,
|
||||
MlirFrozenRewritePatternSet patterns,
|
||||
MlirGreedyRewriteDriverConfig) {
|
||||
return wrap(
|
||||
mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns)));
|
||||
return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -385,6 +385,6 @@ void ArithToAMDGPUConversionPass::runOnOperation() {
|
||||
arith::populateArithToAMDGPUConversionPatterns(
|
||||
patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
|
||||
*maybeChipset);
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
@ -117,8 +117,7 @@ struct ArithToArmSMEConversionPass final
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
arith::populateArithToArmSMEConversionPatterns(patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -59,8 +59,7 @@ class ConvertArmNeon2dToIntr
|
||||
RewritePatternSet patterns(context);
|
||||
populateConvertArmNeon2dToIntrPatterns(patterns);
|
||||
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -271,7 +271,7 @@ struct LowerGpuOpsToNVVMOpsPass
|
||||
{
|
||||
RewritePatternSet patterns(m.getContext());
|
||||
populateGpuRewritePatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(m, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(m, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -271,7 +271,7 @@ struct LowerGpuOpsToROCDLOpsPass
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateGpuRewritePatterns(patterns);
|
||||
arith::populateExpandBFloat16Patterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
|
||||
(void)applyPatternsGreedily(m, std::move(patterns));
|
||||
}
|
||||
|
||||
LLVMTypeConverter converter(ctx, options);
|
||||
|
@ -427,8 +427,7 @@ struct ConvertMeshToMPIPass
|
||||
ConvertProcessLinearIndexOp, ConvertProcessMultiIndexOp>(
|
||||
ctx);
|
||||
|
||||
(void)mlir::applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns));
|
||||
(void)mlir::applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -62,7 +62,7 @@ class ConvertShapeConstraints
|
||||
RewritePatternSet patterns(context);
|
||||
populateConvertShapeConstraintsConversionPatterns(patterns);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(func, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -33,7 +33,7 @@ void ConvertVectorToArmSMEPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorToArmSMEPatterns(patterns, getContext());
|
||||
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::createConvertVectorToArmSMEPass() {
|
||||
|
@ -1326,8 +1326,7 @@ struct ConvertVectorToGPUPass
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
IRRewriter rewriter(&getContext());
|
||||
|
@ -82,7 +82,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
|
||||
populateVectorInsertExtractStridedSliceTransforms(patterns);
|
||||
populateVectorStepLoweringPatterns(patterns);
|
||||
populateVectorRankReducingFMAPattern(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
// Convert to the LLVM IR dialect.
|
||||
|
@ -1730,12 +1730,12 @@ struct ConvertVectorToSCFPass
|
||||
RewritePatternSet lowerTransferPatterns(&getContext());
|
||||
mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
|
||||
lowerTransferPatterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(lowerTransferPatterns));
|
||||
(void)applyPatternsGreedily(getOperation(),
|
||||
std::move(lowerTransferPatterns));
|
||||
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorToSCFConversionPatterns(patterns, options);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -318,8 +318,7 @@ struct ConvertVectorToXeGPUPass
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorToXeGPUConversionPatterns(patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -132,7 +132,7 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
|
||||
static_cast<RewriterBase::Listener *>(rewriter.getListener());
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
// Apply the simplification pattern to a fixpoint.
|
||||
if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
|
||||
if (failed(applyOpPatternsGreedily(targets, frozenPatterns, config))) {
|
||||
auto diag = emitDefiniteFailure()
|
||||
<< "affine.min/max simplification did not converge";
|
||||
return diag;
|
||||
|
@ -239,5 +239,5 @@ void AffineDataCopyGeneration::runOnOperation() {
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
(void)applyOpPatternsAndFold(copyOps, frozenPatterns, config);
|
||||
(void)applyOpPatternsGreedily(copyOps, frozenPatterns, config);
|
||||
}
|
||||
|
@ -198,8 +198,7 @@ public:
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
populateAffineExpandIndexOpsPatterns(patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -79,8 +79,7 @@ public:
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
populateAffineExpandIndexOpsAsAffinePatterns(patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -111,5 +111,5 @@ void SimplifyAffineStructures::runOnOperation() {
|
||||
});
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
(void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config);
|
||||
(void)applyOpPatternsGreedily(opsToSimplify, frozenPatterns, config);
|
||||
}
|
||||
|
@ -318,8 +318,8 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
||||
bool erased;
|
||||
(void)applyOpPatternsAndFold(res.getOperation(), std::move(patterns),
|
||||
config, /*changed=*/nullptr, &erased);
|
||||
(void)applyOpPatternsGreedily(res.getOperation(), std::move(patterns),
|
||||
config, /*changed=*/nullptr, &erased);
|
||||
if (!erased && !prologue)
|
||||
prologue = res;
|
||||
if (!erased)
|
||||
|
@ -425,8 +425,8 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
||||
bool erased;
|
||||
(void)applyOpPatternsAndFold(ifOp.getOperation(), frozenPatterns, config,
|
||||
/*changed=*/nullptr, &erased);
|
||||
(void)applyOpPatternsGreedily(ifOp.getOperation(), frozenPatterns, config,
|
||||
/*changed=*/nullptr, &erased);
|
||||
if (erased) {
|
||||
if (folded)
|
||||
*folded = true;
|
||||
@ -454,7 +454,7 @@ LogicalResult mlir::affine::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
|
||||
|
||||
// Canonicalize to remove dead else blocks (happens whenever an 'if' moves up
|
||||
// a sequence of affine.fors that are all perfectly nested).
|
||||
(void)applyPatternsAndFoldGreedily(
|
||||
(void)applyPatternsGreedily(
|
||||
hoistedIfOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(),
|
||||
frozenPatterns);
|
||||
|
||||
|
@ -489,7 +489,7 @@ struct IntRangeOptimizationsPass final
|
||||
GreedyRewriteConfig config;
|
||||
config.listener = &listener;
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
@ -518,7 +518,7 @@ struct IntRangeNarrowingPass final
|
||||
config.useTopDownTraversal = false;
|
||||
config.listener = &listener;
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns), config)))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -523,8 +523,7 @@ struct OuterProductFusionPass
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateOuterProductFusionPatterns(patterns);
|
||||
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -317,8 +317,7 @@ struct LegalizeVectorStorage
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateLegalizeVectorStoragePatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(patterns)))) {
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
ConversionTarget target(getContext());
|
||||
|
@ -931,7 +931,7 @@ void AsyncParallelForPass::runOnOperation() {
|
||||
[&](ImplicitLocOpBuilder builder, scf::ParallelOp op) {
|
||||
return builder.create<arith::ConstantIndexOp>(minTaskSize);
|
||||
});
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -470,8 +470,8 @@ struct BufferDeallocationSimplificationPass
|
||||
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
|
||||
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
|
||||
config)))
|
||||
if (failed(
|
||||
applyPatternsGreedily(getOperation(), std::move(patterns), config)))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -60,7 +60,7 @@ void EmptyTensorToAllocTensor::runOnOperation() {
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
populateEmptyTensorToAllocTensorPattern(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,7 @@ struct FormExpressionsPass
|
||||
RewritePatternSet patterns(context);
|
||||
populateExpressionPatterns(patterns);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(rootOp, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -227,8 +227,7 @@ struct GpuDecomposeMemrefsPass
|
||||
|
||||
populateGpuDecomposeMemrefsPatterns(patterns);
|
||||
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -630,7 +630,7 @@ class GpuEliminateBarriersPass
|
||||
auto funcOp = getOperation();
|
||||
RewritePatternSet patterns(&getContext());
|
||||
mlir::populateGpuEliminateBarriersPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
|
||||
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -96,7 +96,7 @@ void NVVMOptimizeForTarget::runOnOperation() {
|
||||
MLIRContext *ctx = getOperation()->getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<ExpandDivF16>(ctx);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -3511,7 +3511,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
|
||||
TrackingListener listener(state, *this);
|
||||
GreedyRewriteConfig config;
|
||||
config.listener = &listener;
|
||||
if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns), config)))
|
||||
if (failed(applyPatternsGreedily(target, std::move(patterns), config)))
|
||||
return emitDefaultDefiniteFailure(target);
|
||||
|
||||
results.push_back(target);
|
||||
|
@ -301,7 +301,7 @@ struct LinalgBlockPackMatmul
|
||||
};
|
||||
|
||||
linalg::populateBlockPackMatmulPatterns(patterns, controlFn);
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -563,8 +563,7 @@ struct LinalgDetensorize
|
||||
|
||||
RewritePatternSet canonPatterns(context);
|
||||
tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(),
|
||||
std::move(canonPatterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
|
||||
signalPassFailure();
|
||||
|
||||
// Get rid of the dummy entry block we created in the beginning to work
|
||||
|
@ -831,7 +831,7 @@ struct LinalgFoldUnitExtentDimsPass
|
||||
}
|
||||
linalg::populateFoldUnitExtentDimsPatterns(patterns, options);
|
||||
populateMoveInitOperandsToInputPattern(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
|
||||
(void)applyPatternsGreedily(op, std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -2206,7 +2206,7 @@ struct LinalgElementwiseOpFusionPass
|
||||
// Use TopDownTraversal for compile time reasons
|
||||
GreedyRewriteConfig grc;
|
||||
grc.useTopDownTraversal = true;
|
||||
(void)applyPatternsAndFoldGreedily(op, std::move(patterns), grc);
|
||||
(void)applyPatternsGreedily(op, std::move(patterns), grc);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -89,7 +89,7 @@ struct LinalgGeneralizeNamedOpsPass
|
||||
void LinalgGeneralizeNamedOpsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateLinalgNamedOpsGeneralizationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
|
||||
|
@ -113,7 +113,7 @@ struct LinalgInlineScalarOperandsPass
|
||||
MLIRContext &ctx = getContext();
|
||||
RewritePatternSet patterns(&ctx);
|
||||
populateInlineConstantOperandsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
|
||||
(void)applyPatternsGreedily(op, std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -321,7 +321,7 @@ static void lowerLinalgToLoopsImpl(Operation *enclosingOp) {
|
||||
affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context);
|
||||
patterns.add<FoldAffineOp>(context);
|
||||
// Just apply the patterns greedily.
|
||||
(void)applyPatternsAndFoldGreedily(enclosingOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(enclosingOp, std::move(patterns));
|
||||
}
|
||||
|
||||
struct LowerToAffineLoops
|
||||
|
@ -152,7 +152,7 @@ struct LinalgNamedOpConversionPass
|
||||
Operation *op = getOperation();
|
||||
RewritePatternSet patterns(op->getContext());
|
||||
populateLinalgNamedOpConversionPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -349,7 +349,7 @@ void LinalgSpecializeGenericOpsPass::runOnOperation() {
|
||||
populateLinalgGenericOpsSpecializationPatterns(patterns);
|
||||
populateDecomposeProjectedPermutationPatterns(patterns);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -66,8 +66,7 @@ struct MathUpliftToFMA final
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateUpliftToFMAPatterns(patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -1223,7 +1223,7 @@ struct ExpandStridedMetadataPass final
|
||||
void ExpandStridedMetadataPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
memref::populateExpandStridedMetadataPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> memref::createExpandStridedMetadataPass() {
|
||||
|
@ -857,7 +857,7 @@ struct FoldMemRefAliasOpsPass final
|
||||
void FoldMemRefAliasOpsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
memref::populateFoldMemRefAliasOpPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> memref::createFoldMemRefAliasOpsPass() {
|
||||
|
@ -195,7 +195,7 @@ void memref::populateResolveShapedTypeResultDimsPatterns(
|
||||
void ResolveRankedShapeTypeResultDimsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
@ -203,7 +203,7 @@ void ResolveShapedTypeResultDimsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns);
|
||||
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -112,7 +112,7 @@ struct ForToWhileLoop : public impl::SCFForToWhileLoopBase<ForToWhileLoop> {
|
||||
MLIRContext *ctx = parentOp->getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<ForLoopLoweringPattern>(ctx);
|
||||
(void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(parentOp, std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -167,7 +167,7 @@ struct SCFForLoopCanonicalization
|
||||
MLIRContext *ctx = parentOp->getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
scf::populateSCFForLoopCanonicalizationPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(parentOp, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -331,7 +331,7 @@ struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> {
|
||||
MLIRContext *ctx = parentOp->getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<ForLoopPeelingPattern>(ctx, peelFront, skipPartial);
|
||||
(void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(parentOp, std::move(patterns));
|
||||
|
||||
// Drop the markers.
|
||||
parentOp->walk([](Operation *op) {
|
||||
|
@ -1430,7 +1430,7 @@ SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
|
||||
GreedyRewriteConfig config;
|
||||
config.listener = this;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
return applyOpPatternsAndFold(ops, patterns.value(), config);
|
||||
return applyOpPatternsGreedily(ops, patterns.value(), config);
|
||||
}
|
||||
|
||||
void SliceTrackingListener::notifyOperationInserted(
|
||||
|
@ -29,8 +29,7 @@ public:
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
spirv::populateSPIRVGLCanonicalizationPatterns(patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -1354,7 +1354,7 @@ LogicalResult mlir::spirv::unrollVectorsInSignatures(Operation *op) {
|
||||
// looking for newly created func ops.
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
||||
return applyPatternsAndFoldGreedily(op, std::move(patterns), config);
|
||||
return applyPatternsGreedily(op, std::move(patterns), config);
|
||||
}
|
||||
|
||||
LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
|
||||
@ -1366,7 +1366,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
|
||||
auto options = vector::UnrollVectorOptions().setNativeShapeFn(
|
||||
[](auto op) { return mlir::spirv::getNativeVectorShape(op); });
|
||||
populateVectorUnrollPatterns(patterns, options);
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -1378,7 +1378,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
|
||||
vector::VectorTransposeLowering::EltWise);
|
||||
vector::populateVectorTransposeLoweringPatterns(patterns, options);
|
||||
vector::populateVectorShapeCastLoweringPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||
return failure();
|
||||
}
|
||||
|
||||
@ -1403,7 +1403,7 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
|
||||
vector::BroadcastOp::getCanonicalizationPatterns(patterns, context);
|
||||
vector::ShapeCastOp::getCanonicalizationPatterns(patterns, context);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(patterns))))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
|
@ -236,8 +236,7 @@ struct WebGPUPreparePass final
|
||||
populateSPIRVExpandExtendedMultiplicationPatterns(patterns);
|
||||
populateSPIRVExpandNonFiniteArithmeticPatterns(patterns);
|
||||
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -207,7 +207,7 @@ void OutlineShapeComputationPass::runOnOperation() {
|
||||
MLIRContext *context = funcOp.getContext();
|
||||
RewritePatternSet prevPatterns(context);
|
||||
prevPatterns.insert<TensorDimOpRewriter>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(prevPatterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp, std::move(prevPatterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
// initialize class member `onlyUsedByWithShapes`
|
||||
@ -254,7 +254,7 @@ void OutlineShapeComputationPass::runOnOperation() {
|
||||
}
|
||||
|
||||
// Apply patterns, note this also performs DCE.
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp, {})))
|
||||
if (failed(applyPatternsGreedily(funcOp, {})))
|
||||
return signalPassFailure();
|
||||
});
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ class RemoveShapeConstraintsPass
|
||||
RewritePatternSet patterns(&ctx);
|
||||
populateRemoveShapeConstraintsPatterns(patterns);
|
||||
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -57,7 +57,7 @@ struct SparseAssembler : public impl::SparseAssemblerBase<SparseAssembler> {
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateSparseAssembler(patterns, directOut);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -73,7 +73,7 @@ struct SparseReinterpretMap
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateSparseReinterpretMap(patterns, scope);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -87,7 +87,7 @@ struct PreSparsificationRewritePass
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populatePreSparsificationRewriting(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -110,7 +110,7 @@ struct SparsificationPass
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateSparsificationPatterns(patterns, options);
|
||||
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -122,7 +122,7 @@ struct StageSparseOperationsPass
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateStageSparseOperationsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -141,7 +141,7 @@ struct LowerSparseOpsToForeachPass
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
|
||||
enableConvert);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -154,7 +154,7 @@ struct LowerForeachToSCFPass
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateLowerForeachToSCFPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -329,7 +329,7 @@ struct SparseBufferRewritePass
|
||||
auto *ctx = &getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateSparseBufferRewriting(patterns, enableBufferInitialization);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -351,7 +351,7 @@ struct SparseVectorizationPass
|
||||
populateSparseVectorizationPatterns(
|
||||
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
|
||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -371,7 +371,7 @@ struct SparseGPUCodegenPass
|
||||
populateSparseGPULibgenPatterns(patterns, enableRuntimeLibrary);
|
||||
else
|
||||
populateSparseGPUCodegenPatterns(patterns, numThreads);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -277,7 +277,7 @@ struct FoldTensorSubsetOpsPass final
|
||||
void FoldTensorSubsetOpsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
tensor::populateFoldTensorSubsetOpPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> tensor::createFoldTensorSubsetOpsPass() {
|
||||
|
@ -60,7 +60,7 @@ struct TosaLayerwiseConstantFoldPass
|
||||
aggressiveReduceConstant);
|
||||
populateTosaOpsCanonicalizationPatterns(ctx, patterns);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
|
||||
if (applyPatternsGreedily(func, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -246,7 +246,7 @@ public:
|
||||
patterns.add<ConvertTosaOp<tosa::LogicalXorOp>>(ctx);
|
||||
patterns.add<ConvertTosaOp<tosa::SelectOp>>(ctx);
|
||||
patterns.add<ConvertTosaOp<tosa::PowOp>>(ctx);
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
(void)applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -42,7 +42,7 @@ struct TosaOptionalDecompositions
|
||||
mlir::tosa::populateTosaDecomposeTransposeConv(ctx, patterns);
|
||||
mlir::tosa::populateTosaDecomposeDepthwise(ctx, patterns);
|
||||
|
||||
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())
|
||||
if (applyPatternsGreedily(func, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -417,7 +417,7 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
|
||||
if (target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
|
||||
// Op is isolated from above. Apply patterns and also perform region
|
||||
// simplification.
|
||||
result = applyPatternsAndFoldGreedily(target, frozenPatterns, config);
|
||||
result = applyPatternsGreedily(target, frozenPatterns, config);
|
||||
} else {
|
||||
// Manually gather list of ops because the other
|
||||
// GreedyPatternRewriteDriver overloads only accepts ops that are isolated
|
||||
@ -429,7 +429,7 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
|
||||
if (target != nestedOp)
|
||||
ops.push_back(nestedOp);
|
||||
});
|
||||
result = applyOpPatternsAndFold(ops, frozenPatterns, config);
|
||||
result = applyOpPatternsGreedily(ops, frozenPatterns, config);
|
||||
}
|
||||
|
||||
// A failure typically indicates that the pattern application did not
|
||||
|
@ -286,7 +286,7 @@ struct LowerVectorMaskPass
|
||||
populateVectorMaskLoweringPatternsForSideEffectingOps(loweringPatterns);
|
||||
MaskOp::getCanonicalizationPatterns(loweringPatterns, context);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -486,7 +486,7 @@ struct LowerVectorMultiReductionPass
|
||||
populateVectorMultiReductionLoweringPatterns(loweringPatterns,
|
||||
this->loweringStrategy);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(loweringPatterns))))
|
||||
if (failed(applyPatternsGreedily(op, std::move(loweringPatterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -78,5 +78,5 @@ struct XeGPUFoldAliasOpsPass final
|
||||
void XeGPUFoldAliasOpsPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
xegpu::populateXeGPUFoldAliasOpsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ static void applyPatterns(Region ®ion,
|
||||
// because we don't have expectation this reduction will be success or not.
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
||||
(void)applyOpPatternsAndFold(op, patterns, config);
|
||||
(void)applyOpPatternsGreedily(op, patterns, config);
|
||||
}
|
||||
|
||||
if (eraseOpNotInRange)
|
||||
|
@ -60,7 +60,7 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
|
||||
}
|
||||
void runOnOperation() override {
|
||||
LogicalResult converged =
|
||||
applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
|
||||
applyPatternsGreedily(getOperation(), *patterns, config);
|
||||
// Canonicalization is best-effort. Non-convergence is not a pass failure.
|
||||
if (testConvergence && failed(converged))
|
||||
signalPassFailure();
|
||||
|
@ -6,7 +6,7 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements mlir::applyPatternsAndFoldGreedily.
|
||||
// This file implements mlir::applyPatternsGreedily.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -488,7 +488,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
|
||||
// infinite folding loop, as every constant op would be folded to an
|
||||
// Attribute and then immediately be rematerialized as a constant op, which
|
||||
// is then put on the worklist.
|
||||
if (!op->hasTrait<OpTrait::ConstantLike>()) {
|
||||
if (config.fold && !op->hasTrait<OpTrait::ConstantLike>()) {
|
||||
SmallVector<OpFoldResult> foldResults;
|
||||
if (succeeded(op->fold(foldResults))) {
|
||||
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
|
||||
@ -852,13 +852,13 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
|
||||
if (!config.useTopDownTraversal) {
|
||||
// Add operations to the worklist in postorder.
|
||||
region.walk([&](Operation *op) {
|
||||
if (!insertKnownConstant(op))
|
||||
if (!config.cseConstants || !insertKnownConstant(op))
|
||||
addToWorklist(op);
|
||||
});
|
||||
} else {
|
||||
// Add all nested operations to the worklist in preorder.
|
||||
region.walk<WalkOrder::PreOrder>([&](Operation *op) {
|
||||
if (!insertKnownConstant(op)) {
|
||||
if (!config.cseConstants || !insertKnownConstant(op)) {
|
||||
addToWorklist(op);
|
||||
return WalkResult::advance();
|
||||
}
|
||||
@ -894,9 +894,9 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
mlir::applyPatternsAndFoldGreedily(Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config, bool *changed) {
|
||||
mlir::applyPatternsGreedily(Region ®ion,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config, bool *changed) {
|
||||
// The top-level operation must be known to be isolated from above to
|
||||
// prevent performing canonicalizations on operations defined at or above
|
||||
// the region containing 'op'.
|
||||
@ -1012,7 +1012,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
|
||||
return region;
|
||||
}
|
||||
|
||||
LogicalResult mlir::applyOpPatternsAndFold(
|
||||
LogicalResult mlir::applyOpPatternsGreedily(
|
||||
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config, bool *changed, bool *allErased) {
|
||||
if (ops.empty()) {
|
||||
|
@ -296,7 +296,7 @@ OneToNConversionPattern::matchAndRewrite(Operation *op,
|
||||
namespace mlir {
|
||||
|
||||
// This function applies the provided patterns using
|
||||
// `applyPatternsAndFoldGreedily` and then replaces all newly inserted
|
||||
// `applyPatternsGreedily` and then replaces all newly inserted
|
||||
// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
|
||||
// from target to source types inserted by a `OneToNConversionPattern` normally
|
||||
// fold away with the "forward" casts from source to target types inserted by
|
||||
@ -317,7 +317,7 @@ applyPartialOneToNConversion(Operation *op, TypeConverter &typeConverter,
|
||||
#endif // NDEBUG
|
||||
|
||||
// Apply provided conversion patterns.
|
||||
if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
|
||||
if (failed(applyPatternsGreedily(op, patterns))) {
|
||||
emitError(op->getLoc()) << "failed to apply conversion patterns";
|
||||
return failure();
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-greedy-patterns='cse-constants=false' %s | FileCheck %s --check-prefix=NOCSE
|
||||
// RUN: mlir-opt -test-greedy-patterns='fold=false' %s | FileCheck %s --check-prefix=NOFOLD
|
||||
|
||||
func.func @foo() -> i32 {
|
||||
%c42 = arith.constant 42 : i32
|
||||
@ -25,7 +27,8 @@ func.func @test_fold_before_previously_folded_op() -> (i32, i32) {
|
||||
}
|
||||
|
||||
func.func @test_dont_reorder_constants() -> (i32, i32, i32) {
|
||||
// Test that we don't reorder existing constants during folding if it isn't necessary.
|
||||
// Test that we don't reorder existing constants during folding if it isn't
|
||||
// necessary.
|
||||
// CHECK: %[[CST:.+]] = arith.constant 1
|
||||
// CHECK-NEXT: %[[CST:.+]] = arith.constant 2
|
||||
// CHECK-NEXT: %[[CST:.+]] = arith.constant 3
|
||||
@ -34,3 +37,46 @@ func.func @test_dont_reorder_constants() -> (i32, i32, i32) {
|
||||
%2 = arith.constant 3 : i32
|
||||
return %0, %1, %2 : i32, i32, i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: test_fold_nofold_nocse
|
||||
// NOCSE-LABEL: test_fold_nofold_nocse
|
||||
// NOFOLD-LABEL: test_fold_nofold_nocse
|
||||
func.func @test_fold_nofold_nocse() -> (i32, i32, i32, i32, i32, i32) {
|
||||
// Test either not folding or deduping constants.
|
||||
|
||||
// Testing folding. There should be only 4 constants here.
|
||||
// CHECK-NOT: arith.constant
|
||||
// CHECK-DAG: %[[CST:.+]] = arith.constant 0
|
||||
// CHECK-DAG: %[[CST:.+]] = arith.constant 1
|
||||
// CHECK-DAG: %[[CST:.+]] = arith.constant 2
|
||||
// CHECK-DAG: %[[CST:.+]] = arith.constant 3
|
||||
// CHECK-NOT: arith.constant
|
||||
// CHECK-NEXT: return
|
||||
|
||||
// Testing not-CSE'ing. In this case we have the 3 original constants and 3
|
||||
// produced by folding.
|
||||
// NOCSE-DAG: arith.constant 0 : i32
|
||||
// NOCSE-DAG: arith.constant 1 : i32
|
||||
// NOCSE-DAG: arith.constant 2 : i32
|
||||
// NOCSE-DAG: arith.constant 1 : i32
|
||||
// NOCSE-DAG: arith.constant 2 : i32
|
||||
// NOCSE-DAG: arith.constant 3 : i32
|
||||
// NOCSE-NEXT: return
|
||||
|
||||
// Testing not folding. In this case we just have the original constants.
|
||||
// NOFOLD-DAG: %[[CST:.+]] = arith.constant 0
|
||||
// NOFOLD-DAG: %[[CST:.+]] = arith.constant 1
|
||||
// NOFOLD-DAG: %[[CST:.+]] = arith.constant 2
|
||||
// NOFOLD: arith.addi
|
||||
// NOFOLD: arith.addi
|
||||
// NOFOLD: arith.addi
|
||||
|
||||
%c0 = arith.constant 0 : i32
|
||||
%c1 = arith.constant 1 : i32
|
||||
%c2 = arith.constant 2 : i32
|
||||
%0 = arith.addi %c0, %c1 : i32
|
||||
%1 = arith.addi %0, %c1 : i32
|
||||
%2 = arith.addi %c2, %c1 : i32
|
||||
return %0, %1, %2, %c0, %c1, %c2 : i32, i32, i32, i32, i32, i32
|
||||
}
|
||||
|
||||
|
@ -248,7 +248,7 @@ struct TestMathToVCIX
|
||||
RewritePatternSet patterns(ctx);
|
||||
patterns.add<MathCosToVCIX, MathSinToVCIX, MathTanToVCIX, MathLogToVCIX>(
|
||||
ctx);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -41,7 +41,7 @@ struct TestVectorReductionToSPIRVDotProd
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateVectorReductionToSPIRVDotProductPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -136,7 +136,7 @@ void TestAffineDataCopy::runOnOperation() {
|
||||
}
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
(void)applyOpPatternsAndFold(copyOps, std::move(patterns), config);
|
||||
(void)applyOpPatternsGreedily(copyOps, std::move(patterns), config);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
@ -47,7 +47,7 @@ void TestLowerToArmNeon::runOnOperation() {
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
populateLowerContractionToSMMLAPatternPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,7 @@ struct TestGpuRewritePass
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateGpuRewritePatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
@ -85,7 +85,7 @@ struct TestGpuSubgroupReduceLoweringPass
|
||||
patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
|
||||
}
|
||||
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
@ -34,8 +34,7 @@ struct TestDataLayoutPropagationPass
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateDataLayoutPropagationPatterns(
|
||||
patterns, [](OpOperand *opOperand) { return true; });
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -43,8 +43,8 @@ struct TestLinalgDecomposeOps
|
||||
RewritePatternSet decompositionPatterns(context);
|
||||
linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns,
|
||||
removeDeadArgsAndResults);
|
||||
if (failed(applyPatternsAndFoldGreedily(
|
||||
getOperation(), std::move(decompositionPatterns)))) {
|
||||
if (failed(applyPatternsGreedily(getOperation(),
|
||||
std::move(decompositionPatterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -155,8 +155,8 @@ struct TestLinalgElementwiseFusion
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
auto controlFn = [](OpOperand *operand) { return true; };
|
||||
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -166,8 +166,8 @@ struct TestLinalgElementwiseFusion
|
||||
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
|
||||
setFusedOpOperandLimit<4>);
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -176,8 +176,8 @@ struct TestLinalgElementwiseFusion
|
||||
RewritePatternSet fusionPatterns(context);
|
||||
linalg::populateFoldReshapeOpsByExpansionPatterns(
|
||||
fusionPatterns, [](OpOperand * /*fusedOperand*/) { return true; });
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -212,8 +212,8 @@ struct TestLinalgElementwiseFusion
|
||||
|
||||
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
|
||||
controlReshapeFusionFn);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(),
|
||||
std::move(fusionPatterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -222,8 +222,7 @@ struct TestLinalgElementwiseFusion
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateFoldReshapeOpsByCollapsingPatterns(
|
||||
patterns, [](OpOperand * /*fusedOperand */) { return true; });
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -239,8 +238,7 @@ struct TestLinalgElementwiseFusion
|
||||
return true;
|
||||
};
|
||||
linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -248,8 +246,7 @@ struct TestLinalgElementwiseFusion
|
||||
if (fuseMultiUseProducer) {
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.insert<TestMultiUseProducerFusion>(context);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
@ -265,8 +262,7 @@ struct TestLinalgElementwiseFusion
|
||||
};
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateCollapseDimensions(patterns, collapseFn);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ struct TestLinalgGreedyFusion
|
||||
pm.addPass(createCanonicalizerPass());
|
||||
pm.addPass(createCSEPass());
|
||||
do {
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), frozenPatterns);
|
||||
(void)applyPatternsGreedily(getOperation(), frozenPatterns);
|
||||
if (failed(runPipeline(pm, getOperation())))
|
||||
this->signalPassFailure();
|
||||
} while (succeeded(fuseLinalgOpsGreedily(getOperation())));
|
||||
|
@ -49,8 +49,7 @@ struct TestLinalgRankReduceContractionOps
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateContractionOpRankReducingPatterns(patterns);
|
||||
if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(),
|
||||
std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(funcOp.getBody(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
return;
|
||||
}
|
||||
|
@ -147,14 +147,14 @@ static void applyPatterns(func::FuncOp funcOp) {
|
||||
//===--------------------------------------------------------------------===//
|
||||
patterns.add<CopyVectorizationPattern>(ctx);
|
||||
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
|
||||
RewritePatternSet forwardPattern(funcOp.getContext());
|
||||
forwardPattern.add<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
|
||||
forwardPattern.add<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(forwardPattern));
|
||||
}
|
||||
|
||||
static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
|
||||
@ -163,68 +163,68 @@ static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
|
||||
patterns.add<CopyVectorizationPattern>(ctx);
|
||||
populatePadOpVectorizationPatterns(patterns);
|
||||
populateConvolutionVectorizationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyDecomposePadPatterns(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
patterns.add<DecomposePadOpPattern>(funcOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyDecomposeTensorPackPatterns(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
patterns.add<DecomposeOuterUnitDimsPackOpPattern>(funcOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyDecomposeTensorUnPackPatterns(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(funcOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyExtractSliceOfPadTensorSwapPattern(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
patterns.add<ExtractSliceOfPadTensorSwapPattern>(funcOp.getContext());
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyBubbleUpExtractSliceOpPattern(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
populateBubbleUpExtractSliceOpPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applySwapExtractSliceWithFillPattern(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
populateSwapExtractSliceWithFillPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
populateEraseUnusedOperandsAndResultsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
populateEraseUnnecessaryInputsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyWinogradConv2D(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
populateWinogradConv2DPatterns(patterns, /*m=*/4, /*r=*/3);
|
||||
populateWinogradConv2DPatterns(patterns, /*m=*/2, /*r=*/5);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyDecomposeWinogradOps(func::FuncOp funcOp) {
|
||||
RewritePatternSet patterns(funcOp.getContext());
|
||||
populateDecomposeWinogradOpsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
/// Apply transformations specified as patterns.
|
||||
|
@ -36,8 +36,7 @@ struct TestPadFusionPass
|
||||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
linalg::populateFuseTensorPadWithProducerLinalgOpPatterns(patterns);
|
||||
if (failed(
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
|
||||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
@ -40,7 +40,7 @@ struct TestMathAlgebraicSimplificationPass
|
||||
void TestMathAlgebraicSimplificationPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateMathAlgebraicSimplificationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
@ -53,7 +53,7 @@ void TestExpandMathPass::runOnOperation() {
|
||||
populateExpandRoundFPattern(patterns);
|
||||
populateExpandRoundEvenPattern(patterns);
|
||||
populateExpandRsqrtPattern(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
@ -59,7 +59,7 @@ void TestMathPolynomialApproximationPass::runOnOperation() {
|
||||
MathPolynomialApproximationOptions approxOptions;
|
||||
approxOptions.enableAvx2 = enableAvx2;
|
||||
populateMathPolynomialApproximationPatterns(patterns, approxOptions);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
@ -38,7 +38,7 @@ void TestComposeSubViewPass::getDependentDialects(
|
||||
void TestComposeSubViewPass::runOnOperation() {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
memref::populateComposeSubViewPatterns(patterns, &getContext());
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
|
@ -26,9 +26,9 @@ struct TestAllSliceOpLoweringPass
|
||||
SymbolTableCollection symbolTableCollection;
|
||||
mesh::populateAllSliceOpLoweringPatterns(patterns, symbolTableCollection);
|
||||
LogicalResult status =
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
(void)status;
|
||||
assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
|
||||
assert(succeeded(status) && "applyPatternsGreedily failed.");
|
||||
}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
mesh::registerAllSliceOpLoweringDialects(registry);
|
||||
@ -51,9 +51,9 @@ struct TestMultiIndexOpLoweringPass
|
||||
mesh::populateProcessMultiIndexOpLoweringPatterns(patterns,
|
||||
symbolTableCollection);
|
||||
LogicalResult status =
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
(void)status;
|
||||
assert(succeeded(status) && "applyPatternsAndFoldGreedily failed.");
|
||||
assert(succeeded(status) && "applyPatternsGreedily failed.");
|
||||
}
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
mesh::registerProcessMultiIndexOpLoweringDialects(registry);
|
||||
|
@ -97,8 +97,8 @@ struct TestMeshReshardingPass
|
||||
void runOnOperation() override {
|
||||
RewritePatternSet patterns(&getContext());
|
||||
patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
|
||||
if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
|
||||
std::move(patterns)))) {
|
||||
if (failed(applyPatternsGreedily(getOperation().getOperation(),
|
||||
std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ void TestMeshSimplificationsPass::runOnOperation() {
|
||||
SymbolTableCollection symbolTableCollection;
|
||||
mesh::populateSimplificationPatterns(patterns, symbolTableCollection);
|
||||
[[maybe_unused]] LogicalResult status =
|
||||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
assert(succeeded(status) && "Rewrite patters application did not converge.");
|
||||
}
|
||||
|
||||
|
@ -60,7 +60,7 @@ struct TestMmaSyncF32ToTF32Patterns
|
||||
RewritePatternSet patterns(&getContext());
|
||||
|
||||
populateMmaSyncF32ToTF32Patterns(patterns, tf32Precision);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -226,7 +226,7 @@ struct TestSCFPipeliningPass
|
||||
options.peelEpilogue = false;
|
||||
}
|
||||
scf::populateSCFLoopPipeliningPatterns(patterns, options);
|
||||
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
|
||||
getOperation().walk([](Operation *op) {
|
||||
// Clean up the markers.
|
||||
op->removeAttr(kTestPipeliningStageMarker);
|
||||
|
@ -59,7 +59,7 @@ struct TestWrapWhileLoopInZeroTripCheckPass
|
||||
} else {
|
||||
RewritePatternSet patterns(context);
|
||||
scf::populateSCFRotateWhileLoopPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
|
||||
(void)applyPatternsGreedily(func, std::move(patterns));
|
||||
}
|
||||
}
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user