[mlir] Add fast walk-based pattern rewrite driver (#113825)

This is intended as a fast pattern rewrite driver for the cases when a
simple walk gets the job done but we would still want to implement it in
terms of rewrite patterns (that can be used with the greedy pattern
rewrite driver downstream).

The new driver is inspired by the discussion in
https://github.com/llvm/llvm-project/pull/112454 and the LLVM Dev
presentation from @matthias-springer earlier this week.

This limitation comes with some limitations:
* It does not repeat until a fixpoint or revisit ops modified in place
or newly created ops. In general, it only walks forward (in the
post-order).
* `matchAndRewrite` can only erase the matched op or its descendants.
  This is verified under expensive checks.
* It does not perform folding / DCE.
 
We could probably relax some of these in the future without sacrificing
too much performance.
This commit is contained in:
Jakub Kuderski 2024-10-31 11:10:09 -04:00 committed by GitHub
parent 1d0370872f
commit 0f8a6b7d03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 455 additions and 67 deletions

View File

@ -86,7 +86,7 @@ An action can also carry arbitrary payload, for example we can extend the
```c++
/// A custom Action can be defined minimally by deriving from
/// `tracing::ActionImpl`. It can has any members!
/// `tracing::ActionImpl`. It can have any members!
class MyCustomAction : public tracing::ActionImpl<MyCustomAction> {
public:
using Base = tracing::ActionImpl<MyCustomAction>;

View File

@ -320,15 +320,41 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
framework also provides support for type conversions. More information on this
driver can be found [here](DialectConversion.md).
### Walk Pattern Rewrite Driver
This is a fast and simple driver that walks the given op and applies patterns
that locally have the most benefit. The benefit of a pattern is decided solely
by the benefit specified on the pattern, and the relative order of the pattern
within the pattern list (when two patterns have the same local benefit).
The driver performs a post-order traversal. Note that it walks regions of the
given op but does not visit the op.
This driver does not (re)visit modified or newly replaced ops, and does not
allow for progressive rewrites of the same op. Op and block erasure is only
supported for the currently matched op and its descendant. If your pattern
set requires these, consider using the Greedy Pattern Rewrite Driver instead,
at the expense of extra overhead.
This driver is exposed using the `walkAndApplyPatterns` function.
Note: This driver listens for IR changes via the callbacks provided by
`RewriterBase`. It is important that patterns announce all IR changes to the
rewriter and do not bypass the rewriter API by modifying ops directly.
#### Debugging
You can debug the Walk Pattern Rewrite Driver by passing the
`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
ops.
### Greedy Pattern Rewrite Driver
This driver processes ops in a worklist-driven fashion and greedily applies the
patterns that locally have the most benefit. The benefit of a pattern is decided
solely by the benefit specified on the pattern, and the relative order of the
pattern within the pattern list (when two patterns have the same local benefit).
Patterns are iteratively applied to operations until a fixed point is reached or
until the configurable maximum number of iterations exhausted, at which point
the driver finishes.
patterns that locally have the most benefit (same as the Walk Pattern Rewrite
Driver). Patterns are iteratively applied to operations until a fixed point is
reached or until the configurable maximum number of iterations exhausted, at
which point the driver finishes.
This driver comes in two fashions:
@ -368,7 +394,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
[pass](Passes.md/#-canonicalize) in MLIR.
### Debugging
#### Debugging
To debug the execution of the greedy pattern rewrite driver,
`-debug-only=greedy-rewriter` may be used. This command line flag activates

View File

@ -461,54 +461,60 @@ public:
/// struct can be used as a base to create listener chains, so that multiple
/// listeners can be notified of IR changes.
struct ForwardingListener : public RewriterBase::Listener {
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
ForwardingListener(OpBuilder::Listener *listener)
: listener(listener),
rewriteListener(
dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
listener->notifyOperationInserted(op, previous);
if (listener)
listener->notifyOperationInserted(op, previous);
}
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override {
listener->notifyBlockInserted(block, previous, previousIt);
if (listener)
listener->notifyBlockInserted(block, previous, previousIt);
}
void notifyBlockErased(Block *block) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyBlockErased(block);
}
void notifyOperationModified(Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyOperationModified(op);
}
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyOperationReplaced(op, newOp);
}
void notifyOperationReplaced(Operation *op,
ValueRange replacement) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyOperationReplaced(op, replacement);
}
void notifyOperationErased(Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyOperationErased(op);
}
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyPatternBegin(pattern, op);
}
void notifyPatternEnd(const Pattern &pattern,
LogicalResult status) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyPatternEnd(pattern, status);
}
void notifyMatchFailure(
Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override {
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
if (rewriteListener)
rewriteListener->notifyMatchFailure(loc, reasonCallback);
}
private:
OpBuilder::Listener *listener;
RewriterBase::Listener *rewriteListener;
};
/// Move the blocks that belong to "region" before the given position in

