[mlir] Add fast walk-based pattern rewrite driver (#113825)
This is intended as a fast pattern rewrite driver for the cases when a simple walk gets the job done but we would still want to implement it in terms of rewrite patterns (that can be used with the greedy pattern rewrite driver downstream). The new driver is inspired by the discussion in https://github.com/llvm/llvm-project/pull/112454 and the LLVM Dev presentation from @matthias-springer earlier this week. This limitation comes with some limitations: * It does not repeat until a fixpoint or revisit ops modified in place or newly created ops. In general, it only walks forward (in the post-order). * `matchAndRewrite` can only erase the matched op or its descendants. This is verified under expensive checks. * It does not perform folding / DCE. We could probably relax some of these in the future without sacrificing too much performance.
This commit is contained in:
parent
1d0370872f
commit
0f8a6b7d03
@ -86,7 +86,7 @@ An action can also carry arbitrary payload, for example we can extend the
|
||||
|
||||
```c++
|
||||
/// A custom Action can be defined minimally by deriving from
|
||||
/// `tracing::ActionImpl`. It can has any members!
|
||||
/// `tracing::ActionImpl`. It can have any members!
|
||||
class MyCustomAction : public tracing::ActionImpl<MyCustomAction> {
|
||||
public:
|
||||
using Base = tracing::ActionImpl<MyCustomAction>;
|
||||
|
@ -320,15 +320,41 @@ conversion target, via a set of pattern-based operation rewriting patterns. This
|
||||
framework also provides support for type conversions. More information on this
|
||||
driver can be found [here](DialectConversion.md).
|
||||
|
||||
### Walk Pattern Rewrite Driver
|
||||
|
||||
This is a fast and simple driver that walks the given op and applies patterns
|
||||
that locally have the most benefit. The benefit of a pattern is decided solely
|
||||
by the benefit specified on the pattern, and the relative order of the pattern
|
||||
within the pattern list (when two patterns have the same local benefit).
|
||||
|
||||
The driver performs a post-order traversal. Note that it walks regions of the
|
||||
given op but does not visit the op.
|
||||
|
||||
This driver does not (re)visit modified or newly replaced ops, and does not
|
||||
allow for progressive rewrites of the same op. Op and block erasure is only
|
||||
supported for the currently matched op and its descendant. If your pattern
|
||||
set requires these, consider using the Greedy Pattern Rewrite Driver instead,
|
||||
at the expense of extra overhead.
|
||||
|
||||
This driver is exposed using the `walkAndApplyPatterns` function.
|
||||
|
||||
Note: This driver listens for IR changes via the callbacks provided by
|
||||
`RewriterBase`. It is important that patterns announce all IR changes to the
|
||||
rewriter and do not bypass the rewriter API by modifying ops directly.
|
||||
|
||||
#### Debugging
|
||||
|
||||
You can debug the Walk Pattern Rewrite Driver by passing the
|
||||
`--debug-only=walk-rewriter` CLI flag. This will print the visited and matched
|
||||
ops.
|
||||
|
||||
### Greedy Pattern Rewrite Driver
|
||||
|
||||
This driver processes ops in a worklist-driven fashion and greedily applies the
|
||||
patterns that locally have the most benefit. The benefit of a pattern is decided
|
||||
solely by the benefit specified on the pattern, and the relative order of the
|
||||
pattern within the pattern list (when two patterns have the same local benefit).
|
||||
Patterns are iteratively applied to operations until a fixed point is reached or
|
||||
until the configurable maximum number of iterations exhausted, at which point
|
||||
the driver finishes.
|
||||
patterns that locally have the most benefit (same as the Walk Pattern Rewrite
|
||||
Driver). Patterns are iteratively applied to operations until a fixed point is
|
||||
reached or until the configurable maximum number of iterations exhausted, at
|
||||
which point the driver finishes.
|
||||
|
||||
This driver comes in two fashions:
|
||||
|
||||
@ -368,7 +394,7 @@ rewriter and do not bypass the rewriter API by modifying ops directly.
|
||||
Note: This driver is the one used by the [canonicalization](Canonicalization.md)
|
||||
[pass](Passes.md/#-canonicalize) in MLIR.
|
||||
|
||||
### Debugging
|
||||
#### Debugging
|
||||
|
||||
To debug the execution of the greedy pattern rewrite driver,
|
||||
`-debug-only=greedy-rewriter` may be used. This command line flag activates
|
||||
|
@ -461,54 +461,60 @@ public:
|
||||
/// struct can be used as a base to create listener chains, so that multiple
|
||||
/// listeners can be notified of IR changes.
|
||||
struct ForwardingListener : public RewriterBase::Listener {
|
||||
ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
|
||||
ForwardingListener(OpBuilder::Listener *listener)
|
||||
: listener(listener),
|
||||
rewriteListener(
|
||||
dyn_cast_if_present<RewriterBase::Listener>(listener)) {}
|
||||
|
||||
void notifyOperationInserted(Operation *op, InsertPoint previous) override {
|
||||
listener->notifyOperationInserted(op, previous);
|
||||
if (listener)
|
||||
listener->notifyOperationInserted(op, previous);
|
||||
}
|
||||
void notifyBlockInserted(Block *block, Region *previous,
|
||||
Region::iterator previousIt) override {
|
||||
listener->notifyBlockInserted(block, previous, previousIt);
|
||||
if (listener)
|
||||
listener->notifyBlockInserted(block, previous, previousIt);
|
||||
}
|
||||
void notifyBlockErased(Block *block) override {
|
||||
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
|
||||
if (rewriteListener)
|
||||
rewriteListener->notifyBlockErased(block);
|
||||
}
|
||||
void notifyOperationModified(Operation *op) override {
|
||||
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
|
||||
if (rewriteListener)
|
||||
rewriteListener->notifyOperationModified(op);
|
||||
}
|
||||
void notifyOperationReplaced(Operation *op, Operation *newOp) override {
|
||||
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
|
||||
if (rewriteListener)
|
||||
rewriteListener->notifyOperationReplaced(op, newOp);
|
||||
}
|
||||
void notifyOperationReplaced(Operation *op,
|
||||
ValueRange replacement) override {
|
||||
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
|
||||
if (rewriteListener)
|
||||
rewriteListener->notifyOperationReplaced(op, replacement);
|
||||
}
|
||||
void notifyOperationErased(Operation *op) override {
|
||||
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
|
||||
if (rewriteListener)
|
||||
rewriteListener->notifyOperationErased(op);
|
||||
}
|
||||
void notifyPatternBegin(const Pattern &pattern, Operation *op) override {
|
||||
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
|
||||
if (rewriteListener)
|
||||
rewriteListener->notifyPatternBegin(pattern, op);
|
||||
}
|
||||
void notifyPatternEnd(const Pattern &pattern,
|
||||
LogicalResult status) override {
|
||||
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
|
||||
if (rewriteListener)
|
||||
rewriteListener->notifyPatternEnd(pattern, status);
|
||||
}
|
||||
void notifyMatchFailure(
|
||||
Location loc,
|
||||
function_ref<void(Diagnostic &)> reasonCallback) override {
|
||||
if (auto *rewriteListener = dyn_cast<RewriterBase::Listener>(listener))
|
||||
if (rewriteListener)
|
||||
rewriteListener->notifyMatchFailure(loc, reasonCallback);
|
||||
}
|
||||
|
||||
private:
|
||||
OpBuilder::Listener *listener;
|
||||
RewriterBase::Listener *rewriteListener;
|
||||
};
|
||||
|
||||
/// Move the blocks that belong to "region" before the given position in
|
||||
|
37
mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
Normal file
37
mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h
Normal file
@ -0,0 +1,37 @@
|
||||
//===- WALKPATTERNREWRITEDRIVER.h - Walk Pattern Rewrite Driver -*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Declares a helper function to walk the given op and apply rewrite patterns.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
|
||||
#define MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
|
||||
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// A fast walk-based pattern rewrite driver. Rewrites ops nested under the
|
||||
/// given operation by walking it and applying the highest benefit patterns.
|
||||
/// This rewriter *does not* wait until a fixpoint is reached and *does not*
|
||||
/// visit modified or newly replaced ops. Also *does not* perform folding or
|
||||
/// dead-code elimination.
|
||||
///
|
||||
/// This is intended as the simplest and most lightweight pattern rewriter in
|
||||
/// cases when a simple walk gets the job done.
|
||||
///
|
||||
/// Note: Does not apply patterns to the given operation itself.
|
||||
void walkAndApplyPatterns(Operation *op,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
RewriterBase::Listener *listener = nullptr);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_WALKPATTERNREWRITEDRIVER_H_
|
@ -14,7 +14,7 @@
|
||||
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace arith {
|
||||
@ -157,11 +157,7 @@ struct ArithUnsignedWhenEquivalentPass
|
||||
RewritePatternSet patterns(ctx);
|
||||
populateUnsignedWhenEquivalentPatterns(patterns, solver);
|
||||
|
||||
GreedyRewriteConfig config;
|
||||
config.listener = &listener;
|
||||
|
||||
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config)))
|
||||
signalPassFailure();
|
||||
walkAndApplyPatterns(op, std::move(patterns), &listener);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -10,6 +10,7 @@ add_mlir_library(MLIRTransformUtils
|
||||
LoopInvariantCodeMotionUtils.cpp
|
||||
OneToNTypeConversion.cpp
|
||||
RegionUtils.cpp
|
||||
WalkPatternRewriteDriver.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
|
||||
|
116
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
Normal file
116
mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
Normal file
@ -0,0 +1,116 @@
|
||||
//===- WalkPatternRewriteDriver.cpp - A fast walk-based rewriter ---------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Implements mlir::walkAndApplyPatterns.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Rewrite/PatternApplicator.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#define DEBUG_TYPE "walk-rewriter"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace {
|
||||
struct WalkAndApplyPatternsAction final
|
||||
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(WalkAndApplyPatternsAction)
|
||||
using ActionImpl::ActionImpl;
|
||||
static constexpr StringLiteral tag = "walk-and-apply-patterns";
|
||||
void print(raw_ostream &os) const override { os << tag; }
|
||||
};
|
||||
|
||||
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
// Forwarding listener to guard against unsupported erasures of non-descendant
|
||||
// ops/blocks. Because we use walk-based pattern application, erasing the
|
||||
// op/block from the *next* iteration (e.g., a user of the visited op) is not
|
||||
// valid. Note that this is only used with expensive pattern API checks.
|
||||
struct ErasedOpsListener final : RewriterBase::ForwardingListener {
|
||||
using RewriterBase::ForwardingListener::ForwardingListener;
|
||||
|
||||
void notifyOperationErased(Operation *op) override {
|
||||
checkErasure(op);
|
||||
ForwardingListener::notifyOperationErased(op);
|
||||
}
|
||||
|
||||
void notifyBlockErased(Block *block) override {
|
||||
checkErasure(block->getParentOp());
|
||||
ForwardingListener::notifyBlockErased(block);
|
||||
}
|
||||
|
||||
void checkErasure(Operation *op) const {
|
||||
Operation *ancestorOp = op;
|
||||
while (ancestorOp && ancestorOp != visitedOp)
|
||||
ancestorOp = ancestorOp->getParentOp();
|
||||
|
||||
if (ancestorOp != visitedOp)
|
||||
llvm::report_fatal_error(
|
||||
"unsupported erasure in WalkPatternRewriter; "
|
||||
"erasure is only supported for matched ops and their descendants");
|
||||
}
|
||||
|
||||
Operation *visitedOp = nullptr;
|
||||
};
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
} // namespace
|
||||
|
||||
void walkAndApplyPatterns(Operation *op,
|
||||
const FrozenRewritePatternSet &patterns,
|
||||
RewriterBase::Listener *listener) {
|
||||
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
if (failed(verify(op)))
|
||||
llvm::report_fatal_error("walk pattern rewriter input IR failed to verify");
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
|
||||
MLIRContext *ctx = op->getContext();
|
||||
PatternRewriter rewriter(ctx);
|
||||
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
ErasedOpsListener erasedListener(listener);
|
||||
rewriter.setListener(&erasedListener);
|
||||
#else
|
||||
rewriter.setListener(listener);
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
|
||||
PatternApplicator applicator(patterns);
|
||||
applicator.applyDefaultCostModel();
|
||||
|
||||
ctx->executeAction<WalkAndApplyPatternsAction>(
|
||||
[&] {
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
region.walk([&](Operation *visitedOp) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
|
||||
llvm::dbgs(), OpPrintingFlags().skipRegions());
|
||||
llvm::dbgs() << "\n";);
|
||||
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
erasedListener.visitedOp = visitedOp;
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
{op});
|
||||
|
||||
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
if (failed(verify(op)))
|
||||
llvm::report_fatal_error(
|
||||
"walk pattern rewriter result IR failed to verify");
|
||||
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
|
||||
}
|
||||
|
||||
} // namespace mlir
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s | mlir-opt -test-patterns | FileCheck %s
|
||||
// RUN: mlir-opt %s | mlir-opt -test-greedy-patterns | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @test_enum_attr_roundtrip
|
||||
func.func @test_enum_attr_roundtrip() -> () {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -test-patterns="max-iterations=1" \
|
||||
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1" \
|
||||
// RUN: -allow-unregistered-dialect --split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @add_to_worklist_after_inplace_update()
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt %s -test-patterns="max-iterations=1 top-down=true" \
|
||||
// RUN: mlir-opt %s -test-greedy-patterns="max-iterations=1 top-down=true" \
|
||||
// RUN: --split-input-file | FileCheck %s
|
||||
|
||||
// Tests for https://github.com/llvm/llvm-project/issues/86765. Ensure
|
||||
|
121
mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
Normal file
121
mlir/test/IR/test-walk-pattern-rewrite-driver.mlir
Normal file
@ -0,0 +1,121 @@
|
||||
// RUN: mlir-opt %s --test-walk-pattern-rewrite-driver="dump-notifications=true" \
|
||||
// RUN: --allow-unregistered-dialect --split-input-file | FileCheck %s
|
||||
|
||||
// The following op is updated in-place and will not be added back to the worklist.
|
||||
// CHECK-LABEL: func.func @inplace_update()
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
|
||||
func.func @inplace_update() {
|
||||
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
|
||||
"test.any_attr_of_i32_str"() {attr = 1 : i32} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the driver does not fold visited ops.
|
||||
// CHECK-LABEL: func.func @add_no_fold()
|
||||
// CHECK: arith.constant
|
||||
// CHECK: arith.constant
|
||||
// CHECK: %[[RES:.+]] = arith.addi
|
||||
// CHECK: return %[[RES]]
|
||||
func.func @add_no_fold() -> i32 {
|
||||
%c0 = arith.constant 0 : i32
|
||||
%c1 = arith.constant 1 : i32
|
||||
%res = arith.addi %c0, %c1 : i32
|
||||
return %res : i32
|
||||
}
|
||||
|
||||
// Check that the driver handles rewriter.moveBefore.
|
||||
// CHECK-LABEL: func.func @move_before(
|
||||
// CHECK: "test.move_before_parent_op"
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
|
||||
// CHECK: scf.if
|
||||
// CHECK: return
|
||||
func.func @move_before(%cond : i1) {
|
||||
scf.if %cond {
|
||||
"test.move_before_parent_op"() ({
|
||||
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
|
||||
}) : () -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the driver handles rewriter.moveAfter. In this case, we expect
|
||||
// the moved op to be visited only once since walk uses `make_early_inc_range`.
|
||||
// CHECK-LABEL: func.func @move_after(
|
||||
// CHECK: scf.if
|
||||
// CHECK: }
|
||||
// CHECK: "test.move_after_parent_op"
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
|
||||
// CHECK: return
|
||||
func.func @move_after(%cond : i1) {
|
||||
scf.if %cond {
|
||||
"test.move_after_parent_op"() ({
|
||||
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
|
||||
}) : () -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Check that the driver handles rewriter.moveAfter. In this case, we expect
|
||||
// the moved op to be visited twice since we advance its position to the next
|
||||
// node after the parent.
|
||||
// CHECK-LABEL: func.func @move_forward_and_revisit(
|
||||
// CHECK: scf.if
|
||||
// CHECK: }
|
||||
// CHECK: arith.addi
|
||||
// CHECK: "test.move_after_parent_op"
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 2 : i32}> : () -> ()
|
||||
// CHECK: arith.addi
|
||||
// CHECK: return
|
||||
func.func @move_forward_and_revisit(%cond : i1) {
|
||||
scf.if %cond {
|
||||
"test.move_after_parent_op"() ({
|
||||
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
|
||||
}) {advance = 1 : i32} : () -> ()
|
||||
}
|
||||
%a = arith.addi %cond, %cond : i1
|
||||
%b = arith.addi %a, %cond : i1
|
||||
return
|
||||
}
|
||||
|
||||
// Operation inserted just after the currently visited one won't be visited.
|
||||
// CHECK-LABEL: func.func @insert_just_after
|
||||
// CHECK: "test.clone_me"() ({
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
|
||||
// CHECK: }) {was_cloned} : () -> ()
|
||||
// CHECK: "test.clone_me"() ({
|
||||
// CHECK: "test.any_attr_of_i32_str"() <{attr = 1 : i32}> : () -> ()
|
||||
// CHECK: }) : () -> ()
|
||||
// CHECK: return
|
||||
func.func @insert_just_after(%cond : i1) {
|
||||
"test.clone_me"() ({
|
||||
"test.any_attr_of_i32_str"() {attr = 0 : i32} : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Check that we can replace the current operation with a new one.
|
||||
// Note that the new op won't be visited.
|
||||
// CHECK-LABEL: func.func @replace_with_new_op
|
||||
// CHECK: %[[NEW:.+]] = "test.new_op"
|
||||
// CHECK: %[[RES:.+]] = arith.addi %[[NEW]], %[[NEW]]
|
||||
// CHECK: return %[[RES]]
|
||||
func.func @replace_with_new_op() -> i32 {
|
||||
%a = "test.replace_with_new_op"() : () -> (i32)
|
||||
%res = arith.addi %a, %a : i32
|
||||
return %res : i32
|
||||
}
|
||||
|
||||
// Check that we can erase nested blocks.
|
||||
// CHECK-LABEL: func.func @erase_nested_block
|
||||
// CHECK: %[[RES:.+]] = "test.erase_first_block"
|
||||
// CHECK-NEXT: foo.bar
|
||||
// CHECK: return %[[RES]]
|
||||
func.func @erase_nested_block() -> i32 {
|
||||
%a = "test.erase_first_block"() ({
|
||||
"foo.foo"() : () -> ()
|
||||
^bb1:
|
||||
"foo.bar"() : () -> ()
|
||||
}): () -> (i32)
|
||||
return %a : i32
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt --pass-pipeline="builtin.module(test-patterns)" %s | FileCheck %s
|
||||
// RUN: mlir-opt --pass-pipeline="builtin.module(test-greedy-patterns)" %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @test_reorder_constants_and_match
|
||||
func.func @test_reorder_constants_and_match(%arg0 : i32) -> (i32) {
|
||||
|
@ -1,5 +1,5 @@
|
||||
// RUN: mlir-opt -test-patterns='top-down=false' %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-patterns='top-down=true' %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-greedy-patterns='top-down=false' %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-greedy-patterns='top-down=true' %s | FileCheck %s
|
||||
|
||||
func.func @foo() -> i32 {
|
||||
%c42 = arith.constant 42 : i32
|
||||
|
@ -13,12 +13,16 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/Visitors.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/ScopeExit.h"
|
||||
#include <cstdint>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace test;
|
||||
@ -214,6 +218,30 @@ struct MoveBeforeParentOp : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
/// This pattern moves "test.move_after_parent_op" after the parent op.
|
||||
struct MoveAfterParentOp : public RewritePattern {
|
||||
MoveAfterParentOp(MLIRContext *context)
|
||||
: RewritePattern("test.move_after_parent_op", /*benefit=*/1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Do not hoist past functions.
|
||||
if (isa<FunctionOpInterface>(op->getParentOp()))
|
||||
return failure();
|
||||
|
||||
int64_t moveForwardBy = 0;
|
||||
if (auto advanceBy = op->getAttrOfType<IntegerAttr>("advance"))
|
||||
moveForwardBy = advanceBy.getInt();
|
||||
|
||||
Operation *moveAfter = op->getParentOp();
|
||||
for (int64_t i = 0; i < moveForwardBy; ++i)
|
||||
moveAfter = moveAfter->getNextNode();
|
||||
|
||||
rewriter.moveOpAfter(op, moveAfter);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// This pattern inlines blocks that are nested in
|
||||
/// "test.inline_blocks_into_parent" into the parent block.
|
||||
struct InlineBlocksIntoParent : public RewritePattern {
|
||||
@ -286,14 +314,65 @@ struct CloneRegionBeforeOp : public RewritePattern {
|
||||
}
|
||||
};
|
||||
|
||||
struct TestPatternDriver
|
||||
: public PassWrapper<TestPatternDriver, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPatternDriver)
|
||||
/// Replace an operation may introduce the re-visiting of its users.
|
||||
class ReplaceWithNewOp : public RewritePattern {
|
||||
public:
|
||||
ReplaceWithNewOp(MLIRContext *context)
|
||||
: RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
|
||||
|
||||
TestPatternDriver() = default;
|
||||
TestPatternDriver(const TestPatternDriver &other) : PassWrapper(other) {}
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *newOp;
|
||||
if (op->hasAttr("create_erase_op")) {
|
||||
newOp = rewriter.create(
|
||||
op->getLoc(),
|
||||
OperationName("test.erase_op", op->getContext()).getIdentifier(),
|
||||
ValueRange(), TypeRange());
|
||||
} else {
|
||||
newOp = rewriter.create(
|
||||
op->getLoc(),
|
||||
OperationName("test.new_op", op->getContext()).getIdentifier(),
|
||||
op->getOperands(), op->getResultTypes());
|
||||
}
|
||||
// "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
|
||||
// A "notifyOperationReplaced" callback is triggered in either case.
|
||||
rewriter.replaceAllOpUsesWith(op, newOp->getResults());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
StringRef getArgument() const final { return "test-patterns"; }
|
||||
/// Erases the first child block of the matched "test.erase_first_block"
|
||||
/// operation.
|
||||
class EraseFirstBlock : public RewritePattern {
|
||||
public:
|
||||
EraseFirstBlock(MLIRContext *context)
|
||||
: RewritePattern("test.erase_first_block", /*benefit=*/1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
llvm::errs() << "Num regions: " << op->getNumRegions() << "\n";
|
||||
for (Region &r : op->getRegions()) {
|
||||
for (Block &b : r.getBlocks()) {
|
||||
rewriter.eraseBlock(&b);
|
||||
llvm::errs() << "Erasing block: " << b << "\n";
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct TestGreedyPatternDriver
|
||||
: public PassWrapper<TestGreedyPatternDriver, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGreedyPatternDriver)
|
||||
|
||||
TestGreedyPatternDriver() = default;
|
||||
TestGreedyPatternDriver(const TestGreedyPatternDriver &other)
|
||||
: PassWrapper(other) {}
|
||||
|
||||
StringRef getArgument() const final { return "test-greedy-patterns"; }
|
||||
StringRef getDescription() const final { return "Run test dialect patterns"; }
|
||||
void runOnOperation() override {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
@ -470,34 +549,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
// Replace an operation may introduce the re-visiting of its users.
|
||||
class ReplaceWithNewOp : public RewritePattern {
|
||||
public:
|
||||
ReplaceWithNewOp(MLIRContext *context)
|
||||
: RewritePattern("test.replace_with_new_op", /*benefit=*/1, context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(Operation *op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Operation *newOp;
|
||||
if (op->hasAttr("create_erase_op")) {
|
||||
newOp = rewriter.create(
|
||||
op->getLoc(),
|
||||
OperationName("test.erase_op", op->getContext()).getIdentifier(),
|
||||
ValueRange(), TypeRange());
|
||||
} else {
|
||||
newOp = rewriter.create(
|
||||
op->getLoc(),
|
||||
OperationName("test.new_op", op->getContext()).getIdentifier(),
|
||||
op->getOperands(), op->getResultTypes());
|
||||
}
|
||||
// "replaceOp" could be used instead of "replaceAllOpUsesWith"+"eraseOp".
|
||||
// A "notifyOperationReplaced" callback is triggered in either case.
|
||||
rewriter.replaceAllOpUsesWith(op, newOp->getResults());
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Remove an operation may introduce the re-visiting of its operands.
|
||||
class EraseOp : public RewritePattern {
|
||||
public:
|
||||
@ -560,6 +611,39 @@ private:
|
||||
};
|
||||
};
|
||||
|
||||
struct TestWalkPatternDriver final
|
||||
: PassWrapper<TestWalkPatternDriver, OperationPass<>> {
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestWalkPatternDriver)
|
||||
|
||||
TestWalkPatternDriver() = default;
|
||||
TestWalkPatternDriver(const TestWalkPatternDriver &other)
|
||||
: PassWrapper(other) {}
|
||||
|
||||
StringRef getArgument() const override {
|
||||
return "test-walk-pattern-rewrite-driver";
|
||||
}
|
||||
StringRef getDescription() const override {
|
||||
return "Run test walk pattern rewrite driver";
|
||||
}
|
||||
void runOnOperation() override {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Patterns for testing the WalkPatternRewriteDriver.
|
||||
patterns.add<IncrementIntAttribute<3>, MoveBeforeParentOp,
|
||||
MoveAfterParentOp, CloneOp, ReplaceWithNewOp, EraseFirstBlock>(
|
||||
&getContext());
|
||||
|
||||
DumpNotifications dumpListener;
|
||||
walkAndApplyPatterns(getOperation(), std::move(patterns),
|
||||
dumpNotifications ? &dumpListener : nullptr);
|
||||
}
|
||||
|
||||
Option<bool> dumpNotifications{
|
||||
*this, "dump-notifications",
|
||||
llvm::cl::desc("Print rewrite listener notifications"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -1978,8 +2062,9 @@ void registerPatternsTestPass() {
|
||||
|
||||
PassRegistration<TestDerivedAttributeDriver>();
|
||||
|
||||
PassRegistration<TestPatternDriver>();
|
||||
PassRegistration<TestGreedyPatternDriver>();
|
||||
PassRegistration<TestStrictPatternDriver>();
|
||||
PassRegistration<TestWalkPatternDriver>();
|
||||
|
||||
PassRegistration<TestLegalizePatternDriver>([] {
|
||||
return std::make_unique<TestLegalizePatternDriver>(legalizerConversionMode);
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: mlir-opt -test-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
|
||||
// RUN: mlir-opt -test-greedy-patterns -mlir-print-debuginfo -mlir-print-local-scope %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: verifyFusedLocs
|
||||
func.func @verifyFusedLocs(%arg0 : i32) -> i32 {
|
||||
|
Loading…
x
Reference in New Issue
Block a user