[mlir][reducer] Add eraseRedundantBlocksInRegion and getSuccessorForwardOperands API to BranchOpInterface (#187864)

To simplify the output of the reduction-tree pass, this PR introduces
the eraseRedundantBlocksInRegion. For regions containing multiple
execution paths, this functionality selects the shortest 'interesting'
path. Additionally, this PR adds the getSuccessorForwardOperands API to
BranchOpInterface. This allows us to extract the ForwardOperands for a
specific path chosen from multiple alternatives, enabling the creation
of a cf.br operation for the redirected jump.
This commit is contained in:
lonely eagle 2026-03-28 15:22:46 +08:00 committed by GitHub
parent 097abb3d64
commit eb53972051
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 233 additions and 4 deletions

View File

@ -65,7 +65,8 @@ def AssertOp : CF_Op<"assert",
//===----------------------------------------------------------------------===//
def BranchOp : CF_Op<"br", [
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
DeclareOpInterfaceMethods<BranchOpInterface,
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
Pure, Terminator
]> {
let summary = "Branch operation";
@ -114,8 +115,8 @@ def BranchOp : CF_Op<"br", [
def CondBranchOp
: CF_Op<"cond_br", [AttrSizedOperandSegments,
DeclareOpInterfaceMethods<
BranchOpInterface, ["getSuccessorForOperands"]>,
DeclareOpInterfaceMethods<BranchOpInterface,
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
WeightedBranchOpInterface, Pure, Terminator]> {
let summary = "Conditional branch operation";
let description = [{
@ -241,7 +242,8 @@ def CondBranchOp
def SwitchOp : CF_Op<"switch",
[AttrSizedOperandSegments,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
DeclareOpInterfaceMethods<BranchOpInterface,
["getSuccessorForOperands", "getSuccessorForwardOperands"]>,
Pure, Terminator]> {
let summary = "Switch operation";
let description = [{

View File

@ -98,6 +98,15 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
[{ return lhs == rhs; }]
>,
InterfaceMethod<[{
This method is called to returns the operands of this operation that
are passed to the specified successor's block arguments. If the successor
is not valid for this operation, or no operands are forwarded, an empty
ValueRange is returned.
}],
"ValueRange", "getSuccessorForwardOperands",
(ins "Block *":$successor), [{}],[{ return {};}]
>,
];
let verify = [{

View File

@ -90,6 +90,9 @@ public:
/// corresponding region.
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion,
IRMapping &mapper);
private:
/// A custom BFS iterator. The difference between
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.

View File

@ -296,6 +296,12 @@ Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
return getDest();
}
ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) {
if (successor == getDest())
return getDestOperands();
return {};
}
//===----------------------------------------------------------------------===//
// CondBranchOp
//===----------------------------------------------------------------------===//
@ -583,6 +589,14 @@ Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
return nullptr;
}
ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) {
if (successor == getTrueDest())
return getTrueOperands();
else if (successor == getFalseDest())
return getFalseOperands();
return {};
}
//===----------------------------------------------------------------------===//
// SwitchOp
//===----------------------------------------------------------------------===//
@ -1034,6 +1048,16 @@ void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
.add<SimplifyUniformBlockArguments>(context);
}
ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) {
if (successor == getDefaultDestination())
return getDefaultOperands();
SuccessorRange caseDests = getCaseDestinations();
auto it = llvm::find(caseDests, successor);
if (it == caseDests.end())
return {};
return getCaseOperands(std::distance(caseDests.begin(), it));
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

View File

@ -45,6 +45,16 @@ LogicalResult ReductionNode::initialize(ModuleOp parentModule,
return success();
}
LogicalResult ReductionNode::initialize(ModuleOp parentModule,
Region &targetRegion,
IRMapping &mapper) {
module = cast<ModuleOp>(parentModule->clone(mapper));
// Use the first block of targetRegion to locate the cloned region.
Block *block = mapper.lookup(&*targetRegion.begin());
region = block->getParent();
return success();
}
/// If we haven't explored any variants from this node, we will create N
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
/// max element in `ranges` and create 2 new variants for each call.

View File

@ -14,7 +14,10 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
@ -24,6 +27,9 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "reduction-tree"
namespace mlir {
#define GEN_PASS_DEF_REDUCTIONTREEPASS
@ -184,6 +190,113 @@ static LogicalResult eraseAllOpsInRegion(ModuleOp module, Region &region,
return failure();
}
// Returns the first branching terminator (cond_br, switch, etc.) found in the
// region.
static Operation *getBranchTerminatorInRegion(Region &region) {
for (Block &block : region.getBlocks()) {
if (block.getNumSuccessors() > 1)
return block.getTerminator();
}
return {};
}
/// Reduces the control flow in a region by iteratively forcing branching
/// terminators to point to a single successor. It evaluates each potential
/// branch path and commits the reduction that results in the smallest
/// "interesting" module.
static LogicalResult eraseRedundantBlocksInRegion(ModuleOp module,
Region &region,
const Tester &test) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
// While exploring the reduction tree, we always branch from an interesting
// node. Thus the root node must be interesting.
if (initStatus.first != Tester::Interestingness::True)
return module.emitWarning() << "uninterested module will not be reduced";
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
// We set the simplification level to Aggressive to enable block merging.
GreedyRewriteConfig config;
config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive);
config.setUseTopDownTraversal(true);
// Populate canonicalization patterns for cf ops. When all targets of a
// 'cf.cond_br' or 'cf.switch' point to the same block, they will be
// canonicalized into a 'cf.br'.
auto context = region.getContext();
RewritePatternSet patterns(context);
cf::BranchOp::getCanonicalizationPatterns(patterns, context);
cf::CondBranchOp::getCanonicalizationPatterns(patterns, context);
cf::SwitchOp::getCanonicalizationPatterns(patterns, context);
FrozenRewritePatternSet fPatterns = std::move(patterns);
ReductionNode *smallestNode = nullptr;
mlir::OpBuilder b(context);
while (Operation *branchTerminator = getBranchTerminatorInRegion(region)) {
size_t numSuccessor = branchTerminator->getNumSuccessors();
// We allocate memory on the heap because the object will be assigned to
// 'smallestNode'.
ReductionNode *root = allocator.Allocate();
std::vector<ReductionNode::Range> ranges{
{0, std::distance(region.op_begin(), region.op_end())}};
// Iterate through each successor of the branching terminator to try
// reducing the control flow to a single-path execution.
int branchIdx = -1;
for (int i = 0, e = numSuccessor; i < e; ++i) {
new (root) ReductionNode(nullptr, ranges, allocator);
mlir::IRMapping mapper;
if (failed(root->initialize(module, region, mapper)))
llvm_unreachable("unexpected initialization failure");
Operation *tergetTerminator = mapper.lookup(branchTerminator);
Block *selectedBlock = tergetTerminator->getSuccessor(i);
auto branchOp = cast<BranchOpInterface>(tergetTerminator);
ValueRange selectedBlockOperands =
branchOp.getSuccessorForwardOperands(selectedBlock);
b.setInsertionPointAfter(tergetTerminator);
cf::BranchOp::create(b, tergetTerminator->getLoc(), selectedBlock,
selectedBlockOperands);
tergetTerminator->erase();
// Apply canonicalization patterns to collapse the now-redundant branches
(void)applyPatternsGreedily(root->getRegion().getParentOp(), fPatterns,
config);
root->update(test.isInteresting(root->getModule()));
// Track the smallest "interesting" version of the IR found so far.
if (root->isInteresting() == Tester::Interestingness::True &&
(smallestNode == nullptr ||
root->getSize() < smallestNode->getSize())) {
smallestNode = root;
branchIdx = i;
}
}
// If an interesting reduced branch was found, commit the change to the
// original region and re-apply patterns for a final cleanup.
if (branchIdx != -1) {
Block *selectedBlock = branchTerminator->getSuccessor(branchIdx);
auto branchOp = cast<BranchOpInterface>(branchTerminator);
ValueRange selectedBlockOperands =
branchOp.getSuccessorForwardOperands(selectedBlock);
b.setInsertionPointAfter(branchTerminator);
cf::BranchOp::create(b, branchTerminator->getLoc(), selectedBlock,
selectedBlockOperands);
branchTerminator->erase();
(void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
}
}
// If no branching terminators were found (skipping the while loop),
// there might still be opportunities for linear block merging or
// We apply patterns here as a final cleanup to ensure the region is fully
// simplified.
if (smallestNode == nullptr)
(void)applyPatternsGreedily(region.getParentOp(), fPatterns, config);
return success();
}
template <typename IteratorType>
static LogicalResult findOptimal(ModuleOp module, Region &region,
const FrozenRewritePatternSet &patterns,
@ -196,6 +309,8 @@ static LogicalResult findOptimal(ModuleOp module, Region &region,
if (succeeded(eraseAllOpsInRegion(module, region, test)))
return success();
(void)eraseRedundantBlocksInRegion(module, region, test);
// In the second phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,

View File

@ -58,3 +58,69 @@ func.func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func.func @simple5() {
return
}
// -----
// CHECK-LABEL: func @br_reduction
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
func.func @br_reduction(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])
// -----
// CHECK-LABEL: func @br_reduction_loop
// CHECK-SAME: %[[ARG0:.*]]: i1,
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
func.func @br_reduction_loop(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cf.cond_br %arg0, ^bb1, ^bb2
^bb1:
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
cf.cond_br %arg0, ^bb3(%1: memref<2xf32>), ^bb4
^bb4:
return
}
// CHECK-NEXT: cf.br ^bb1(%[[ARG1]] : memref<2xf32>)
// CHECK-NEXT: ^bb1(%[[VAL_0:.*]]: memref<2xf32>):
// CHECK-NEXT: "test.op_crash"(%[[VAL_0]], %[[ARG2]])
// CHECK-NEXT: cf.br ^bb1(%[[VAL_0]] : memref<2xf32>)
// -----
// CHECK-LABEL: func @switch_reduction
// CHECK-SAME: %[[ARG0:.*]]: i32,
// CHECK-SAME: %[[ARG1:.*]]: memref<2xf32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<2xf32>) {
func.func @switch_reduction(%arg0: i32, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cf.switch %arg0 : i32, [
default: ^bb3(%arg1 : memref<2xf32>),
0: ^bb1,
1: ^bb2
]
^bb1:
cf.br ^bb3(%arg1 : memref<2xf32>)
^bb2:
%0 = memref.alloc() : memref<2xf32>
cf.br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>):
"test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return
}
// CHECK-NEXT: "test.op_crash"(%[[ARG1]], %[[ARG2]])