[mlir] GreedyPatternRewriteDriver: Move strict mode to GreedyPatternRewriteDriver

`strictMode` is moved to GreedyRewriteConfig to simplify the API and state of rewriter classes. The region-based GreedyPatternRewriteDriver now also supports strict mode.

MultiOpPatternRewriteDriver becomes simpler: fewer method must be overridden.

Differential Revision: https://reviews.llvm.org/D142623
This commit is contained in:
Matthias Springer 2023-01-27 15:44:12 +01:00
parent 72121a20cd
commit 6bdecbcb99
10 changed files with 80 additions and 79 deletions

View File

@ -63,6 +63,18 @@ public:
/// Only ops within the scope are added to the worklist. If no scope is
/// specified, the closest enclosing region is used as a scope.
Region *scope = nullptr;
/// Strict mode can restrict the ops that are added to the worklist during
/// the rewrite.
///
/// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing ops (that
/// were on the worklist at the very beginning) and newly created ops are
/// enqueued. All other ops are excluded.
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
/// were on the worklist at the very beginning) enqueued. All other ops are
/// excluded.
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
};
//===----------------------------------------------------------------------===//
@ -105,14 +117,8 @@ inline LogicalResult applyPatternsAndFoldGreedily(
///
/// Newly created ops and other pre-existing ops that use results of rewritten
/// ops or supply operands to such ops are simplified, unless such ops are
/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless
/// of `strictMode`).
///
/// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly
/// created ops are simplified. All other ops are excluded.
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are
/// simplified. All other ops are excluded.
/// excluded via `config.strictMode`. Any other ops remain unmodified (i.e.,
/// regardless of `strictMode`).
///
/// In addition to strictness, a region scope can be specified. Only ops within
/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
@ -130,23 +136,17 @@ inline LogicalResult applyPatternsAndFoldGreedily(
LogicalResult
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
const FrozenRewritePatternSet &patterns,
GreedyRewriteStrictness strictMode,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *changed = nullptr, bool *allErased = nullptr);
/// Applies the specified patterns on `op` alone while also trying to fold it,
/// by selecting the highest benefits patterns in a greedy manner. Returns
/// success if no more patterns can be matched. `erased` is set to true if `op`
/// was folded away or erased as a result of becoming dead.
///
/// Returns success if the iterative process converged and no more patterns can
/// be matched.
/// Applies the specified patterns on `op` while also trying to fold it.
/// This function is a shortcut for the ArrayRef<Operation *> overload and
/// behaves the same way.
inline LogicalResult
applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
GreedyRewriteConfig config = GreedyRewriteConfig(),
bool *erased = nullptr) {
return applyOpPatternsAndFold(ArrayRef(op), patterns,
GreedyRewriteStrictness::ExistingOps, config,
return applyOpPatternsAndFold(ArrayRef(op), patterns, config,
/*changed=*/nullptr, erased);
}

View File

@ -130,10 +130,10 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
// Apply the simplification pattern to a fixpoint.
if (failed(
applyOpPatternsAndFold(targets, frozenPatterns,
GreedyRewriteStrictness::ExistingAndNewOps))) {
if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
auto diag = emitDefiniteFailure()
<< "affine.min/max simplification did not converge";
return diag;

View File

@ -239,6 +239,7 @@ void AffineDataCopyGeneration::runOnOperation() {
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
(void)applyOpPatternsAndFold(copyOps, frozenPatterns,
GreedyRewriteStrictness::ExistingAndNewOps);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsAndFold(copyOps, frozenPatterns, config);
}

View File

@ -105,6 +105,7 @@ void SimplifyAffineStructures::runOnOperation() {
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
opsToSimplify.push_back(op);
});
(void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns,
GreedyRewriteStrictness::ExistingAndNewOps);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config);
}

View File

@ -321,9 +321,10 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
// Simplify/canonicalize the affine.for.
RewritePatternSet patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsAndFold(res, std::move(patterns),
GreedyRewriteConfig(), &erased);
(void)applyOpPatternsAndFold(res, std::move(patterns), config, &erased);
if (!erased && !prologue)
prologue = res;
if (!erased)

View File

