[MLIR] Refactor the walkAndApplyPatterns driver to remove the recursion (#154037)

This is in preparation of a follow-up change to stop traversing
unreachable blocks.

This is not NFC because of a subtlety of the early_inc. On a test case
like:

```
  scf.if %cond {
    "test.move_after_parent_op"() ({
      "test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
    }) : () -> ()
  }
```

We recursively traverse the nested regions, and process an op when the
region is done (post-order).
We need to pre-increment the iterator before processing an operation in
case it gets deleted. However
we can do this before or after processing the nested region. This
implementation does the latter.
This commit is contained in:
Mehdi Amini 2025-08-18 11:07:19 +02:00 committed by GitHub
parent a0f325bd41
commit 16aa283344
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 92 additions and 13 deletions

View File

@ -13,12 +13,14 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Debug.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
// Iterator on all reachable operations in the region.
// Also keep track if we visited the nested regions of the current op
// already to drive the post-order traversal.
struct RegionReachableOpIterator {
RegionReachableOpIterator(Region *region) : region(region) {
regionIt = region->begin();
if (regionIt != region->end())
blockIt = regionIt->begin();
}
// Advance the iterator to the next reachable operation.
void advance() {
assert(regionIt != region->end());
hasVisitedRegions = false;
if (blockIt == regionIt->end()) {
++regionIt;
if (regionIt != region->end())
blockIt = regionIt->begin();
return;
}
++blockIt;
if (blockIt != regionIt->end()) {
LDBG() << "Incrementing block iterator, next op: "
<< OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
}
}
// The region we're iterating over.
Region *region;
// The Block currently being iterated over.
Region::iterator regionIt;
// The Operation currently being iterated over.
Block::iterator blockIt;
// Whether we've visited the nested regions of the current op already.
bool hasVisitedRegions = false;
};
// Worklist of regions to visit to drive the post-order traversal.
SmallVector<RegionReachableOpIterator> worklist;
LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
// Perform a post-order traversal of the regions, visiting each
// reachable operation.
for (Region &region : op->getRegions()) {
region.walk([&](Operation *visitedOp) {
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
llvm::dbgs(), OpPrintingFlags().skipRegions());
llvm::dbgs() << "\n";);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
erasedListener.visitedOp = visitedOp;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
assert(worklist.empty());
if (region.empty())
continue;
// Prime the worklist with the entry block of this region.
worklist.push_back({&region});
while (!worklist.empty()) {
RegionReachableOpIterator &it = worklist.back();
if (it.regionIt == it.region->end()) {
// We're done with this region.
worklist.pop_back();
continue;
}
});
if (it.blockIt == it.regionIt->end()) {
// We're done with this block.
it.advance();
continue;
}
Operation *op = &*it.blockIt;
// If we haven't visited the nested regions of this op yet,
// enqueue them.
if (!it.hasVisitedRegions) {
it.hasVisitedRegions = true;
for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
if (nestedRegion.empty())
continue;
worklist.push_back({&nestedRegion});
}
}
// If we're not at the back of the worklist, we've enqueued some
// nested region for processing. We'll come back to this op later
// (post-order)
if (&it != &worklist.back())
continue;
// Preemptively increment the iterator, in case the current op
// would be erased.
it.advance();
LDBG() << "Visiting op: "
<< OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (succeeded(applicator.matchAndRewrite(op, rewriter)))
LDBG() << "\tOp matched and rewritten";
}
}
},
{op});

View File

@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) {
}
// Check that the driver handles rewriter.moveAfter. In this case, we expect
// the moved op to be visited only once since walk uses `make_early_inc_range`.
// the moved op to be visited twice.
// CHECK-LABEL: func.func @move_after(
// CHECK: scf.if
// CHECK: }
// CHECK: "test.move_after_parent_op"
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
// CHECK: return
func.func @move_after(%cond : i1) {
scf.if %cond {