[MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion (#154037)
This is in preparation of a follow-up change to stop traversing unreachable blocks. This is not NFC because of a subtlety of the early_inc. On a test case like: ``` scf.if %cond { "test.move_after_parent_op"() ({ "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> () }) : () -> () } ``` We recursively traverse the nested regions, and process an op when the region is done (post-order). We need to pre-increment the iterator before processing an operation in case it gets deleted. However we can do this before or after processing the nested region. This implementation does the latter.
This commit is contained in:
parent
a0f325bd41
commit
16aa283344
@ -13,12 +13,14 @@
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/DebugLog.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#define DEBUG_TYPE "walk-rewriter"
|
||||
@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op,
|
||||
PatternApplicator applicator(patterns);
|
||||
applicator.applyDefaultCostModel();
|
||||
|
||||
// Iterator on all reachable operations in the region.
|
||||
// Also keep track if we visited the nested regions of the current op
|
||||
// already to drive the post-order traversal.
|
||||
struct RegionReachableOpIterator {
|
||||
RegionReachableOpIterator(Region *region) : region(region) {
|
||||
regionIt = region->begin();
|
||||
if (regionIt != region->end())
|
||||
blockIt = regionIt->begin();
|
||||
}
|
||||
// Advance the iterator to the next reachable operation.
|
||||
void advance() {
|
||||
assert(regionIt != region->end());
|
||||
hasVisitedRegions = false;
|
||||
if (blockIt == regionIt->end()) {
|
||||
++regionIt;
|
||||
if (regionIt != region->end())
|
||||
blockIt = regionIt->begin();
|
||||
return;
|
||||
}
|
||||
++blockIt;
|
||||
if (blockIt != regionIt->end()) {
|
||||
LDBG() << "Incrementing block iterator, next op: "
|
||||
<< OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
|
||||
}
|
||||
}
|
||||
// The region we're iterating over.
|
||||
Region *region;
|
||||
// The Block currently being iterated over.
|
||||
Region::iterator regionIt;
|
||||
// The Operation currently being iterated over.
|
||||
Block::iterator blockIt;
|
||||
// Whether we've visited the nested regions of the current op already.
|
||||
bool hasVisitedRegions = false;
|
||||
};
|
||||
|
||||
// Worklist of regions to visit to drive the post-order traversal.
|
||||
SmallVector<RegionReachableOpIterator> worklist;
|
||||
|
||||
LDBG() << "Starting walk-based pattern rewrite driver";
|
||||
ctx->executeAction<WalkAndApplyPatternsAction>(
|
||||
[&] {
|
||||
// Perform a post-order traversal of the regions, visiting each
|
||||
// reachable operation.
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
region.walk([&](Operation *visitedOp) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
|
||||
llvm::dbgs(), OpPrintingFlags().skipRegions());
|
||||
llvm::dbgs() << "\n";);
|
||||
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
erasedListener.visitedOp = visitedOp;
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
|
||||
assert(worklist.empty());
|
||||
if (region.empty())
|
||||
continue;
|
||||
|
||||
// Prime the worklist with the entry block of this region.
|
||||
worklist.push_back({®ion});
|
||||
while (!worklist.empty()) {
|
||||
RegionReachableOpIterator &it = worklist.back();
|
||||
if (it.regionIt == it.region->end()) {
|
||||
// We're done with this region.
|
||||
worklist.pop_back();
|
||||
continue;
|
||||
}
|
||||
});
|
||||
if (it.blockIt == it.regionIt->end()) {
|
||||
// We're done with this block.
|
||||
it.advance();
|
||||
continue;
|
||||
}
|
||||
Operation *op = &*it.blockIt;
|
||||
// If we haven't visited the nested regions of this op yet,
|
||||
// enqueue them.
|
||||
if (!it.hasVisitedRegions) {
|
||||
it.hasVisitedRegions = true;
|
||||
for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
|
||||
if (nestedRegion.empty())
|
||||
continue;
|
||||
worklist.push_back({&nestedRegion});
|
||||
}
|
||||
}
|
||||
// If we're not at the back of the worklist, we've enqueued some
|
||||
// nested region for processing. We'll come back to this op later
|
||||
// (post-order)
|
||||
if (&it != &worklist.back())
|
||||
continue;
|
||||
|
||||
// Preemptively increment the iterator, in case the current op
|
||||
// would be erased.
|
||||
it.advance();
|
||||
|
||||
LDBG() << "Visiting op: "
|
||||
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
|
||||
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
erasedListener.visitedOp = op;
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
|
||||
LDBG() << "\tOp matched and rewritten";
|
||||
}
|
||||
}
|
||||
},
|
||||
{op});
|
||||
|
@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
|
||||
}
|
||||
|
||||
// Check that the driver handles rewriter.moveAfter. In this case, we expect
|
||||
// the moved op to be visited only once since walk uses `make_early_inc_range`.
|
||||
// the moved op to be visited twice.
|
||||
// CHECK-LABEL: func.func @move_after(
|
||||
// CHECK: scf.if
|
||||
// CHECK: }
|
||||
// CHECK: "test.move_after_parent_op"
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
|
||||
// CHECK: return
|
||||
func.func @move_after(%cond : i1) {
|
||||
scf.if %cond {
|
||||
|
Loading…
x
Reference in New Issue
Block a user