To simplify the output of the reduction-tree pass, this PR introduces the eraseRedundantBlocksInRegion. For regions containing multiple execution paths, this functionality selects the shortest 'interesting' path. Additionally, this PR adds the getSuccessorForwardOperands API to BranchOpInterface. This allows us to extract the ForwardOperands for a specific path chosen from multiple alternatives, enabling the creation of a cf.br operation for the redirected jump.
1067 lines
37 KiB
C++
1067 lines
37 KiB
C++
//===- ControlFlowOps.cpp - ControlFlow Operations ------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
|
|
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
|
|
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
|
|
#include "mlir/Dialect/UB/IR/UBOps.h"
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/BuiltinTypes.h"
|
|
#include "mlir/IR/IRMapping.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "mlir/IR/Value.h"
|
|
#include "mlir/Transforms/InliningUtils.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include <numeric>
|
|
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOpsDialect.cpp.inc"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::cf;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ControlFlowDialect Interfaces
|
|
//===----------------------------------------------------------------------===//
|
|
namespace {
|
|
/// This class defines the interface for handling inlining with control flow
|
|
/// operations.
|
|
struct ControlFlowInlinerInterface : public DialectInlinerInterface {
|
|
using DialectInlinerInterface::DialectInlinerInterface;
|
|
~ControlFlowInlinerInterface() override = default;
|
|
|
|
/// All control flow operations can be inlined.
|
|
bool isLegalToInline(Operation *call, Operation *callable,
|
|
bool wouldBeCloned) const final {
|
|
return true;
|
|
}
|
|
bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
|
|
return true;
|
|
}
|
|
|
|
/// ControlFlow terminator operations don't really need any special handing.
|
|
void handleTerminator(Operation *op, Block *newDest) const final {}
|
|
};
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ControlFlowDialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void ControlFlowDialect::initialize() {
|
|
addOperations<
|
|
#define GET_OP_LIST
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
|
|
>();
|
|
addInterfaces<ControlFlowInlinerInterface>();
|
|
declarePromisedInterface<ConvertToLLVMPatternInterface, ControlFlowDialect>();
|
|
declarePromisedInterfaces<bufferization::BufferizableOpInterface, BranchOp,
|
|
CondBranchOp>();
|
|
declarePromisedInterface<bufferization::BufferDeallocationOpInterface,
|
|
CondBranchOp>();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// AssertOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
|
|
// Erase assertion if argument is constant true.
|
|
if (matchPattern(op.getArg(), m_One())) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// This side effect models "program termination".
|
|
void AssertOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
effects.emplace_back(MemoryEffects::Write::get());
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// BranchOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Given a successor, try to collapse it to a new destination if it only
|
|
/// contains a passthrough unconditional branch. If the successor is
|
|
/// collapsable, `successor` and `successorOperands` are updated to reference
|
|
/// the new destination and values. `argStorage` is used as storage if operands
|
|
/// to the collapsed successor need to be remapped. It must outlive uses of
|
|
/// successorOperands.
|
|
static LogicalResult collapseBranch(Block *&successor,
|
|
ValueRange &successorOperands,
|
|
SmallVectorImpl<Value> &argStorage) {
|
|
// Check that the successor only contains a unconditional branch.
|
|
if (std::next(successor->begin()) != successor->end())
|
|
return failure();
|
|
// Check that the terminator is an unconditional branch.
|
|
BranchOp successorBranch = dyn_cast<BranchOp>(successor->getTerminator());
|
|
if (!successorBranch)
|
|
return failure();
|
|
// Check that the arguments are only used within the terminator.
|
|
for (BlockArgument arg : successor->getArguments()) {
|
|
for (Operation *user : arg.getUsers())
|
|
if (user != successorBranch)
|
|
return failure();
|
|
}
|
|
// Don't try to collapse branches to infinite loops.
|
|
Block *successorDest = successorBranch.getDest();
|
|
if (successorDest == successor)
|
|
return failure();
|
|
// Don't try to collapse branches which participate in a cycle.
|
|
BranchOp nextBranch = dyn_cast<BranchOp>(successorDest->getTerminator());
|
|
llvm::DenseSet<Block *> visited{successor, successorDest};
|
|
while (nextBranch) {
|
|
Block *nextBranchDest = nextBranch.getDest();
|
|
if (visited.contains(nextBranchDest))
|
|
return failure();
|
|
visited.insert(nextBranchDest);
|
|
nextBranch = dyn_cast<BranchOp>(nextBranchDest->getTerminator());
|
|
}
|
|
|
|
// Update the operands to the successor. If the branch parent has no
|
|
// arguments, we can use the branch operands directly.
|
|
OperandRange operands = successorBranch.getOperands();
|
|
if (successor->args_empty()) {
|
|
successor = successorDest;
|
|
successorOperands = operands;
|
|
return success();
|
|
}
|
|
|
|
// Otherwise, we need to remap any argument operands.
|
|
for (Value operand : operands) {
|
|
BlockArgument argOperand = llvm::dyn_cast<BlockArgument>(operand);
|
|
if (argOperand && argOperand.getOwner() == successor)
|
|
argStorage.push_back(successorOperands[argOperand.getArgNumber()]);
|
|
else
|
|
argStorage.push_back(operand);
|
|
}
|
|
successor = successorDest;
|
|
successorOperands = argStorage;
|
|
return success();
|
|
}
|
|
|
|
/// Simplify a branch to a block that has a single predecessor. This effectively
|
|
/// merges the two blocks.
|
|
static LogicalResult
|
|
simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) {
|
|
// Check that the successor block has a single predecessor.
|
|
Block *succ = op.getDest();
|
|
Block *opParent = op->getBlock();
|
|
if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors()))
|
|
return failure();
|
|
|
|
// If any branch operand is itself a block argument of the successor, merging
|
|
// would call replaceAllUsesWith(arg, arg) — a no-op — leaving dangling uses
|
|
// of that argument after the successor block is erased.
|
|
for (Value operand : op.getOperands())
|
|
if (auto ba = dyn_cast<BlockArgument>(operand))
|
|
if (ba.getOwner() == succ)
|
|
return failure();
|
|
|
|
// Merge the successor into the current block and erase the branch.
|
|
SmallVector<Value> brOperands(op.getOperands());
|
|
rewriter.eraseOp(op);
|
|
rewriter.mergeBlocks(succ, opParent, brOperands);
|
|
return success();
|
|
}
|
|
|
|
/// br ^bb1
|
|
/// ^bb1
|
|
/// br ^bbN(...)
|
|
///
|
|
/// -> br ^bbN(...)
|
|
///
|
|
static LogicalResult simplifyPassThroughBr(BranchOp op,
|
|
PatternRewriter &rewriter) {
|
|
Block *dest = op.getDest();
|
|
ValueRange destOperands = op.getOperands();
|
|
SmallVector<Value, 4> destOperandStorage;
|
|
|
|
// Try to collapse the successor if it points somewhere other than this
|
|
// block.
|
|
if (dest == op->getBlock() ||
|
|
failed(collapseBranch(dest, destOperands, destOperandStorage)))
|
|
return failure();
|
|
|
|
// Create a new branch with the collapsed successor.
|
|
rewriter.replaceOpWithNewOp<BranchOp>(op, dest, destOperands);
|
|
return success();
|
|
}
|
|
|
|
/// If all incoming values for a block argument from all predecessors are the
|
|
/// same SSA value, replace uses of the block argument with that value. This
|
|
/// allows the block argument to be removed by dead code elimination.
|
|
///
|
|
/// %c = arith.constant 0 : i32
|
|
/// cf.br ^bb1(%c : i32) // pred 1
|
|
/// cf.br ^bb1(%c : i32) // pred 2
|
|
/// ^bb1(%arg0: i32):
|
|
/// use(%arg0)
|
|
/// ->
|
|
/// ^bb1(%arg0: i32):
|
|
/// use(%c) // %arg0 has no uses and can be removed
|
|
///
|
|
static LogicalResult simplifyUniformBlockArgs(Block *dest,
|
|
PatternRewriter &rewriter) {
|
|
if (dest->hasNoPredecessors() ||
|
|
llvm::hasSingleElement(dest->getPredecessors()))
|
|
return failure();
|
|
|
|
bool changed = false;
|
|
for (BlockArgument arg : dest->getArguments()) {
|
|
if (arg.use_empty())
|
|
continue;
|
|
|
|
Value commonValue;
|
|
for (Block *pred : dest->getPredecessors()) {
|
|
auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
|
|
if (!branch) {
|
|
commonValue = Value();
|
|
break;
|
|
}
|
|
|
|
for (auto [i, succ] : llvm::enumerate(branch->getSuccessors())) {
|
|
if (succ != dest)
|
|
continue;
|
|
|
|
// Produced operands are modeled by BranchOpInterface as null Values.
|
|
Value val = branch.getSuccessorOperands(i)[arg.getArgNumber()];
|
|
if (commonValue && commonValue != val) {
|
|
commonValue = Value();
|
|
break;
|
|
}
|
|
commonValue = val;
|
|
}
|
|
|
|
if (!commonValue)
|
|
break;
|
|
}
|
|
|
|
if (commonValue && commonValue != arg) {
|
|
rewriter.replaceAllUsesWith(arg, commonValue);
|
|
changed = true;
|
|
}
|
|
}
|
|
return success(changed);
|
|
}
|
|
|
|
namespace {
|
|
/// Replaces block arguments with a uniform incoming value across all
|
|
/// predecessors, for any op implementing BranchOpInterface.
|
|
struct SimplifyUniformBlockArguments
|
|
: public OpInterfaceRewritePattern<BranchOpInterface> {
|
|
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
|
|
LogicalResult matchAndRewrite(BranchOpInterface op,
|
|
PatternRewriter &rewriter) const override {
|
|
bool changed = false;
|
|
for (Block *succ : op->getSuccessors())
|
|
changed |= succeeded(simplifyUniformBlockArgs(succ, rewriter));
|
|
return success(changed);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
|
|
return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
|
|
succeeded(simplifyPassThroughBr(op, rewriter)) ||
|
|
succeeded(simplifyUniformBlockArgs(op.getDest(), rewriter)));
|
|
}
|
|
|
|
void BranchOp::setDest(Block *block) { return setSuccessor(block); }
|
|
|
|
void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); }
|
|
|
|
SuccessorOperands BranchOp::getSuccessorOperands(unsigned index) {
|
|
assert(index == 0 && "invalid successor index");
|
|
return SuccessorOperands(getDestOperandsMutable());
|
|
}
|
|
|
|
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) {
|
|
return getDest();
|
|
}
|
|
|
|
ValueRange BranchOp::getSuccessorForwardOperands(Block *successor) {
|
|
if (successor == getDest())
|
|
return getDestOperands();
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CondBranchOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// cf.cond_br true, ^bb1, ^bb2
|
|
/// -> br ^bb1
|
|
/// cf.cond_br false, ^bb1, ^bb2
|
|
/// -> br ^bb2
|
|
///
|
|
struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
|
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
PatternRewriter &rewriter) const override {
|
|
if (matchPattern(condbr.getCondition(), m_NonZero())) {
|
|
// True branch taken.
|
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
|
condbr.getTrueOperands());
|
|
return success();
|
|
}
|
|
if (matchPattern(condbr.getCondition(), m_Zero())) {
|
|
// False branch taken.
|
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
|
condbr.getFalseOperands());
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
/// cf.cond_br %cond, ^bb1, ^bb2
|
|
/// ^bb1
|
|
/// br ^bbN(...)
|
|
/// ^bb2
|
|
/// br ^bbK(...)
|
|
///
|
|
/// -> cf.cond_br %cond, ^bbN(...), ^bbK(...)
|
|
///
|
|
struct SimplifyPassThroughCondBranch : public OpRewritePattern<CondBranchOp> {
|
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
PatternRewriter &rewriter) const override {
|
|
Block *trueDest = condbr.getTrueDest(), *falseDest = condbr.getFalseDest();
|
|
ValueRange trueDestOperands = condbr.getTrueOperands();
|
|
ValueRange falseDestOperands = condbr.getFalseOperands();
|
|
SmallVector<Value, 4> trueDestOperandStorage, falseDestOperandStorage;
|
|
|
|
// Try to collapse one of the current successors.
|
|
LogicalResult collapsedTrue =
|
|
collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage);
|
|
LogicalResult collapsedFalse =
|
|
collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage);
|
|
if (failed(collapsedTrue) && failed(collapsedFalse))
|
|
return failure();
|
|
|
|
// Create a new branch with the collapsed successors.
|
|
rewriter.replaceOpWithNewOp<CondBranchOp>(
|
|
condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest,
|
|
falseDestOperands, condbr.getWeights());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// cf.cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N)
|
|
/// -> br ^bb1(A, ..., N)
|
|
///
|
|
/// cf.cond_br %cond, ^bb1(A), ^bb1(B)
|
|
/// -> %select = arith.select %cond, A, B
|
|
/// br ^bb1(%select)
|
|
///
|
|
struct SimplifyCondBranchIdenticalSuccessors
|
|
: public OpRewritePattern<CondBranchOp> {
|
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
PatternRewriter &rewriter) const override {
|
|
// Check that the true and false destinations are the same and have the same
|
|
// operands.
|
|
Block *trueDest = condbr.getTrueDest();
|
|
if (trueDest != condbr.getFalseDest())
|
|
return failure();
|
|
|
|
// If all of the operands match, no selects need to be generated.
|
|
OperandRange trueOperands = condbr.getTrueOperands();
|
|
OperandRange falseOperands = condbr.getFalseOperands();
|
|
if (trueOperands == falseOperands) {
|
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, trueOperands);
|
|
return success();
|
|
}
|
|
|
|
// Otherwise, if the current block is the only predecessor insert selects
|
|
// for any mismatched branch operands.
|
|
if (trueDest->getUniquePredecessor() != condbr->getBlock())
|
|
return failure();
|
|
|
|
// Generate a select for any operands that differ between the two.
|
|
SmallVector<Value, 8> mergedOperands;
|
|
mergedOperands.reserve(trueOperands.size());
|
|
Value condition = condbr.getCondition();
|
|
for (auto it : llvm::zip(trueOperands, falseOperands)) {
|
|
if (std::get<0>(it) == std::get<1>(it))
|
|
mergedOperands.push_back(std::get<0>(it));
|
|
else
|
|
mergedOperands.push_back(
|
|
arith::SelectOp::create(rewriter, condbr.getLoc(), condition,
|
|
std::get<0>(it), std::get<1>(it)));
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest, mergedOperands);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// ...
|
|
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
|
|
/// ...
|
|
/// ^bb1: // has single predecessor
|
|
/// ...
|
|
/// cf.cond_br %cond, ^bb3(...), ^bb4(...)
|
|
///
|
|
/// ->
|
|
///
|
|
/// ...
|
|
/// cf.cond_br %cond, ^bb1(...), ^bb2(...)
|
|
/// ...
|
|
/// ^bb1: // has single predecessor
|
|
/// ...
|
|
/// br ^bb3(...)
|
|
///
|
|
struct SimplifyCondBranchFromCondBranchOnSameCondition
|
|
: public OpRewritePattern<CondBranchOp> {
|
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
PatternRewriter &rewriter) const override {
|
|
// Check that we have a single distinct predecessor.
|
|
Block *currentBlock = condbr->getBlock();
|
|
Block *predecessor = currentBlock->getSinglePredecessor();
|
|
if (!predecessor)
|
|
return failure();
|
|
|
|
// Check that the predecessor terminates with a conditional branch to this
|
|
// block and that it branches on the same condition.
|
|
auto predBranch = dyn_cast<CondBranchOp>(predecessor->getTerminator());
|
|
if (!predBranch || condbr.getCondition() != predBranch.getCondition())
|
|
return failure();
|
|
|
|
// Fold this branch to an unconditional branch.
|
|
if (currentBlock == predBranch.getTrueDest())
|
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getTrueDest(),
|
|
condbr.getTrueDestOperands());
|
|
else
|
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, condbr.getFalseDest(),
|
|
condbr.getFalseDestOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// cf.cond_br %arg0, ^trueB, ^falseB
|
|
///
|
|
/// ^trueB:
|
|
/// "test.consumer1"(%arg0) : (i1) -> ()
|
|
/// ...
|
|
///
|
|
/// ^falseB:
|
|
/// "test.consumer2"(%arg0) : (i1) -> ()
|
|
/// ...
|
|
///
|
|
/// ->
|
|
///
|
|
/// cf.cond_br %arg0, ^trueB, ^falseB
|
|
/// ^trueB:
|
|
/// "test.consumer1"(%true) : (i1) -> ()
|
|
/// ...
|
|
///
|
|
/// ^falseB:
|
|
/// "test.consumer2"(%false) : (i1) -> ()
|
|
/// ...
|
|
struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
|
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
PatternRewriter &rewriter) const override {
|
|
// Check that we have a single distinct predecessor.
|
|
bool replaced = false;
|
|
Type ty = rewriter.getI1Type();
|
|
|
|
// These variables serve to prevent creating duplicate constants
|
|
// and hold constant true or false values.
|
|
Value constantTrue = nullptr;
|
|
Value constantFalse = nullptr;
|
|
|
|
// TODO These checks can be expanded to encompas any use with only
|
|
// either the true of false edge as a predecessor. For now, we fall
|
|
// back to checking the single predecessor is given by the true/fasle
|
|
// destination, thereby ensuring that only that edge can reach the
|
|
// op.
|
|
if (condbr.getTrueDest()->getSinglePredecessor()) {
|
|
for (OpOperand &use :
|
|
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
|
if (use.getOwner()->getBlock() == condbr.getTrueDest()) {
|
|
replaced = true;
|
|
|
|
if (!constantTrue)
|
|
constantTrue = arith::ConstantOp::create(
|
|
rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(true));
|
|
|
|
rewriter.modifyOpInPlace(use.getOwner(),
|
|
[&] { use.set(constantTrue); });
|
|
}
|
|
}
|
|
}
|
|
if (condbr.getFalseDest()->getSinglePredecessor()) {
|
|
for (OpOperand &use :
|
|
llvm::make_early_inc_range(condbr.getCondition().getUses())) {
|
|
if (use.getOwner()->getBlock() == condbr.getFalseDest()) {
|
|
replaced = true;
|
|
|
|
if (!constantFalse)
|
|
constantFalse = arith::ConstantOp::create(
|
|
rewriter, condbr.getLoc(), ty, rewriter.getBoolAttr(false));
|
|
|
|
rewriter.modifyOpInPlace(use.getOwner(),
|
|
[&] { use.set(constantFalse); });
|
|
}
|
|
}
|
|
}
|
|
return success(replaced);
|
|
}
|
|
};
|
|
|
|
/// If the destination block of a conditional branch contains only
|
|
/// ub.unreachable, unconditionally branch to the other destination.
|
|
struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
|
|
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CondBranchOp condbr,
|
|
PatternRewriter &rewriter) const override {
|
|
// If the "true" destination is unreachable, branch to the "false"
|
|
// destination.
|
|
Block *trueDest = condbr.getTrueDest();
|
|
Block *falseDest = condbr.getFalseDest();
|
|
if (llvm::hasSingleElement(*trueDest) &&
|
|
isa<ub::UnreachableOp>(trueDest->getTerminator())) {
|
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
|
|
condbr.getFalseOperands());
|
|
return success();
|
|
}
|
|
|
|
// If the "false" destination is unreachable, branch to the "true"
|
|
// destination.
|
|
if (llvm::hasSingleElement(*falseDest) &&
|
|
isa<ub::UnreachableOp>(falseDest->getTerminator())) {
|
|
rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
|
|
condbr.getTrueOperands());
|
|
return success();
|
|
}
|
|
|
|
return failure();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
|
|
SimplifyCondBranchIdenticalSuccessors,
|
|
SimplifyCondBranchFromCondBranchOnSameCondition,
|
|
CondBranchTruthPropagation, DropUnreachableCondBranch,
|
|
SimplifyUniformBlockArguments>(context);
|
|
}
|
|
|
|
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
|
|
assert(index < getNumSuccessors() && "invalid successor index");
|
|
return SuccessorOperands(index == trueIndex ? getTrueDestOperandsMutable()
|
|
: getFalseDestOperandsMutable());
|
|
}
|
|
|
|
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
|
if (IntegerAttr condAttr =
|
|
llvm::dyn_cast_or_null<IntegerAttr>(operands.front()))
|
|
return condAttr.getValue().isOne() ? getTrueDest() : getFalseDest();
|
|
return nullptr;
|
|
}
|
|
|
|
ValueRange CondBranchOp::getSuccessorForwardOperands(Block *successor) {
|
|
if (successor == getTrueDest())
|
|
return getTrueOperands();
|
|
else if (successor == getFalseDest())
|
|
return getFalseOperands();
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// SwitchOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
|
Block *defaultDestination, ValueRange defaultOperands,
|
|
DenseIntElementsAttr caseValues,
|
|
BlockRange caseDestinations,
|
|
ArrayRef<ValueRange> caseOperands) {
|
|
build(builder, result, value, defaultOperands, caseOperands, caseValues,
|
|
defaultDestination, caseDestinations);
|
|
}
|
|
|
|
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
|
Block *defaultDestination, ValueRange defaultOperands,
|
|
ArrayRef<APInt> caseValues, BlockRange caseDestinations,
|
|
ArrayRef<ValueRange> caseOperands) {
|
|
DenseIntElementsAttr caseValuesAttr;
|
|
if (!caseValues.empty()) {
|
|
ShapedType caseValueType = VectorType::get(
|
|
static_cast<int64_t>(caseValues.size()), value.getType());
|
|
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
|
|
}
|
|
build(builder, result, value, defaultDestination, defaultOperands,
|
|
caseValuesAttr, caseDestinations, caseOperands);
|
|
}
|
|
|
|
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
|
|
Block *defaultDestination, ValueRange defaultOperands,
|
|
ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
|
|
ArrayRef<ValueRange> caseOperands) {
|
|
DenseIntElementsAttr caseValuesAttr;
|
|
if (!caseValues.empty()) {
|
|
ShapedType caseValueType = VectorType::get(
|
|
static_cast<int64_t>(caseValues.size()), value.getType());
|
|
caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
|
|
}
|
|
build(builder, result, value, defaultDestination, defaultOperands,
|
|
caseValuesAttr, caseDestinations, caseOperands);
|
|
}
|
|
|
|
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
|
|
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
|
|
static ParseResult parseSwitchOpCases(
|
|
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
|
|
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
|
|
SmallVectorImpl<Type> &defaultOperandTypes,
|
|
DenseIntElementsAttr &caseValues,
|
|
SmallVectorImpl<Block *> &caseDestinations,
|
|
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>> &caseOperands,
|
|
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
|
|
if (parser.parseKeyword("default") || parser.parseColon() ||
|
|
parser.parseSuccessor(defaultDestination))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
|
|
/*allowResultNumber=*/false) ||
|
|
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
SmallVector<APInt> values;
|
|
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
|
|
while (succeeded(parser.parseOptionalComma())) {
|
|
int64_t value = 0;
|
|
if (failed(parser.parseInteger(value)))
|
|
return failure();
|
|
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
|
|
|
|
Block *destination;
|
|
SmallVector<OpAsmParser::UnresolvedOperand> operands;
|
|
SmallVector<Type> operandTypes;
|
|
if (failed(parser.parseColon()) ||
|
|
failed(parser.parseSuccessor(destination)))
|
|
return failure();
|
|
if (succeeded(parser.parseOptionalLParen())) {
|
|
if (failed(parser.parseOperandList(operands,
|
|
OpAsmParser::Delimiter::None)) ||
|
|
failed(parser.parseColonTypeList(operandTypes)) ||
|
|
failed(parser.parseRParen()))
|
|
return failure();
|
|
}
|
|
caseDestinations.push_back(destination);
|
|
caseOperands.emplace_back(operands);
|
|
caseOperandTypes.emplace_back(operandTypes);
|
|
}
|
|
|
|
if (!values.empty()) {
|
|
ShapedType caseValueType =
|
|
VectorType::get(static_cast<int64_t>(values.size()), flagType);
|
|
caseValues = DenseIntElementsAttr::get(caseValueType, values);
|
|
}
|
|
return success();
|
|
}
|
|
|
|
static void printSwitchOpCases(
|
|
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
|
|
OperandRange defaultOperands, TypeRange defaultOperandTypes,
|
|
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
|
|
OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes) {
|
|
p << " default: ";
|
|
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
|
|
|
|
if (!caseValues)
|
|
return;
|
|
|
|
for (const auto &it : llvm::enumerate(caseValues.getValues<APInt>())) {
|
|
p << ',';
|
|
p.printNewline();
|
|
p << " ";
|
|
p << it.value().getLimitedValue();
|
|
p << ": ";
|
|
p.printSuccessorAndUseList(caseDestinations[it.index()],
|
|
caseOperands[it.index()]);
|
|
}
|
|
p.printNewline();
|
|
}
|
|
|
|
LogicalResult SwitchOp::verify() {
|
|
auto caseValues = getCaseValues();
|
|
auto caseDestinations = getCaseDestinations();
|
|
|
|
if (!caseValues && caseDestinations.empty())
|
|
return success();
|
|
|
|
Type flagType = getFlag().getType();
|
|
Type caseValueType = caseValues->getType().getElementType();
|
|
if (caseValueType != flagType)
|
|
return emitOpError() << "'flag' type (" << flagType
|
|
<< ") should match case value type (" << caseValueType
|
|
<< ")";
|
|
|
|
if (caseValues &&
|
|
caseValues->size() != static_cast<int64_t>(caseDestinations.size()))
|
|
return emitOpError() << "number of case values (" << caseValues->size()
|
|
<< ") should match number of "
|
|
"case destinations ("
|
|
<< caseDestinations.size() << ")";
|
|
return success();
|
|
}
|
|
|
|
SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
|
|
assert(index < getNumSuccessors() && "invalid successor index");
|
|
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
|
|
: getCaseOperandsMutable(index - 1));
|
|
}
|
|
|
|
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
|
|
std::optional<DenseIntElementsAttr> caseValues = getCaseValues();
|
|
|
|
if (!caseValues)
|
|
return getDefaultDestination();
|
|
|
|
SuccessorRange caseDests = getCaseDestinations();
|
|
if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
|
|
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>()))
|
|
if (it.value() == value.getValue())
|
|
return caseDests[it.index()];
|
|
return getDefaultDestination();
|
|
}
|
|
return nullptr;
|
|
}
|
|
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1
|
|
/// ]
|
|
/// -> br ^bb1
|
|
static LogicalResult simplifySwitchWithOnlyDefault(SwitchOp op,
|
|
PatternRewriter &rewriter) {
|
|
if (!op.getCaseDestinations().empty())
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
|
op.getDefaultOperands());
|
|
return success();
|
|
}
|
|
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb1,
|
|
/// 43: ^bb2
|
|
/// ]
|
|
/// ->
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1,
|
|
/// 43: ^bb2
|
|
/// ]
|
|
static LogicalResult
|
|
dropSwitchCasesThatMatchDefault(SwitchOp op, PatternRewriter &rewriter) {
|
|
SmallVector<Block *> newCaseDestinations;
|
|
SmallVector<ValueRange> newCaseOperands;
|
|
SmallVector<APInt> newCaseValues;
|
|
bool requiresChange = false;
|
|
auto caseValues = op.getCaseValues();
|
|
auto caseDests = op.getCaseDestinations();
|
|
|
|
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
|
if (caseDests[it.index()] == op.getDefaultDestination() &&
|
|
op.getCaseOperands(it.index()) == op.getDefaultOperands()) {
|
|
requiresChange = true;
|
|
continue;
|
|
}
|
|
newCaseDestinations.push_back(caseDests[it.index()]);
|
|
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
|
newCaseValues.push_back(it.value());
|
|
}
|
|
|
|
if (!requiresChange)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<SwitchOp>(
|
|
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
|
newCaseValues, newCaseDestinations, newCaseOperands);
|
|
return success();
|
|
}
|
|
|
|
/// Helper for folding a switch with a constant value.
|
|
/// switch %c_42 : i32, [
|
|
/// default: ^bb1 ,
|
|
/// 42: ^bb2,
|
|
/// 43: ^bb3
|
|
/// ]
|
|
/// -> br ^bb2
|
|
static void foldSwitch(SwitchOp op, PatternRewriter &rewriter,
|
|
const APInt &caseValue) {
|
|
auto caseValues = op.getCaseValues();
|
|
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
|
if (it.value() == caseValue) {
|
|
rewriter.replaceOpWithNewOp<BranchOp>(
|
|
op, op.getCaseDestinations()[it.index()],
|
|
op.getCaseOperands(it.index()));
|
|
return;
|
|
}
|
|
}
|
|
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
|
op.getDefaultOperands());
|
|
}
|
|
|
|
/// switch %c_42 : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb2,
|
|
/// 43: ^bb3
|
|
/// ]
|
|
/// -> br ^bb2
|
|
static LogicalResult simplifyConstSwitchValue(SwitchOp op,
|
|
PatternRewriter &rewriter) {
|
|
APInt caseValue;
|
|
if (!matchPattern(op.getFlag(), m_ConstantInt(&caseValue)))
|
|
return failure();
|
|
|
|
foldSwitch(op, rewriter, caseValue);
|
|
return success();
|
|
}
|
|
|
|
/// switch %c_42 : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb2,
|
|
/// ]
|
|
/// ^bb2:
|
|
/// br ^bb3
|
|
/// ->
|
|
/// switch %c_42 : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb3,
|
|
/// ]
|
|
static LogicalResult simplifyPassThroughSwitch(SwitchOp op,
|
|
PatternRewriter &rewriter) {
|
|
SmallVector<Block *> newCaseDests;
|
|
SmallVector<ValueRange> newCaseOperands;
|
|
SmallVector<SmallVector<Value>> argStorage;
|
|
auto caseValues = op.getCaseValues();
|
|
argStorage.reserve(caseValues->size() + 1);
|
|
auto caseDests = op.getCaseDestinations();
|
|
bool requiresChange = false;
|
|
for (int64_t i = 0, size = caseValues->size(); i < size; ++i) {
|
|
Block *caseDest = caseDests[i];
|
|
ValueRange caseOperands = op.getCaseOperands(i);
|
|
argStorage.emplace_back();
|
|
if (succeeded(collapseBranch(caseDest, caseOperands, argStorage.back())))
|
|
requiresChange = true;
|
|
|
|
newCaseDests.push_back(caseDest);
|
|
newCaseOperands.push_back(caseOperands);
|
|
}
|
|
|
|
Block *defaultDest = op.getDefaultDestination();
|
|
ValueRange defaultOperands = op.getDefaultOperands();
|
|
argStorage.emplace_back();
|
|
|
|
if (succeeded(
|
|
collapseBranch(defaultDest, defaultOperands, argStorage.back())))
|
|
requiresChange = true;
|
|
|
|
if (!requiresChange)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<SwitchOp>(op, op.getFlag(), defaultDest,
|
|
defaultOperands, *caseValues,
|
|
newCaseDests, newCaseOperands);
|
|
return success();
|
|
}
|
|
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb2,
|
|
/// ]
|
|
/// ^bb2:
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb3,
|
|
/// 42: ^bb4
|
|
/// ]
|
|
/// ->
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb2,
|
|
/// ]
|
|
/// ^bb2:
|
|
/// br ^bb4
|
|
///
|
|
/// and
|
|
///
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb2,
|
|
/// ]
|
|
/// ^bb2:
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb3,
|
|
/// 43: ^bb4
|
|
/// ]
|
|
/// ->
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb2,
|
|
/// ]
|
|
/// ^bb2:
|
|
/// br ^bb3
|
|
static LogicalResult
|
|
simplifySwitchFromSwitchOnSameCondition(SwitchOp op,
|
|
PatternRewriter &rewriter) {
|
|
// Check that we have a single distinct predecessor.
|
|
Block *currentBlock = op->getBlock();
|
|
Block *predecessor = currentBlock->getSinglePredecessor();
|
|
if (!predecessor)
|
|
return failure();
|
|
|
|
// Check that the predecessor terminates with a switch branch to this block
|
|
// and that it branches on the same condition and that this branch isn't the
|
|
// default destination.
|
|
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
|
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
|
predSwitch.getDefaultDestination() == currentBlock)
|
|
return failure();
|
|
|
|
// Fold this switch to an unconditional branch.
|
|
SuccessorRange predDests = predSwitch.getCaseDestinations();
|
|
auto it = llvm::find(predDests, currentBlock);
|
|
if (it != predDests.end()) {
|
|
std::optional<DenseIntElementsAttr> predCaseValues =
|
|
predSwitch.getCaseValues();
|
|
foldSwitch(op, rewriter,
|
|
predCaseValues->getValues<APInt>()[it - predDests.begin()]);
|
|
} else {
|
|
rewriter.replaceOpWithNewOp<BranchOp>(op, op.getDefaultDestination(),
|
|
op.getDefaultOperands());
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb2
|
|
/// ]
|
|
/// ^bb1:
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb3,
|
|
/// 42: ^bb4,
|
|
/// 43: ^bb5
|
|
/// ]
|
|
/// ->
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb1,
|
|
/// 42: ^bb2,
|
|
/// ]
|
|
/// ^bb1:
|
|
/// switch %flag : i32, [
|
|
/// default: ^bb3,
|
|
/// 43: ^bb5
|
|
/// ]
|
|
static LogicalResult
|
|
simplifySwitchFromDefaultSwitchOnSameCondition(SwitchOp op,
|
|
PatternRewriter &rewriter) {
|
|
// Check that we have a single distinct predecessor.
|
|
Block *currentBlock = op->getBlock();
|
|
Block *predecessor = currentBlock->getSinglePredecessor();
|
|
if (!predecessor)
|
|
return failure();
|
|
|
|
// Check that the predecessor terminates with a switch branch to this block
|
|
// and that it branches on the same condition and that this branch is the
|
|
// default destination.
|
|
auto predSwitch = dyn_cast<SwitchOp>(predecessor->getTerminator());
|
|
if (!predSwitch || op.getFlag() != predSwitch.getFlag() ||
|
|
predSwitch.getDefaultDestination() != currentBlock)
|
|
return failure();
|
|
|
|
// Delete case values that are not possible here.
|
|
DenseSet<APInt> caseValuesToRemove;
|
|
auto predDests = predSwitch.getCaseDestinations();
|
|
auto predCaseValues = predSwitch.getCaseValues();
|
|
for (int64_t i = 0, size = predCaseValues->size(); i < size; ++i)
|
|
if (currentBlock != predDests[i])
|
|
caseValuesToRemove.insert(predCaseValues->getValues<APInt>()[i]);
|
|
|
|
SmallVector<Block *> newCaseDestinations;
|
|
SmallVector<ValueRange> newCaseOperands;
|
|
SmallVector<APInt> newCaseValues;
|
|
bool requiresChange = false;
|
|
|
|
auto caseValues = op.getCaseValues();
|
|
auto caseDests = op.getCaseDestinations();
|
|
for (const auto &it : llvm::enumerate(caseValues->getValues<APInt>())) {
|
|
if (caseValuesToRemove.contains(it.value())) {
|
|
requiresChange = true;
|
|
continue;
|
|
}
|
|
newCaseDestinations.push_back(caseDests[it.index()]);
|
|
newCaseOperands.push_back(op.getCaseOperands(it.index()));
|
|
newCaseValues.push_back(it.value());
|
|
}
|
|
|
|
if (!requiresChange)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<SwitchOp>(
|
|
op, op.getFlag(), op.getDefaultDestination(), op.getDefaultOperands(),
|
|
newCaseValues, newCaseDestinations, newCaseOperands);
|
|
return success();
|
|
}
|
|
|
|
void SwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add(&simplifySwitchWithOnlyDefault)
|
|
.add(&dropSwitchCasesThatMatchDefault)
|
|
.add(&simplifyConstSwitchValue)
|
|
.add(&simplifyPassThroughSwitch)
|
|
.add(&simplifySwitchFromSwitchOnSameCondition)
|
|
.add(&simplifySwitchFromDefaultSwitchOnSameCondition)
|
|
.add<SimplifyUniformBlockArguments>(context);
|
|
}
|
|
|
|
ValueRange SwitchOp::getSuccessorForwardOperands(Block *successor) {
|
|
if (successor == getDefaultDestination())
|
|
return getDefaultOperands();
|
|
SuccessorRange caseDests = getCaseDestinations();
|
|
auto it = llvm::find(caseDests, successor);
|
|
if (it == caseDests.end())
|
|
return {};
|
|
return getCaseOperands(std::distance(caseDests.begin(), it));
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.cpp.inc"
|