@ -413,10 +413,11 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
// in which case we return with `folded` being set.
RewritePatternSet patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
bool erased;
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
(void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(),
&erased);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
bool erased;
(void)applyOpPatternsAndFold(ifOp, frozenPatterns, config, &erased);
if (erased) {
if (folded)
*folded = true;

View File

@ -60,10 +60,13 @@ static void applyPatterns(Region &region,
// matching in above iteration. Besides, erase op not-in-range may end up in
// invalid module, so `applyOpPatternsAndFold` should come before that
// transform.
for (Operation *op : opsInRange)
for (Operation *op : opsInRange) {
// `applyOpPatternsAndFold` returns whether the op is convered. Omit it
// because we don't have expectation this reduction will be success or not.
(void)applyOpPatternsAndFold(op, patterns);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
(void)applyOpPatternsAndFold(op, patterns, config);
}
if (eraseOpNotInRange)
for (Operation *op : opsNotInRange) {

View File

@ -59,7 +59,7 @@ public:
protected:
/// Add the given operation to the worklist.
virtual void addSingleOpToWorklist(Operation *op);
void addSingleOpToWorklist(Operation *op);
// Implement the hook for inserting operations, and make sure that newly
// inserted ops are added to the worklist for processing.
@ -102,6 +102,12 @@ protected:
/// Configuration information for how to simplify.
const GreedyRewriteConfig config;
/// The list of ops we are restricting our rewrites to. These include the
/// supplied set of ops as well as new ops created while rewriting those ops
/// depending on `strictMode`. This set is not maintained when
/// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
private:
#ifndef NDEBUG
/// A logger used to emit information during the application process.
@ -150,6 +156,12 @@ bool GreedyPatternRewriteDriver::simplify(Region &region) && {
return false;
};
// Populate strict mode ops.
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
strictModeFilteredOps.clear();
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
}
bool changed = false;
int64_t iteration = 0;
do {
@ -323,12 +335,15 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
}
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
// Check to see if the worklist already contains this op.
if (worklistMap.count(op))
return;
if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op)) {
// Check to see if the worklist already contains this op.
if (worklistMap.count(op))
return;
worklistMap[op] = worklist.size();
worklist.push_back(op);
worklistMap[op] = worklist.size();
worklist.push_back(op);
}
}
Operation *GreedyPatternRewriteDriver::popFromWorklist() {
@ -355,6 +370,8 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
<< ")\n";
});
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
addToWorklist(op);
}
@ -391,6 +408,9 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
removeFromWorklist(operation);
folder.notifyRemoval(operation);
});
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
@ -459,10 +479,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
const GreedyRewriteConfig &config,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
: GreedyPatternRewriteDriver(ctx, patterns, config),
strictMode(strictMode), survivingOps(survivingOps) {}
survivingOps(survivingOps) {}
/// Performs the specified rewrites on `ops` while also trying to fold these
/// ops. `strictMode` controls which other ops are simplified. Only ops
@ -476,38 +496,13 @@ public:
LogicalResult simplifyLocally(ArrayRef<Operation *> op,
bool *changed = nullptr) &&;
protected:
void addSingleOpToWorklist(Operation *op) override {
if (strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op))
GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
}
private:
void notifyOperationInserted(Operation *op) override {
if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
strictModeFilteredOps.insert(op);
GreedyPatternRewriteDriver::notifyOperationInserted(op);
}
void notifyOperationRemoved(Operation *op) override {
GreedyPatternRewriteDriver::notifyOperationRemoved(op);
if (survivingOps)
survivingOps->erase(op);
if (strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.erase(op);
}
/// `strictMode` control which ops are added to the worklist during
/// simplification.
const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
/// The list of ops we are restricting our rewrites to. These include the
/// supplied set of ops as well as new ops created while rewriting those ops
/// depending on `strictMode`. This set is not maintained when `strictMode`
/// is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
/// An optional set of ops that survived the rewrite. This set is populated
/// at the beginning of `simplifyLocally` with the inititally provided list
/// of ops.
@ -524,7 +519,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
survivingOps->insert(ops.begin(), ops.end());
}
if (strictMode != GreedyRewriteStrictness::AnyOp) {
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
strictModeFilteredOps.clear();
strictModeFilteredOps.insert(ops.begin(), ops.end());
}
@ -549,7 +544,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
if (op == nullptr)
continue;
assert((strictMode == GreedyRewriteStrictness::AnyOp ||
assert((config.strictMode == GreedyRewriteStrictness::AnyOp ||
strictModeFilteredOps.contains(op)) &&
"unexpected op was inserted under strict mode");
@ -637,8 +632,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
LogicalResult mlir::applyOpPatternsAndFold(
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
bool *changed, bool *allErased) {
GreedyRewriteConfig config, bool *changed, bool *allErased) {
if (ops.empty()) {
if (changed)
*changed = false;
@ -664,8 +658,7 @@ LogicalResult mlir::applyOpPatternsAndFold(
// Start the pattern driver.
llvm::SmallDenseSet<Operation *, 4> surviving;
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
strictMode, config,
allErased ? &surviving : nullptr);
config, allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
if (allErased)
*allErased = surviving.empty();

View File

@ -132,8 +132,9 @@ void TestAffineDataCopy::runOnOperation() {
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
}
}
(void)applyOpPatternsAndFold(copyOps, std::move(patterns),
GreedyRewriteStrictness::ExistingAndNewOps);
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
(void)applyOpPatternsAndFold(copyOps, std::move(patterns), config);
}
namespace mlir {

View File

@ -266,13 +266,13 @@ public:
}
});
GreedyRewriteStrictness mode;
GreedyRewriteConfig config;
if (strictMode == "AnyOp") {
mode = GreedyRewriteStrictness::AnyOp;
config.strictMode = GreedyRewriteStrictness::AnyOp;
} else if (strictMode == "ExistingAndNewOps") {
mode = GreedyRewriteStrictness::ExistingAndNewOps;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
} else if (strictMode == "ExistingOps") {
mode = GreedyRewriteStrictness::ExistingOps;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
} else {
llvm_unreachable("invalid strictness option");
}
@ -282,8 +282,8 @@ public:
// operation will trigger the assertion while processing.
bool changed = false;
bool allErased = false;
(void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
GreedyRewriteConfig(), &changed, &allErased);
(void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
&changed, &allErased);
Builder b(ctx);
getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
getOperation()->setAttr("pattern_driver_all_erased",