[mlir] GreedyPatternRewriteDriver: Move strict mode to GreedyPatternRewriteDriver
`strictMode` is moved to GreedyRewriteConfig to simplify the API and state of rewriter classes. The region-based GreedyPatternRewriteDriver now also supports strict mode. MultiOpPatternRewriteDriver becomes simpler: fewer method must be overridden. Differential Revision: https://reviews.llvm.org/D142623
This commit is contained in:
parent
72121a20cd
commit
6bdecbcb99
@ -63,6 +63,18 @@ public:
|
||||
/// Only ops within the scope are added to the worklist. If no scope is
|
||||
/// specified, the closest enclosing region is used as a scope.
|
||||
Region *scope = nullptr;
|
||||
|
||||
/// Strict mode can restrict the ops that are added to the worklist during
|
||||
/// the rewrite.
|
||||
///
|
||||
/// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
|
||||
/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing ops (that
|
||||
/// were on the worklist at the very beginning) and newly created ops are
|
||||
/// enqueued. All other ops are excluded.
|
||||
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops (that were
|
||||
/// were on the worklist at the very beginning) enqueued. All other ops are
|
||||
/// excluded.
|
||||
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -105,14 +117,8 @@ inline LogicalResult applyPatternsAndFoldGreedily(
|
||||
///
|
||||
/// Newly created ops and other pre-existing ops that use results of rewritten
|
||||
/// ops or supply operands to such ops are simplified, unless such ops are
|
||||
/// excluded via `strictMode`. Any other ops remain unmodified (i.e., regardless
|
||||
/// of `strictMode`).
|
||||
///
|
||||
/// * GreedyRewriteStrictness::AnyOp: No ops are excluded.
|
||||
/// * GreedyRewriteStrictness::ExistingAndNewOps: Only pre-existing and newly
|
||||
/// created ops are simplified. All other ops are excluded.
|
||||
/// * GreedyRewriteStrictness::ExistingOps: Only pre-existing ops are
|
||||
/// simplified. All other ops are excluded.
|
||||
/// excluded via `config.strictMode`. Any other ops remain unmodified (i.e.,
|
||||
/// regardless of `strictMode`).
|
||||
///
|
||||
/// In addition to strictness, a region scope can be specified. Only ops within
|
||||
/// the scope are simplified. This is similar to `applyPatternsAndFoldGreedily`,
|
||||
@ -130,23 +136,17 @@ inline LogicalResult applyPatternsAndFoldGreedily(
|
||||
LogicalResult
|
||||
applyOpPatternsAndFold(ArrayRef<Operation *> ops,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteStrictness strictMode,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *changed = nullptr, bool *allErased = nullptr);
|
||||
|
||||
/// Applies the specified patterns on `op` alone while also trying to fold it,
|
||||
/// by selecting the highest benefits patterns in a greedy manner. Returns
|
||||
/// success if no more patterns can be matched. `erased` is set to true if `op`
|
||||
/// was folded away or erased as a result of becoming dead.
|
||||
///
|
||||
/// Returns success if the iterative process converged and no more patterns can
|
||||
/// be matched.
|
||||
/// Applies the specified patterns on `op` while also trying to fold it.
|
||||
/// This function is a shortcut for the ArrayRef<Operation *> overload and
|
||||
/// behaves the same way.
|
||||
inline LogicalResult
|
||||
applyOpPatternsAndFold(Operation *op, const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteConfig config = GreedyRewriteConfig(),
|
||||
bool *erased = nullptr) {
|
||||
return applyOpPatternsAndFold(ArrayRef(op), patterns,
|
||||
GreedyRewriteStrictness::ExistingOps, config,
|
||||
return applyOpPatternsAndFold(ArrayRef(op), patterns, config,
|
||||
/*changed=*/nullptr, erased);
|
||||
}
|
||||
|
||||
|
||||
@ -130,10 +130,10 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results,
|
||||
patterns.insert<SimplifyAffineMinMaxOp<AffineMinOp>,
|
||||
SimplifyAffineMinMaxOp<AffineMaxOp>>(getContext(), cstr);
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
// Apply the simplification pattern to a fixpoint.
|
||||
if (failed(
|
||||
applyOpPatternsAndFold(targets, frozenPatterns,
|
||||
GreedyRewriteStrictness::ExistingAndNewOps))) {
|
||||
if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) {
|
||||
auto diag = emitDefiniteFailure()
|
||||
<< "affine.min/max simplification did not converge";
|
||||
return diag;
|
||||
|
||||
@ -239,6 +239,7 @@ void AffineDataCopyGeneration::runOnOperation() {
|
||||
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
|
||||
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
(void)applyOpPatternsAndFold(copyOps, frozenPatterns,
|
||||
GreedyRewriteStrictness::ExistingAndNewOps);
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
(void)applyOpPatternsAndFold(copyOps, frozenPatterns, config);
|
||||
}
|
||||
|
||||
@ -105,6 +105,7 @@ void SimplifyAffineStructures::runOnOperation() {
|
||||
if (isa<AffineForOp, AffineIfOp, AffineApplyOp>(op))
|
||||
opsToSimplify.push_back(op);
|
||||
});
|
||||
(void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns,
|
||||
GreedyRewriteStrictness::ExistingAndNewOps);
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
(void)applyOpPatternsAndFold(opsToSimplify, frozenPatterns, config);
|
||||
}
|
||||
|
||||
@ -321,9 +321,10 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
|
||||
// Simplify/canonicalize the affine.for.
|
||||
RewritePatternSet patterns(res.getContext());
|
||||
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
||||
bool erased;
|
||||
(void)applyOpPatternsAndFold(res, std::move(patterns),
|
||||
GreedyRewriteConfig(), &erased);
|
||||
(void)applyOpPatternsAndFold(res, std::move(patterns), config, &erased);
|
||||
if (!erased && !prologue)
|
||||
prologue = res;
|
||||
if (!erased)
|
||||
|
||||
@ -413,10 +413,11 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
|
||||
// in which case we return with `folded` being set.
|
||||
RewritePatternSet patterns(ifOp.getContext());
|
||||
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
|
||||
bool erased;
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
(void)applyOpPatternsAndFold(ifOp, frozenPatterns, GreedyRewriteConfig(),
|
||||
&erased);
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
||||
bool erased;
|
||||
(void)applyOpPatternsAndFold(ifOp, frozenPatterns, config, &erased);
|
||||
if (erased) {
|
||||
if (folded)
|
||||
*folded = true;
|
||||
|
||||
@ -60,10 +60,13 @@ static void applyPatterns(Region ®ion,
|
||||
// matching in above iteration. Besides, erase op not-in-range may end up in
|
||||
// invalid module, so `applyOpPatternsAndFold` should come before that
|
||||
// transform.
|
||||
for (Operation *op : opsInRange)
|
||||
for (Operation *op : opsInRange) {
|
||||
// `applyOpPatternsAndFold` returns whether the op is convered. Omit it
|
||||
// because we don't have expectation this reduction will be success or not.
|
||||
(void)applyOpPatternsAndFold(op, patterns);
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
||||
(void)applyOpPatternsAndFold(op, patterns, config);
|
||||
}
|
||||
|
||||
if (eraseOpNotInRange)
|
||||
for (Operation *op : opsNotInRange) {
|
||||
|
||||
@ -59,7 +59,7 @@ public:
|
||||
|
||||
protected:
|
||||
/// Add the given operation to the worklist.
|
||||
virtual void addSingleOpToWorklist(Operation *op);
|
||||
void addSingleOpToWorklist(Operation *op);
|
||||
|
||||
// Implement the hook for inserting operations, and make sure that newly
|
||||
// inserted ops are added to the worklist for processing.
|
||||
@ -102,6 +102,12 @@ protected:
|
||||
/// Configuration information for how to simplify.
|
||||
const GreedyRewriteConfig config;
|
||||
|
||||
/// The list of ops we are restricting our rewrites to. These include the
|
||||
/// supplied set of ops as well as new ops created while rewriting those ops
|
||||
/// depending on `strictMode`. This set is not maintained when
|
||||
/// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
|
||||
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
|
||||
|
||||
private:
|
||||
#ifndef NDEBUG
|
||||
/// A logger used to emit information during the application process.
|
||||
@ -150,6 +156,12 @@ bool GreedyPatternRewriteDriver::simplify(Region ®ion) && {
|
||||
return false;
|
||||
};
|
||||
|
||||
// Populate strict mode ops.
|
||||
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
|
||||
strictModeFilteredOps.clear();
|
||||
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
|
||||
}
|
||||
|
||||
bool changed = false;
|
||||
int64_t iteration = 0;
|
||||
do {
|
||||
@ -323,12 +335,15 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
|
||||
}
|
||||
|
||||
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
|
||||
// Check to see if the worklist already contains this op.
|
||||
if (worklistMap.count(op))
|
||||
return;
|
||||
if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
|
||||
strictModeFilteredOps.contains(op)) {
|
||||
// Check to see if the worklist already contains this op.
|
||||
if (worklistMap.count(op))
|
||||
return;
|
||||
|
||||
worklistMap[op] = worklist.size();
|
||||
worklist.push_back(op);
|
||||
worklistMap[op] = worklist.size();
|
||||
worklist.push_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
Operation *GreedyPatternRewriteDriver::popFromWorklist() {
|
||||
@ -355,6 +370,8 @@ void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
|
||||
logger.startLine() << "** Insert : '" << op->getName() << "'(" << op
|
||||
<< ")\n";
|
||||
});
|
||||
if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
|
||||
strictModeFilteredOps.insert(op);
|
||||
addToWorklist(op);
|
||||
}
|
||||
|
||||
@ -391,6 +408,9 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
|
||||
removeFromWorklist(operation);
|
||||
folder.notifyRemoval(operation);
|
||||
});
|
||||
|
||||
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
|
||||
strictModeFilteredOps.erase(op);
|
||||
}
|
||||
|
||||
void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op,
|
||||
@ -459,10 +479,10 @@ class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
|
||||
public:
|
||||
explicit MultiOpPatternRewriteDriver(
|
||||
MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteStrictness strictMode, const GreedyRewriteConfig &config,
|
||||
const GreedyRewriteConfig &config,
|
||||
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr)
|
||||
: GreedyPatternRewriteDriver(ctx, patterns, config),
|
||||
strictMode(strictMode), survivingOps(survivingOps) {}
|
||||
survivingOps(survivingOps) {}
|
||||
|
||||
/// Performs the specified rewrites on `ops` while also trying to fold these
|
||||
/// ops. `strictMode` controls which other ops are simplified. Only ops
|
||||
@ -476,38 +496,13 @@ public:
|
||||
LogicalResult simplifyLocally(ArrayRef<Operation *> op,
|
||||
bool *changed = nullptr) &&;
|
||||
|
||||
protected:
|
||||
void addSingleOpToWorklist(Operation *op) override {
|
||||
if (strictMode == GreedyRewriteStrictness::AnyOp ||
|
||||
strictModeFilteredOps.contains(op))
|
||||
GreedyPatternRewriteDriver::addSingleOpToWorklist(op);
|
||||
}
|
||||
|
||||
private:
|
||||
void notifyOperationInserted(Operation *op) override {
|
||||
if (strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
|
||||
strictModeFilteredOps.insert(op);
|
||||
GreedyPatternRewriteDriver::notifyOperationInserted(op);
|
||||
}
|
||||
|
||||
void notifyOperationRemoved(Operation *op) override {
|
||||
GreedyPatternRewriteDriver::notifyOperationRemoved(op);
|
||||
if (survivingOps)
|
||||
survivingOps->erase(op);
|
||||
if (strictMode != GreedyRewriteStrictness::AnyOp)
|
||||
strictModeFilteredOps.erase(op);
|
||||
}
|
||||
|
||||
/// `strictMode` control which ops are added to the worklist during
|
||||
/// simplification.
|
||||
const GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
|
||||
|
||||
/// The list of ops we are restricting our rewrites to. These include the
|
||||
/// supplied set of ops as well as new ops created while rewriting those ops
|
||||
/// depending on `strictMode`. This set is not maintained when `strictMode`
|
||||
/// is GreedyRewriteStrictness::AnyOp.
|
||||
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
|
||||
|
||||
/// An optional set of ops that survived the rewrite. This set is populated
|
||||
/// at the beginning of `simplifyLocally` with the inititally provided list
|
||||
/// of ops.
|
||||
@ -524,7 +519,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
|
||||
survivingOps->insert(ops.begin(), ops.end());
|
||||
}
|
||||
|
||||
if (strictMode != GreedyRewriteStrictness::AnyOp) {
|
||||
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
|
||||
strictModeFilteredOps.clear();
|
||||
strictModeFilteredOps.insert(ops.begin(), ops.end());
|
||||
}
|
||||
@ -549,7 +544,7 @@ MultiOpPatternRewriteDriver::simplifyLocally(ArrayRef<Operation *> ops,
|
||||
if (op == nullptr)
|
||||
continue;
|
||||
|
||||
assert((strictMode == GreedyRewriteStrictness::AnyOp ||
|
||||
assert((config.strictMode == GreedyRewriteStrictness::AnyOp ||
|
||||
strictModeFilteredOps.contains(op)) &&
|
||||
"unexpected op was inserted under strict mode");
|
||||
|
||||
@ -637,8 +632,7 @@ static Region *findCommonAncestor(ArrayRef<Operation *> ops) {
|
||||
|
||||
LogicalResult mlir::applyOpPatternsAndFold(
|
||||
ArrayRef<Operation *> ops, const FrozenRewritePatternSet &patterns,
|
||||
GreedyRewriteStrictness strictMode, GreedyRewriteConfig config,
|
||||
bool *changed, bool *allErased) {
|
||||
GreedyRewriteConfig config, bool *changed, bool *allErased) {
|
||||
if (ops.empty()) {
|
||||
if (changed)
|
||||
*changed = false;
|
||||
@ -664,8 +658,7 @@ LogicalResult mlir::applyOpPatternsAndFold(
|
||||
// Start the pattern driver.
|
||||
llvm::SmallDenseSet<Operation *, 4> surviving;
|
||||
MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
|
||||
strictMode, config,
|
||||
allErased ? &surviving : nullptr);
|
||||
config, allErased ? &surviving : nullptr);
|
||||
LogicalResult converged = std::move(driver).simplifyLocally(ops, changed);
|
||||
if (allErased)
|
||||
*allErased = surviving.empty();
|
||||
|
||||
@ -132,8 +132,9 @@ void TestAffineDataCopy::runOnOperation() {
|
||||
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
|
||||
}
|
||||
}
|
||||
(void)applyOpPatternsAndFold(copyOps, std::move(patterns),
|
||||
GreedyRewriteStrictness::ExistingAndNewOps);
|
||||
GreedyRewriteConfig config;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
(void)applyOpPatternsAndFold(copyOps, std::move(patterns), config);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -266,13 +266,13 @@ public:
|
||||
}
|
||||
});
|
||||
|
||||
GreedyRewriteStrictness mode;
|
||||
GreedyRewriteConfig config;
|
||||
if (strictMode == "AnyOp") {
|
||||
mode = GreedyRewriteStrictness::AnyOp;
|
||||
config.strictMode = GreedyRewriteStrictness::AnyOp;
|
||||
} else if (strictMode == "ExistingAndNewOps") {
|
||||
mode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
|
||||
} else if (strictMode == "ExistingOps") {
|
||||
mode = GreedyRewriteStrictness::ExistingOps;
|
||||
config.strictMode = GreedyRewriteStrictness::ExistingOps;
|
||||
} else {
|
||||
llvm_unreachable("invalid strictness option");
|
||||
}
|
||||
@ -282,8 +282,8 @@ public:
|
||||
// operation will trigger the assertion while processing.
|
||||
bool changed = false;
|
||||
bool allErased = false;
|
||||
(void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), mode,
|
||||
GreedyRewriteConfig(), &changed, &allErased);
|
||||
(void)applyOpPatternsAndFold(ArrayRef(ops), std::move(patterns), config,
|
||||
&changed, &allErased);
|
||||
Builder b(ctx);
|
||||
getOperation()->setAttr("pattern_driver_changed", b.getBoolAttr(changed));
|
||||
getOperation()->setAttr("pattern_driver_all_erased",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user