llvm-project/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
Matthias Springer c4750d0575
[mlir] Consolidate patterns into RegionBranchOpInterface patterns (#174094)
Instead of op-specific cleanup patterns for region branch ops to remove
unused results / block arguments, etc., add a set of patterns that can
handle all `RegionBranchOpInterface` ops. These patterns are enabled
only for selected SCF dialect ops at the moment:
* `scf.execute_region`
* `scf.for`
* `scf.if`
* `scf.index_switch`
* `scf.while`

It is currently not possible to register canoncalization patterns for op
interfaces and some ops have incorrect interface implementations. In
follow-up PRs, the set of ops will be gradually extended within the SCF
dialect (`scf.forall`) and across other dialects
(`gpu.warp_execute_on_lane0`, (maybe) various affine dialect ops, ...),
and maybe eventually to apply to all `RegionBranchOpInterface` ops.

This commit removes many similar canonicalization patterns from the SCF
dialect. The newly added canonicalization patterns allow users to get
the same canonicalizations for free for their own ops. And even a few
additional new canonicalizations
([example](https://github.com/llvm/llvm-project/pull/174094/files#diff-54318cd685386d5519c42be49818e388b09d934edcbe4280548baa3601802977R2241),
[example](https://github.com/llvm/llvm-project/pull/174094/files#diff-54318cd685386d5519c42be49818e388b09d934edcbe4280548baa3601802977R1101),
...).

Implementation outline: This commit adds 3 canonicalization patterns.
* `MakeRegionBranchOpSuccessorInputsDead`: Remove uses of successor
inputs, by swapping them for successor operand values.
* `RemoveDuplicateSuccessorInputUses`: Remove uses of successor inputs
that are duplicates. (Similar to `WhileRemoveDuplicatedResults` in the
SCF dialect.)
* `RemoveDeadRegionBranchOpSuccessorInputs`: Remove dead successor
inputs if all of their "tied" successor inputs are also dead. (Similar
to `WhileUnusedResult` in the SCF dialect.)
2026-01-13 07:22:09 +00:00

1039 lines
41 KiB
C++

//===- ControlFlowInterfaces.cpp - ControlFlow Interfaces -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include <utility>
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/Support/DebugLog.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ControlFlowInterfaces
//===----------------------------------------------------------------------===//
#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
: producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
}
SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
MutableOperandRange forwardedOperands)
: producedOperandCount(producedOperandCount),
forwardedOperands(std::move(forwardedOperands)) {}
//===----------------------------------------------------------------------===//
// BranchOpInterface
//===----------------------------------------------------------------------===//
/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if 'operandIndex' is within the range of 'operands', or
/// std::nullopt if `operandIndex` isn't a successor operand index.
std::optional<BlockArgument>
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor) {
LDBG() << "Getting branch successor argument for operand index "
<< operandIndex << " in successor block";
OperandRange forwardedOperands = operands.getForwardedOperands();
// Check that the operands are valid.
if (forwardedOperands.empty()) {
LDBG() << "No forwarded operands, returning nullopt";
return std::nullopt;
}
// Check to ensure that this operand is within the range.
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
if (operandIndex < operandsStart ||
operandIndex >= (operandsStart + forwardedOperands.size())) {
LDBG() << "Operand index " << operandIndex << " out of range ["
<< operandsStart << ", "
<< (operandsStart + forwardedOperands.size())
<< "), returning nullopt";
return std::nullopt;
}
// Index the successor.
unsigned argIndex =
operands.getProducedOperandCount() + operandIndex - operandsStart;
LDBG() << "Computed argument index " << argIndex << " for successor block";
return successor->getArgument(argIndex);
}
/// Verify that the given operands match those of the given successor block.
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands) {
LDBG() << "Verifying branch successor operands for successor #" << succNo
<< " in operation " << op->getName();
// Check the count.
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
LDBG() << "Branch has " << operandCount << " operands, target block has "
<< destBB->getNumArguments() << " arguments";
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
<< " operands for successor #" << succNo
<< ", but target block has "
<< destBB->getNumArguments();
// Check the types.
LDBG() << "Checking type compatibility for "
<< (operandCount - operands.getProducedOperandCount())
<< " forwarded operands";
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
Type operandType = operands[i].getType();
Type argType = destBB->getArgument(i).getType();
LDBG() << "Checking type compatibility: operand type " << operandType
<< " vs argument type " << argType;
if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
LDBG() << "Branch successor operand verification successful";
return success();
}
//===----------------------------------------------------------------------===//
// WeightedBranchOpInterface
//===----------------------------------------------------------------------===//
static LogicalResult verifyWeights(Operation *op,
llvm::ArrayRef<int32_t> weights,
std::size_t expectedWeightsNum,
llvm::StringRef weightAnchorName,
llvm::StringRef weightRefName) {
if (weights.empty())
return success();
if (weights.size() != expectedWeightsNum)
return op->emitError() << "expects number of " << weightAnchorName
<< " weights to match number of " << weightRefName
<< ": " << weights.size() << " vs "
<< expectedWeightsNum;
if (llvm::all_of(weights, [](int32_t value) { return value == 0; }))
return op->emitError() << "branch weights cannot all be zero";
return success();
}
LogicalResult detail::verifyBranchWeights(Operation *op) {
llvm::ArrayRef<int32_t> weights =
cast<WeightedBranchOpInterface>(op).getWeights();
return verifyWeights(op, weights, op->getNumSuccessors(), "branch",
"successors");
}
//===----------------------------------------------------------------------===//
// WeightedRegionBranchOpInterface
//===----------------------------------------------------------------------===//
LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
llvm::ArrayRef<int32_t> weights =
cast<WeightedRegionBranchOpInterface>(op).getWeights();
return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
}
//===----------------------------------------------------------------------===//
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
/// Verify that types match along control flow edges described the given op.
LogicalResult detail::verifyRegionBranchOpInterface(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
// Verify all control flow edges from region branch points to region
// successors.
SmallVector<RegionBranchPoint> regionBranchPoints =
regionInterface.getAllRegionBranchPoints();
for (const RegionBranchPoint &branchPoint : regionBranchPoints) {
SmallVector<RegionSuccessor> successors;
regionInterface.getSuccessorRegions(branchPoint, successors);
for (const RegionSuccessor &successor : successors) {
// Helper function that print the region branch point and the region
// successor.
auto emitRegionEdgeError = [&]() {
InFlightDiagnostic diag =
regionInterface->emitOpError("along control flow edge from ");
if (branchPoint.isParent()) {
diag << "parent";
diag.attachNote(op->getLoc()) << "region branch point";
} else {
diag << "Operation "
<< branchPoint.getTerminatorPredecessorOrNull()->getName();
diag.attachNote(
branchPoint.getTerminatorPredecessorOrNull()->getLoc())
<< "region branch point";
}
diag << " to ";
if (Region *region = successor.getSuccessor()) {
diag << "Region #" << region->getRegionNumber();
} else {
diag << "parent";
}
return diag;
};
// Verify number of successor operands and successor inputs.
OperandRange succOperands =
regionInterface.getSuccessorOperands(branchPoint, successor);
ValueRange succInputs = successor.getSuccessorInputs();
if (succOperands.size() != succInputs.size()) {
return emitRegionEdgeError()
<< ": region branch point has " << succOperands.size()
<< " operands, but region successor needs " << succInputs.size()
<< " inputs";
}
// Verify that the types are compatible.
TypeRange succInputTypes = succInputs.getTypes();
TypeRange succOperandTypes = succOperands.getTypes();
for (const auto &typesIdx :
llvm::enumerate(llvm::zip(succOperandTypes, succInputTypes))) {
Type succOperandType = std::get<0>(typesIdx.value());
Type succInputType = std::get<1>(typesIdx.value());
if (!regionInterface.areTypesCompatible(succOperandType, succInputType))
return emitRegionEdgeError()
<< ": successor operand type #" << typesIdx.index() << " "
<< succOperandType << " should match successor input type #"
<< typesIdx.index() << " " << succInputType;
}
}
}
return success();
}
/// Stop condition for `traverseRegionGraph`. The traversal is interrupted if
/// this function returns "true" for a successor region. The first parameter is
/// the successor region. The second parameter indicates all already visited
/// regions.
using StopConditionFn = function_ref<bool(Region *, ArrayRef<bool> visited)>;
/// Traverse the region graph starting at `begin`. The traversal is interrupted
/// if `stopCondition` evaluates to "true" for a successor region. In that case,
/// this function returns "true". Otherwise, if the traversal was not
/// interrupted, this function returns "false".
static bool traverseRegionGraph(Region *begin,
StopConditionFn stopConditionFn) {
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
LDBG() << "Starting region graph traversal from region #"
<< begin->getRegionNumber() << " in operation " << op->getName();
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
LDBG() << "Initialized visited array with " << op->getNumRegions()
<< " regions";
// Retrieve all successors of the region and enqueue them in the worklist.
SmallVector<Region *> worklist;
auto enqueueAllSuccessors = [&](Region *region) {
LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
SmallVector<Attribute> operandAttributes(op->getNumOperands());
for (Block &block : *region) {
if (block.empty())
continue;
auto terminator =
dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
if (!terminator)
continue;
SmallVector<RegionSuccessor> successors;
operandAttributes.resize(terminator->getNumOperands());
terminator.getSuccessorRegions(operandAttributes, successors);
LDBG() << "Found " << successors.size()
<< " successors from terminator in block";
for (RegionSuccessor successor : successors) {
if (!successor.isParent()) {
worklist.push_back(successor.getSuccessor());
LDBG() << "Added region #"
<< successor.getSuccessor()->getRegionNumber()
<< " to worklist";
} else {
LDBG() << "Skipping parent successor";
}
}
}
};
enqueueAllSuccessors(begin);
LDBG() << "Initial worklist size: " << worklist.size();
// Process all regions in the worklist via DFS.
while (!worklist.empty()) {
Region *nextRegion = worklist.pop_back_val();
LDBG() << "Processing region #" << nextRegion->getRegionNumber()
<< " from worklist (remaining: " << worklist.size() << ")";
if (stopConditionFn(nextRegion, visited)) {
LDBG() << "Stop condition met for region #"
<< nextRegion->getRegionNumber() << ", returning true";
return true;
}
if (!nextRegion->getParentOp()) {
llvm::errs() << "Region " << *nextRegion << " has no parent op\n";
return false;
}
if (visited[nextRegion->getRegionNumber()]) {
LDBG() << "Region #" << nextRegion->getRegionNumber()
<< " already visited, skipping";
continue;
}
visited[nextRegion->getRegionNumber()] = true;
LDBG() << "Marking region #" << nextRegion->getRegionNumber()
<< " as visited";
enqueueAllSuccessors(nextRegion);
}
LDBG() << "Traversal completed, returning false";
return false;
}
/// Return `true` if region `r` is reachable from region `begin` according to
/// the RegionBranchOpInterface (by taking a branch).
static bool isRegionReachable(Region *begin, Region *r) {
assert(begin->getParentOp() == r->getParentOp() &&
"expected that both regions belong to the same op");
return traverseRegionGraph(begin,
[&](Region *nextRegion, ArrayRef<bool> visited) {
// Interrupt traversal if `r` was reached.
return nextRegion == r;
});
}
/// Return `true` if `a` and `b` are in mutually exclusive regions.
///
/// 1. Find the first common of `a` and `b` (ancestor) that implements
/// RegionBranchOpInterface.
/// 2. Determine the regions `regionA` and `regionB` in which `a` and `b` are
/// contained.
/// 3. Check if `regionA` and `regionB` are mutually exclusive. They are
/// mutually exclusive if they are not reachable from each other as per
/// RegionBranchOpInterface::getSuccessorRegions.
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
LDBG() << "Checking if operations are in mutually exclusive regions: "
<< a->getName() << " and " << b->getName();
assert(a && "expected non-empty operation");
assert(b && "expected non-empty operation");
auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
while (branchOp) {
LDBG() << "Checking branch operation " << branchOp->getName();
// Check if b is inside branchOp. (We already know that a is.)
if (!branchOp->isProperAncestor(b)) {
LDBG() << "Operation b is not inside branchOp, checking next ancestor";
// Check next enclosing RegionBranchOpInterface.
branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
continue;
}
LDBG() << "Both operations are inside branchOp, finding their regions";
// b is contained in branchOp. Retrieve the regions in which `a` and `b`
// are contained.
Region *regionA = nullptr, *regionB = nullptr;
for (Region &r : branchOp->getRegions()) {
if (r.findAncestorOpInRegion(*a)) {
assert(!regionA && "already found a region for a");
regionA = &r;
LDBG() << "Found region #" << r.getRegionNumber() << " for operation a";
}
if (r.findAncestorOpInRegion(*b)) {
assert(!regionB && "already found a region for b");
regionB = &r;
LDBG() << "Found region #" << r.getRegionNumber() << " for operation b";
}
}
assert(regionA && regionB && "could not find region of op");
LDBG() << "Region A: #" << regionA->getRegionNumber() << ", Region B: #"
<< regionB->getRegionNumber();
// `a` and `b` are in mutually exclusive regions if both regions are
// distinct and neither region is reachable from the other region.
bool regionsAreDistinct = (regionA != regionB);
bool aNotReachableFromB = !isRegionReachable(regionA, regionB);
bool bNotReachableFromA = !isRegionReachable(regionB, regionA);
LDBG() << "Regions distinct: " << regionsAreDistinct
<< ", A not reachable from B: " << aNotReachableFromB
<< ", B not reachable from A: " << bNotReachableFromA;
bool mutuallyExclusive =
regionsAreDistinct && aNotReachableFromB && bNotReachableFromA;
LDBG() << "Operations are mutually exclusive: " << mutuallyExclusive;
return mutuallyExclusive;
}
// Could not find a common RegionBranchOpInterface among a's and b's
// ancestors.
LDBG() << "No common RegionBranchOpInterface found, operations are not "
"mutually exclusive";
return false;
}
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
LDBG() << "Checking if region #" << index << " is repetitive in operation "
<< getOperation()->getName();
Region *region = &getOperation()->getRegion(index);
bool isRepetitive = isRegionReachable(region, region);
LDBG() << "Region #" << index << " is repetitive: " << isRepetitive;
return isRepetitive;
}
bool RegionBranchOpInterface::hasLoop() {
LDBG() << "Checking if operation " << getOperation()->getName()
<< " has loops";
SmallVector<RegionSuccessor> entryRegions;
getSuccessorRegions(RegionBranchPoint::parent(), entryRegions);
LDBG() << "Found " << entryRegions.size() << " entry regions";
for (RegionSuccessor successor : entryRegions) {
if (!successor.isParent()) {
LDBG() << "Checking entry region #"
<< successor.getSuccessor()->getRegionNumber() << " for loops";
bool hasLoop =
traverseRegionGraph(successor.getSuccessor(),
[](Region *nextRegion, ArrayRef<bool> visited) {
// Interrupt traversal if the region was already
// visited.
return visited[nextRegion->getRegionNumber()];
});
if (hasLoop) {
LDBG() << "Found loop in entry region #"
<< successor.getSuccessor()->getRegionNumber();
return true;
}
} else {
LDBG() << "Skipping parent successor";
}
}
LDBG() << "No loops found in operation";
return false;
}
OperandRange
RegionBranchOpInterface::getSuccessorOperands(RegionBranchPoint src,
RegionSuccessor dest) {
if (src.isParent())
return getEntrySuccessorOperands(dest);
auto terminator = cast<RegionBranchTerminatorOpInterface>(
src.getTerminatorPredecessorOrNull());
return terminator.getSuccessorOperands(dest);
}
static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
return MutableArrayRef<OpOperand>(operands.getBase(), operands.size());
}
static void
getSuccessorOperandInputMapping(RegionBranchOpInterface branchOp,
RegionBranchSuccessorMapping &mapping,
RegionBranchPoint src) {
SmallVector<RegionSuccessor> successors;
branchOp.getSuccessorRegions(src, successors);
for (RegionSuccessor dst : successors) {
OperandRange operands = branchOp.getSuccessorOperands(src, dst);
assert(operands.size() == dst.getSuccessorInputs().size() &&
"expected the same number of operands and inputs");
for (const auto &[operand, input] : llvm::zip_equal(
operandsToOpOperands(operands), dst.getSuccessorInputs()))
mapping[&operand].push_back(input);
}
}
void RegionBranchOpInterface::getSuccessorOperandInputMapping(
RegionBranchSuccessorMapping &mapping,
std::optional<RegionBranchPoint> src) {
if (src.has_value()) {
::getSuccessorOperandInputMapping(*this, mapping, src.value());
} else {
// No region branch point specified: populate the mapping for all possible
// region branch points.
for (RegionBranchPoint branchPoint : getAllRegionBranchPoints())
::getSuccessorOperandInputMapping(*this, mapping, branchPoint);
}
}
static RegionBranchInverseSuccessorMapping invertRegionBranchSuccessorMapping(
const RegionBranchSuccessorMapping &operandToInputs) {
RegionBranchInverseSuccessorMapping inputToOperands;
for (const auto &[operand, inputs] : operandToInputs) {
for (Value input : inputs)
inputToOperands[input].push_back(operand);
}
return inputToOperands;
}
void RegionBranchOpInterface::getSuccessorInputOperandMapping(
RegionBranchInverseSuccessorMapping &mapping) {
RegionBranchSuccessorMapping operandToInputs;
getSuccessorOperandInputMapping(operandToInputs);
mapping = invertRegionBranchSuccessorMapping(operandToInputs);
}
SmallVector<RegionBranchPoint>
RegionBranchOpInterface::getAllRegionBranchPoints() {
SmallVector<RegionBranchPoint> branchPoints;
branchPoints.push_back(RegionBranchPoint::parent());
for (Region &region : getOperation()->getRegions()) {
for (Block &block : region) {
if (block.empty())
continue;
if (auto terminator =
dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
branchPoints.push_back(RegionBranchPoint(terminator));
}
}
return branchPoints;
}
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
LDBG() << "Finding enclosing repetitive region for operation "
<< op->getName();
while (Region *region = op->getParentRegion()) {
LDBG() << "Checking region #" << region->getRegionNumber()
<< " in operation " << region->getParentOp()->getName();
op = region->getParentOp();
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
LDBG()
<< "Found RegionBranchOpInterface, checking if region is repetitive";
if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
LDBG() << "Found repetitive region #" << region->getRegionNumber();
return region;
}
} else {
LDBG() << "Parent operation does not implement RegionBranchOpInterface";
}
}
LDBG() << "No enclosing repetitive region found";
return nullptr;
}
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
LDBG() << "Finding enclosing repetitive region for value";
Region *region = value.getParentRegion();
while (region) {
LDBG() << "Checking region #" << region->getRegionNumber()
<< " in operation " << region->getParentOp()->getName();
Operation *op = region->getParentOp();
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op)) {
LDBG()
<< "Found RegionBranchOpInterface, checking if region is repetitive";
if (branchOp.isRepetitiveRegion(region->getRegionNumber())) {
LDBG() << "Found repetitive region #" << region->getRegionNumber();
return region;
}
} else {
LDBG() << "Parent operation does not implement RegionBranchOpInterface";
}
region = op->getParentRegion();
}
LDBG() << "No enclosing repetitive region found for value";
return nullptr;
}
/// Return "true" if `a` can be used in lieu of `b`, where `b` is a region
/// successor input and `a` is a "reachable value" of `b`. Reachable values
/// are successor operand values that are (maybe transitively) forwarded to
/// `b`.
static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
assert((b.getDefiningOp() == regionBranchOp ||
b.getParentRegion()->getParentOp() == regionBranchOp) &&
"b must be a region successor input");
// Case 1: `a` is defined inside of the region branch op. `a` must be
// directly nested in the region branch op. Otherwise, it could not have
// been among the reachable values for a region successor input.
if (a.getParentRegion()->getParentOp() == regionBranchOp) {
// Case 1.1: If `b` is a result of the region branch op, `a` is not in
// scope for `b`.
// Example:
// %b = region_op({
// ^bb0(%a1: ...):
// %a2 = ...
// })
if (isa<OpResult>(b))
return false;
// Case 1.2: `b` is an entry block argument of a region. `a` is in scope
// for `b` only if it is also an entry block argument of the same region.
// Example:
// region_op({
// ^bb0(%b: ..., %a: ...):
// ...
// })
assert(isa<BlockArgument>(b) && "b must be a block argument");
return isa<BlockArgument>(a) && cast<BlockArgument>(a).getOwner() ==
cast<BlockArgument>(b).getOwner();
}
// Case 2: `a` is defined outside of the region branch op. In that case, we
// can safely assume that `a` was defined before `b`. Otherwise, it could not
// be among the reachable values for a region successor input.
// Example:
// { <- %a1 parent region begins here.
// ^bb0(%a1: ...):
// %a2 = ...
// %b1 = reigon_op({
// ^bb1(%b2: ...):
// ...
// })
// }
return true;
}
/// Compute all non-successor-input values that a successor input could have
/// based on the given successor input to successor operand mapping.
///
/// Example 1:
/// %r = scf.if ... {
/// scf.yield %a : ...
/// } else {
/// scf.yield %b : ...
/// }
/// reachableValues(%r) = {%a, %b}
///
/// Example 2:
/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
/// scf.yield %arg0 : ...
/// }
/// reachableValues(%arg0) = {%0}
/// reachableValues(%r) = {%0}
///
/// Example 3:
/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
/// ...
/// scf.yield %1 : ...
/// }
/// reachableValues(%arg0) = {%0, %1}
/// reachableValues(%r) = {%0, %1}
static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) {
assert(inputToOperands.contains(value) && "value must be a successor input");
// Starting with the given value, trace back all predecessor values (i.e.,
// preceding successor operands) and add them to the set of reachable values.
// If the successor operand is again a successor input, do not add it to
// result set, but instead continue the traversal.
llvm::SmallDenseSet<Value> reachableValues;
llvm::SmallDenseSet<Value> visited;
SmallVector<Value> worklist;
worklist.push_back(value);
while (!worklist.empty()) {
Value next = worklist.pop_back_val();
auto it = inputToOperands.find(next);
if (it == inputToOperands.end()) {
reachableValues.insert(next);
continue;
}
for (OpOperand *operand : it->second)
if (visited.insert(operand->get()).second)
worklist.push_back(operand->get());
}
// Note: The result does not contain any successor inputs. (Therefore,
// `value` is also guaranteed to be excluded.)
return reachableValues;
}
namespace {
/// Try to make successor inputs dead by replacing their uses with values that
/// are not successor inputs. This pattern enables additional canonicalization
/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
///
/// Example:
///
/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
/// scf.yield %arg1, %arg1 : ...
/// }
/// use(%r0, %r1)
///
/// reachableValues(%r0) = {%0, %1}
/// reachableValues(%r1) = {%1} ==> replace uses of %r1 with %1.
/// reachableValues(%arg0) = {%0, %1}
/// reachableValues(%arg1) = {%1} ==> replace uses of %arg1 with %1.
///
/// IR after pattern application:
///
/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
/// scf.yield %1, %1 : ...
/// }
/// use(%r0, %1)
///
/// Note that %r1 and %arg1 are dead now. The IR can now be further
/// canonicalized by RemoveDeadRegionBranchOpSuccessorInputs.
struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name,
PatternBenefit benefit = 1)
: RewritePattern(name, benefit, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"isolated-from-above ops are not supported");
// Compute the mapping of successor inputs to successor operands.
auto regionBranchOp = cast<RegionBranchOpInterface>(op);
RegionBranchInverseSuccessorMapping inputToOperands;
regionBranchOp.getSuccessorInputOperandMapping(inputToOperands);
// Try to replace the uses of each successor input one-by-one.
bool changed = false;
for (Value value : inputToOperands.keys()) {
// Nothing to do for successor inputs that are already dead.
if (value.use_empty())
continue;
// Nothing to do for successor inputs that may have multiple reachable
// values.
llvm::SmallDenseSet<Value> reachableValues =
computeReachableValuesFromSuccessorInput(value, inputToOperands);
if (reachableValues.size() != 1)
continue;
assert(*reachableValues.begin() != value &&
"successor inputs are supposed to be excluded");
// Do not replace `value` with the found reachable value if doing so
// would violate dominance. Example:
// %r = scf.execute_region ... {
// %a = ...
// scf.yield %a : ...
// }
// use(%r)
// In the above example, reachableValues(%r) = {%a}, but %a cannot be
// used as a replacement for %r due to dominance / scope.
if (!isDefinedBefore(regionBranchOp, *reachableValues.begin(), value))
continue;
rewriter.replaceAllUsesWith(value, *reachableValues.begin());
changed = true;
}
return success(changed);
}
};
/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
/// found, create a new bit vector with the given size and initialize it with
/// false.
template <typename MappingTy, typename KeyTy>
static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
unsigned size) {
return mapping.try_emplace(key, size, false).first->second;
}
/// Compute tied successor inputs. Tied successor inputs are successor inputs
/// that come as a set. If you erase one value from a set, you must erase all
/// values from the set. Otherwise, the op would become structurally invalid.
/// Each successor input appears in exactly one set.
///
/// Example:
/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
/// ...
/// }
/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}.
static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs(
const RegionBranchSuccessorMapping &operandToInputs) {
llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
for (const auto &[operand, inputs] : operandToInputs) {
assert(!inputs.empty() && "expected non-empty inputs");
Value firstInput = inputs.front();
tiedSuccessorInputs.insert(firstInput);
for (Value nextInput : llvm::drop_begin(inputs)) {
// As we explore more successor operand to successor input mappings,
// existing sets may get merged.
tiedSuccessorInputs.unionSets(firstInput, nextInput);
}
}
return tiedSuccessorInputs;
}
/// Remove dead successor inputs from region branch ops. A successor input is
/// dead if it has no uses. Successor inputs come in sets of tied values: if
/// you remove one value from a set, you must remove all values from the set.
/// Furthermore, successor operands must also be removed. (Op operands are not
/// part of the set, but the set is built based on the successor operand to
/// successor input mapping.)
///
/// Example 1:
/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
/// scf.yield %0, %arg1 : ...
/// }
/// use(%0, %1)
///
/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first
/// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The
/// resulting IR is as follows:
///
/// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... {
/// scf.yield %arg1 : ...
/// }
/// use(%0, %1)
///
/// Example 2:
/// %r0, %r1 = scf.while (%arg0 = %0) {
/// scf.condition(...) %arg0, %arg0 : ...
/// } do {
/// ^bb0(%arg1: ..., %arg2: ...):
/// scf.yield %arg1 : ...
/// }
/// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}.
///
/// Example 3:
/// %r1, %r2 = scf.if ... {
/// scf.yield %0, %1 : ...
/// } else {
/// scf.yield %2, %3 : ...
/// }
/// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each
/// value can be removed independently of the other values.
struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name,
PatternBenefit benefit = 1)
: RewritePattern(name, benefit, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"isolated-from-above ops are not supported");
// Compute tied values: values that must come as a set. If you remove one,
// you must remove all. If a successor op operand is forwarded to two
// successor inputs %a and %b, both %a and %b are in the same set.
auto regionBranchOp = cast<RegionBranchOpInterface>(op);
RegionBranchSuccessorMapping operandToInputs;
regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
llvm::EquivalenceClasses<Value> tiedSuccessorInputs =
computeTiedSuccessorInputs(operandToInputs);
// Determine which values to remove and group them by block and operation.
SmallVector<Value> valuesToRemove;
DenseMap<Block *, BitVector> blockArgsToRemove;
BitVector resultsToRemove(regionBranchOp->getNumResults(), false);
// Iterate over all sets of tied successor inputs.
for (auto it = tiedSuccessorInputs.begin(), e = tiedSuccessorInputs.end();
it != e; ++it) {
if (!(*it)->isLeader())
continue;
// Value can be removed if it is dead and all other tied values are also
// dead.
bool allDead = true;
for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
// Iterate over all values in the set and check their liveness.
if (!memberIt->use_empty()) {
allDead = false;
break;
}
}
if (!allDead)
continue;
// The entire set is dead. Group values by block and operation to
// simplify removal.
for (auto memberIt = tiedSuccessorInputs.member_begin(**it);
memberIt != tiedSuccessorInputs.member_end(); ++memberIt) {
if (auto arg = dyn_cast<BlockArgument>(*memberIt)) {
// Set blockArgsToRemove[block][arg_number] = true.
BitVector &vector =
lookupOrCreateBitVector(blockArgsToRemove, arg.getOwner(),
arg.getOwner()->getNumArguments());
vector.set(arg.getArgNumber());
} else {
// Set resultsToRemove[result_number] = true.
OpResult result = cast<OpResult>(*memberIt);
assert(result.getDefiningOp() == regionBranchOp &&
"result must be a region branch op result");
resultsToRemove.set(result.getResultNumber());
}
valuesToRemove.push_back(*memberIt);
}
}
if (valuesToRemove.empty())
return rewriter.notifyMatchFailure(op, "no values to remove");
// Find operands that must be removed together with the values.
RegionBranchInverseSuccessorMapping inputsToOperands =
invertRegionBranchSuccessorMapping(operandToInputs);
DenseMap<Operation *, llvm::BitVector> operandsToRemove;
for (Value value : valuesToRemove) {
for (OpOperand *operand : inputsToOperands[value]) {
// Set operandsToRemove[op][operand_number] = true.
BitVector &vector =
lookupOrCreateBitVector(operandsToRemove, operand->getOwner(),
operand->getOwner()->getNumOperands());
vector.set(operand->getOperandNumber());
}
}
// Erase operands.
for (auto &pair : operandsToRemove) {
Operation *op = pair.first;
BitVector &operands = pair.second;
rewriter.modifyOpInPlace(op, [&]() { op->eraseOperands(operands); });
}
// Erase block arguments.
for (auto &pair : blockArgsToRemove) {
Block *block = pair.first;
BitVector &blockArg = pair.second;
rewriter.modifyOpInPlace(block->getParentOp(),
[&]() { block->eraseArguments(blockArg); });
}
// Erase op results.
if (resultsToRemove.any())
rewriter.eraseOpResults(regionBranchOp, resultsToRemove);
return success();
}
};
/// Return "true" if the two values are owned by the same operation or block.
static bool haveSameOwner(Value a, Value b) {
void *aOwner, *bOwner;
if (auto arg = dyn_cast<BlockArgument>(a))
aOwner = arg.getOwner();
else
aOwner = a.getDefiningOp();
if (auto arg = dyn_cast<BlockArgument>(b))
bOwner = arg.getOwner();
else
bOwner = b.getDefiningOp();
return aOwner == bOwner;
}
/// Get the block argument or op result number of the given value.
static unsigned getArgOrResultNumber(Value value) {
if (auto opResult = llvm::dyn_cast<OpResult>(value))
return opResult.getResultNumber();
return llvm::cast<BlockArgument>(value).getArgNumber();
}
/// Find duplicate successor inputs and make all dead except for one. Two
/// successor inputs are "duplicate" if their corresponding successor operands
/// have the same values. This pattern enables additional canonicalization
/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
///
/// Example:
/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
/// use(%arg0, %arg1)
/// ...
/// scf.yield %x, %x : ...
/// }
/// use(%r0, %r1)
///
/// Operands of successor input %r0: [%0, %x]
/// Operands of successor input %r1: [%0, %x] ==> DUPLICATE!
/// Replace %r1 with %r0.
///
/// Operands of successor input %arg0: [%0, %x]
/// Operands of successor input %arg1: [%0, %x] ==> DUPLICATE!
/// Replace %arg1 with %arg0. (We have to make sure that we make same decision
/// as for the other tied successor inputs above. Otherwise, a set of tied
/// successor inputs may not become entirely dead.)
///
/// The resulting IR is as follows:
/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %0) -> ... {
/// use(%arg0, %arg0)
/// ...
/// scf.yield %x, %x : ...
/// }
/// use(%r0, %r0) // Note: We don't want use(%r1, %r1), which is also correct,
/// // but does not help with further canonicalizations.
struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
RemoveDuplicateSuccessorInputUses(MLIRContext *context, StringRef name,
PatternBenefit benefit = 1)
: RewritePattern(name, benefit, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
"isolated-from-above ops are not supported");
// Collect all successor inputs and sort them. When dropping the uses of a
// successor input, we'd like to also drop the uses of the same tied
// successor inputs. Otherwise, a set of tied successor inputs may not
// become entirely dead, which is required for
// RemoveDeadRegionBranchOpSuccessorInputs to be able to erase them.
// (Sorting is not required for correctness.)
auto regionBranchOp = cast<RegionBranchOpInterface>(op);
RegionBranchInverseSuccessorMapping inputsToOperands;
regionBranchOp.getSuccessorInputOperandMapping(inputsToOperands);
SmallVector<Value> inputs = llvm::to_vector(inputsToOperands.keys());
llvm::sort(inputs, [](Value a, Value b) {
return getArgOrResultNumber(a) < getArgOrResultNumber(b);
});
// Check every distinct pair of successor inputs for duplicates. Replace
// `input2` with `input1` if they are duplicates.
bool changed = false;
unsigned numInputs = inputs.size();
for (auto i : llvm::seq<unsigned>(0, numInputs)) {
Value input1 = inputs[i];
for (auto j : llvm::seq<unsigned>(i + 1, numInputs)) {
Value input2 = inputs[j];
// Nothing to do if input2 is already dead.
if (input2.use_empty())
continue;
// Replace only values that belong to the same block / operation.
// This implies that the two values are either both block arguments or
// both op results.
if (!haveSameOwner(input1, input2))
continue;
// Gather the predecessor value for each predecessor (region branch
// point). The two inputs are duplicates if each predecessor forwards
// the same value.
llvm::SmallDenseMap<Operation *, Value> operands1, operands2;
for (OpOperand *operand : inputsToOperands[input1]) {
assert(!operands1.contains(operand->getOwner()));
operands1[operand->getOwner()] = operand->get();
}
for (OpOperand *operand : inputsToOperands[input2]) {
assert(!operands2.contains(operand->getOwner()));
operands2[operand->getOwner()] = operand->get();
}
if (operands1 == operands2) {
rewriter.replaceAllUsesWith(input2, input1);
changed = true;
}
}
}
return success(changed);
}
};
} // namespace
void mlir::populateRegionBranchOpInterfaceCanonicalizationPatterns(
RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
patterns.add<MakeRegionBranchOpSuccessorInputsDead,
RemoveDuplicateSuccessorInputUses,
RemoveDeadRegionBranchOpSuccessorInputs>(patterns.getContext(),
opName, benefit);
}