diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp index ee5c642c943c..2111e2912056 100644 --- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -13,12 +13,14 @@ #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" #include "mlir/Rewrite/PatternApplicator.h" -#include "llvm/Support/Debug.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "walk-rewriter" @@ -88,20 +90,97 @@ void walkAndApplyPatterns(Operation *op, PatternApplicator applicator(patterns); applicator.applyDefaultCostModel(); + // Iterator on all reachable operations in the region. + // Also keep track if we visited the nested regions of the current op + // already to drive the post-order traversal. + struct RegionReachableOpIterator { + RegionReachableOpIterator(Region *region) : region(region) { + regionIt = region->begin(); + if (regionIt != region->end()) + blockIt = regionIt->begin(); + } + // Advance the iterator to the next reachable operation. + void advance() { + assert(regionIt != region->end()); + hasVisitedRegions = false; + if (blockIt == regionIt->end()) { + ++regionIt; + if (regionIt != region->end()) + blockIt = regionIt->begin(); + return; + } + ++blockIt; + if (blockIt != regionIt->end()) { + LDBG() << "Incrementing block iterator, next op: " + << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions()); + } + } + // The region we're iterating over. + Region *region; + // The Block currently being iterated over. + Region::iterator regionIt; + // The Operation currently being iterated over. + Block::iterator blockIt; + // Whether we've visited the nested regions of the current op already. + bool hasVisitedRegions = false; + }; + + // Worklist of regions to visit to drive the post-order traversal. + SmallVector worklist; + + LDBG() << "Starting walk-based pattern rewrite driver"; ctx->executeAction( [&] { + // Perform a post-order traversal of the regions, visiting each + // reachable operation. for (Region ®ion : op->getRegions()) { - region.walk([&](Operation *visitedOp) { - LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( - llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n";); -#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - erasedListener.visitedOp = visitedOp; -#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) { - LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); + assert(worklist.empty()); + if (region.empty()) + continue; + + // Prime the worklist with the entry block of this region. + worklist.push_back({®ion}); + while (!worklist.empty()) { + RegionReachableOpIterator &it = worklist.back(); + if (it.regionIt == it.region->end()) { + // We're done with this region. + worklist.pop_back(); + continue; } - }); + if (it.blockIt == it.regionIt->end()) { + // We're done with this block. + it.advance(); + continue; + } + Operation *op = &*it.blockIt; + // If we haven't visited the nested regions of this op yet, + // enqueue them. + if (!it.hasVisitedRegions) { + it.hasVisitedRegions = true; + for (Region &nestedRegion : llvm::reverse(op->getRegions())) { + if (nestedRegion.empty()) + continue; + worklist.push_back({&nestedRegion}); + } + } + // If we're not at the back of the worklist, we've enqueued some + // nested region for processing. We'll come back to this op later + // (post-order) + if (&it != &worklist.back()) + continue; + + // Preemptively increment the iterator, in case the current op + // would be erased. + it.advance(); + + LDBG() << "Visiting op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + erasedListener.visitedOp = op; +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (succeeded(applicator.matchAndRewrite(op, rewriter))) + LDBG() << "\tOp matched and rewritten"; + } } }, {op}); diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir index 02f7e60671c9..c75c478ec373 100644 --- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir +++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir @@ -40,12 +40,12 @@ func.func @move_before(%cond : i1) { } // Check that the driver handles rewriter.moveAfter. In this case, we expect -// the moved op to be visited only once since walk uses `make_early_inc_range`. +// the moved op to be visited twice. // CHECK-LABEL: func.func @move_after( // CHECK: scf.if // CHECK: } // CHECK: "test.move_after_parent_op" -// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> () +// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> () // CHECK: return func.func @move_after(%cond : i1) { scf.if %cond {