View File

@ -0,0 +1,37 @@
//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Declares a helper function to walk the given op and apply rewrite patterns.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
namespace mlir {
/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
/// given operation by walking it and applying the highest benefit patterns.
/// This rewriter *does not* wait until a fixpoint is reached and *does not*
/// visit modified or newly replaced ops. Also *does not* perform folding or
/// dead-code elimination.
///
/// This is intended as the simplest and most lightweight pattern rewriter in
/// cases when a simple walk gets the job done.
///
/// Note: Does not apply patterns to the given operation itself.
void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
RewriterBase::Listener *listener = nullptr);
} // namespace mlir
#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_

View File

@ -14,7 +14,7 @@
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
namespace arith {
@ -157,11 +157,7 @@ struct ArithUnsignedWhenEquivalentPass
RewritePatternSet patterns(ctx);
populateUnsignedWhenEquivalentPatterns(patterns, solver);
GreedyRewriteConfig config;
config.listener = &listener;
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
signalPassFailure();
walkAndApplyPatterns(op, std::move(patterns), &listener);
}
};
} // end anonymous namespace

View File

@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
LoopInvariantCodeMotionUtils.cpp
OneToNTypeConversion.cpp
RegionUtils.cpp
WalkPatternRewriteDriver.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms

View File

@ -0,0 +1,116 @@
//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Implements mlir::walkAndApplyPatterns.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
namespace mlir {
namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
using ActionImpl::ActionImpl;
static constexpr StringLiteral tag = "walk-and-apply-patterns";
void print(raw_ostream &os) const override { os << tag; }
};
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Forwarding listener to guard against unsupported erasures of non-descendant
// ops/blocks. Because we use walk-based pattern application, erasing the
// op/block from the *next* iteration (e.g., a user of the visited op) is not
// valid. Note that this is only used with expensive pattern API checks.
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
using RewriterBase::ForwardingListener::ForwardingListener;
void notifyOperationErased(Operation *op) override {
checkErasure(op);
ForwardingListener::notifyOperationErased(op);
}
void notifyBlockErased(Block *block) override {
checkErasure(block->getParentOp());
ForwardingListener::notifyBlockErased(block);
}
void checkErasure(Operation *op) const {
Operation *ancestorOp = op;
while (ancestorOp && ancestorOp != visitedOp)
ancestorOp = ancestorOp->getParentOp();
if (ancestorOp != visitedOp)
llvm::report_fatal_error(
"unsupported erasure in WalkPatternRewriter; "
"erasure is only supported for matched ops and their descendants");
}
Operation *visitedOp = nullptr;
};
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
} // namespace
void walkAndApplyPatterns(Operation *op,
const FrozenRewritePatternSet &patterns,
RewriterBase::Listener *listener) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
MLIRContext *ctx = op->getContext();
PatternRewriter rewriter(ctx);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
ErasedOpsListener erasedListener(listener);
rewriter.setListener(&erasedListener);
#else
rewriter.setListener(listener);
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
for (Region &region : op->getRegions()) {
region.walk([&](Operation *visitedOp) {
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
llvm::dbgs(), OpPrintingFlags().skipRegions());
llvm::dbgs() << "\n";);
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
erasedListener.visitedOp = visitedOp;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
}
});
}
},
{op});
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (failed(verify(op)))
llvm::report_fatal_error(
"walk pattern rewriter result IR failed to verify");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
}
} // namespace mlir

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s
// CHECK-LABEL: @test_enum_attr_roundtrip
func.func @test_enum_attr_roundtrip() -> () {

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s
// CHECK-LABEL: func @add_to_worklist_after_inplace_update()

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
// RUN: --split-input-file | FileCheck %s
// Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure

View File

@ -0,0 +1,121 @@
// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver="dump-notifications=true" \
// RUN: --allow-unregistered-dialect --split-input-file | FileCheck %s
// The following op is updated in-place and will not be added back to the worklist.
// CHECK-LABEL: func.func @inplace_update()
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
func.func @inplace_update() {
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
"test.any_attr_of_i32_str"() {attr = 1 : i32} : () -> ()
return
}
// Check that the driver does not fold visited ops.
// CHECK-LABEL: func.func @add_no_fold()
// CHECK: arith.constant
// CHECK: arith.constant
// CHECK: %[[RES:.+]] = arith.addi
// CHECK: return %[[RES]]
func.func @add_no_fold() -> i32 {
%c0 = arith.constant 0 : i32
%c1 = arith.constant 1 : i32
%res = arith.addi %c0, %c1 : i32
return %res : i32
}
// Check that the driver handles rewriter.moveBefore.
// CHECK-LABEL: func.func @move_before(
// CHECK: "test.move_before_parent_op"
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: scf.if
// CHECK: return
func.func @move_before(%cond : i1) {
scf.if %cond {
"test.move_before_parent_op"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) : () -> ()
}
return
}
// Check that the driver handles rewriter.moveAfter. In this case, we expect
// the moved op to be visited only once since walk uses `make_early_inc_range`.
// CHECK-LABEL: func.func @move_after(
// CHECK: scf.if
// CHECK: }
// CHECK: "test.move_after_parent_op"
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: return
func.func @move_after(%cond : i1) {
scf.if %cond {
"test.move_after_parent_op"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) : () -> ()
}
return
}
// Check that the driver handles rewriter.moveAfter. In this case, we expect
// the moved op to be visited twice since we advance its position to the next
// node after the parent.
// CHECK-LABEL: func.func @move_forward_and_revisit(
// CHECK: scf.if
// CHECK: }
// CHECK: arith.addi
// CHECK: "test.move_after_parent_op"
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
// CHECK: arith.addi
// CHECK: return
func.func @move_forward_and_revisit(%cond : i1) {
scf.if %cond {
"test.move_after_parent_op"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) {advance = 1 : i32} : () -> ()
}
%a = arith.addi %cond, %cond : i1
%b = arith.addi %a, %cond : i1
return
}
// Operation inserted just after the currently visited one won't be visited.
// CHECK-LABEL: func.func @insert_just_after
// CHECK: "test.clone_me"() ({
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: }) {was_cloned} : () -> ()
// CHECK: "test.clone_me"() ({
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
// CHECK: }) : () -> ()
// CHECK: return
func.func @insert_just_after(%cond : i1) {
"test.clone_me"() ({
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
}) : () -> ()
return
}
// Check that we can replace the current operation with a new one.
// Note that the new op won't be visited.
// CHECK-LABEL: func.func @replace_with_new_op
// CHECK: %[[NEW:.+]] = "test.new_op"
// CHECK: %[[RES:.+]] = arith.addi %[[NEW]], %[[NEW]]
// CHECK: return %[[RES]]
func.func @replace_with_new_op() -> i32 {
%a = "test.replace_with_new_op"() : () -> (i32)
%res = arith.addi %a, %a : i32
return %res : i32
}
// Check that we can erase nested blocks.
// CHECK-LABEL: func.func @erase_nested_block
// CHECK: %[[RES:.+]] = "test.erase_first_block"
// CHECK-NEXT: foo.bar
// CHECK: return %[[RES]]
func.func @erase_nested_block() -> i32 {
%a = "test.erase_first_block"() ({
"foo.foo"() : () -> ()
^bb1:
"foo.bar"() : () -> ()
}): () -> (i32)
return %a : i32
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt --pass-pipeline="builtin.module(test-patterns)" %s | FileCheck %s
// RUN: mlir-opt --pass-pipeline="builtin.module(test-greedy-patterns)" %s | FileCheck %s
// CHECK-LABEL: func @test_reorder_constants_and_match
func.func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) {

View File

@ -1,5 +1,5 @@
// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
func.func @foo() -> i32 {
%c42 = arith.constant 42 : i32

View File

@ -13,12 +13,16 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
#include <cstdint>
using namespace mlir;
using namespace test;
@ -214,6 +218,30 @@ struct MoveBeforeParentOp : public RewritePattern {
}
};
/// This pattern moves "test.move_after_parent_op" after the parent op.
struct MoveAfterParentOp : public RewritePattern {
MoveAfterParentOp(MLIRContext *context)
: RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
// Do not hoist past functions.
if (isa<FunctionOpInterface>(op->getParentOp()))
return failure();
int64_t moveForwardBy = 0;
if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
moveForwardBy = advanceBy.getInt();
Operation *moveAfter = op->getParentOp();
for (int64_t i = 0; i < moveForwardBy; ++i)
moveAfter = moveAfter->getNextNode();
rewriter.moveOpAfter(op, moveAfter);
return success();
}
};
/// This pattern inlines blocks that are nested in
/// "test.inline_blocks_into_parent" into the parent block.
struct InlineBlocksIntoParent : public RewritePattern {
@ -286,14 +314,65 @@ struct CloneRegionBeforeOp : public RewritePattern {
}
};
struct TestPatternDriver
: public PassWrapper<TestPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
/// Replace an operation may introduce the re-visiting of its users.
class ReplaceWithNewOp : public RewritePattern {
public:
ReplaceWithNewOp(MLIRContext *context)
: RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
TestPatternDriver() = default;
TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Operation *newOp;
if (op->hasAttr("create_erase_op")) {
newOp = rewriter.create(
op->getLoc(),
OperationName("test.erase_op", op->getContext()).getIdentifier(),
ValueRange(), TypeRange());
} else {
newOp = rewriter.create(
op->getLoc(),
OperationName("test.new_op", op->getContext()).getIdentifier(),
op->getOperands(), op->getResultTypes());
}
// "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
// A "notifyOperationReplaced" callback is triggered in either case.
rewriter.replaceAllOpUsesWith(op, newOp->getResults());
rewriter.eraseOp(op);
return success();
}
};
StringRef getArgument() const final { return "test-patterns"; }
/// Erases the first child block of the matched "test.erase_first_block"
/// operation.
class EraseFirstBlock : public RewritePattern {
public:
EraseFirstBlock(MLIRContext *context)
: RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
llvm::errs() << "Num regions: " << op->getNumRegions() << "\n";
for (Region &r : op->getRegions()) {
for (Block &b : r.getBlocks()) {
rewriter.eraseBlock(&b);
llvm::errs() << "Erasing block: " << b << "\n";
return success();
}
}
return failure();
}
};
struct TestGreedyPatternDriver
: public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
TestGreedyPatternDriver() = default;
TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
: PassWrapper(other) {}
StringRef getArgument() const final { return "test-greedy-patterns"; }
StringRef getDescription() const final { return "Run test dialect patterns"; }
void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
@ -470,34 +549,6 @@ private:
}
};
// Replace an operation may introduce the re-visiting of its users.
class ReplaceWithNewOp : public RewritePattern {
public:
ReplaceWithNewOp(MLIRContext *context)
: RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
Operation *newOp;
if (op->hasAttr("create_erase_op")) {
newOp = rewriter.create(
op->getLoc(),
OperationName("test.erase_op", op->getContext()).getIdentifier(),
ValueRange(), TypeRange());
} else {
newOp = rewriter.create(
op->getLoc(),
OperationName("test.new_op", op->getContext()).getIdentifier(),
op->getOperands(), op->getResultTypes());
}
// "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
// A "notifyOperationReplaced" callback is triggered in either case.
rewriter.replaceAllOpUsesWith(op, newOp->getResults());
rewriter.eraseOp(op);
return success();
}
};
// Remove an operation may introduce the re-visiting of its operands.
class EraseOp : public RewritePattern {
public:
@ -560,6 +611,39 @@ private:
};
};
struct TestWalkPatternDriver final
: PassWrapper<TestWalkPatternDriver, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
TestWalkPatternDriver() = default;
TestWalkPatternDriver(const TestWalkPatternDriver &other)
: PassWrapper(other) {}
StringRef getArgument() const override {
return "test-walk-pattern-rewrite-driver";
}
StringRef getDescription() const override {
return "Run test walk pattern rewrite driver";
}
void runOnOperation() override {
mlir::RewritePatternSet patterns(&getContext());
// Patterns for testing the WalkPatternRewriteDriver.
patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
&getContext());
DumpNotifications dumpListener;
walkAndApplyPatterns(getOperation(), std::move(patterns),
dumpNotifications ? &dumpListener : nullptr);
}
Option<bool> dumpNotifications{
*this, "dump-notifications",
llvm::cl::desc("Print rewrite listener notifications"),
llvm::cl::init(false)};
};
} // namespace
//===----------------------------------------------------------------------===//
@ -1978,8 +2062,9 @@ void registerPatternsTestPass() {
PassRegistration<TestDerivedAttributeDriver>();
PassRegistration<TestPatternDriver>();
PassRegistration<TestGreedyPatternDriver>();
PassRegistration<TestStrictPatternDriver>();
PassRegistration<TestWalkPatternDriver>();
PassRegistration<TestLegalizePatternDriver>([] {
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -test-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
// RUN: mlir-opt -test-greedy-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
// CHECK-LABEL: verifyFusedLocs
func.func @verifyFusedLocs(%arg0 : i32) -> i32 {