378 lines
14 KiB
C++
378 lines
14 KiB
C++
//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "llvm/Transforms/Scalar/LoopTermFold.h"
|
|
#include "llvm/ADT/Statistic.h"
|
|
#include "llvm/Analysis/LoopAnalysisManager.h"
|
|
#include "llvm/Analysis/LoopInfo.h"
|
|
#include "llvm/Analysis/LoopPass.h"
|
|
#include "llvm/Analysis/MemorySSA.h"
|
|
#include "llvm/Analysis/MemorySSAUpdater.h"
|
|
#include "llvm/Analysis/ScalarEvolution.h"
|
|
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
|
|
#include "llvm/Analysis/TargetLibraryInfo.h"
|
|
#include "llvm/Analysis/TargetTransformInfo.h"
|
|
#include "llvm/Analysis/ValueTracking.h"
|
|
#include "llvm/Config/llvm-config.h"
|
|
#include "llvm/IR/BasicBlock.h"
|
|
#include "llvm/IR/Dominators.h"
|
|
#include "llvm/IR/IRBuilder.h"
|
|
#include "llvm/IR/InstrTypes.h"
|
|
#include "llvm/IR/Instruction.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/IR/Value.h"
|
|
#include "llvm/InitializePasses.h"
|
|
#include "llvm/Pass.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
#include "llvm/Transforms/Scalar.h"
|
|
#include "llvm/Transforms/Utils.h"
|
|
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
|
#include "llvm/Transforms/Utils/Local.h"
|
|
#include "llvm/Transforms/Utils/LoopUtils.h"
|
|
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
|
|
#include <cassert>
|
|
#include <optional>
|
|
|
|
using namespace llvm;
|
|
|
|
#define DEBUG_TYPE "loop-term-fold"
|
|
|
|
STATISTIC(NumTermFold,
|
|
"Number of terminating condition fold recognized and performed");
|
|
|
|
static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
|
|
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
|
|
const LoopInfo &LI, const TargetTransformInfo &TTI) {
|
|
if (!L->isInnermost()) {
|
|
LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
|
|
return std::nullopt;
|
|
}
|
|
// Only inspect on simple loop structure
|
|
if (!L->isLoopSimplifyForm()) {
|
|
LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
|
|
LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
BasicBlock *LoopLatch = L->getLoopLatch();
|
|
BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
|
|
if (!BI || BI->isUnconditional())
|
|
return std::nullopt;
|
|
auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
|
|
if (!TermCond) {
|
|
LLVM_DEBUG(
|
|
dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
|
|
return std::nullopt;
|
|
}
|
|
if (!TermCond->hasOneUse()) {
|
|
LLVM_DEBUG(
|
|
dbgs()
|
|
<< "Cannot replace terminating condition with more than one use\n");
|
|
return std::nullopt;
|
|
}
|
|
|
|
BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
|
|
Value *RHS = TermCond->getOperand(1);
|
|
if (!LHS || !L->isLoopInvariant(RHS))
|
|
// We could pattern match the inverse form of the icmp, but that is
|
|
// non-canonical, and this pass is running *very* late in the pipeline.
|
|
return std::nullopt;
|
|
|
|
// Find the IV used by the current exit condition.
|
|
PHINode *ToFold;
|
|
Value *ToFoldStart, *ToFoldStep;
|
|
if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
|
|
return std::nullopt;
|
|
|
|
// Ensure the simple recurrence is a part of the current loop.
|
|
if (ToFold->getParent() != L->getHeader())
|
|
return std::nullopt;
|
|
|
|
// If that IV isn't dead after we rewrite the exit condition in terms of
|
|
// another IV, there's no point in doing the transform.
|
|
if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
|
|
return std::nullopt;
|
|
|
|
// Inserting instructions in the preheader has a runtime cost, scale
|
|
// the allowed cost with the loops trip count as best we can.
|
|
const unsigned ExpansionBudget = [&]() {
|
|
unsigned Budget = 2 * SCEVCheapExpansionBudget;
|
|
if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
|
|
return std::min(Budget, SmallTC);
|
|
if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
|
|
return std::min(Budget, *SmallTC);
|
|
// Unknown trip count, assume long running by default.
|
|
return Budget;
|
|
}();
|
|
|
|
const SCEV *BECount = SE.getBackedgeTakenCount(L);
|
|
const DataLayout &DL = L->getHeader()->getDataLayout();
|
|
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
|
|
|
|
PHINode *ToHelpFold = nullptr;
|
|
const SCEV *TermValueS = nullptr;
|
|
bool MustDropPoison = false;
|
|
auto InsertPt = L->getLoopPreheader()->getTerminator();
|
|
for (PHINode &PN : L->getHeader()->phis()) {
|
|
if (ToFold == &PN)
|
|
continue;
|
|
|
|
if (!SE.isSCEVable(PN.getType())) {
|
|
LLVM_DEBUG(dbgs() << "IV of phi '" << PN
|
|
<< "' is not SCEV-able, not qualified for the "
|
|
"terminating condition folding.\n");
|
|
continue;
|
|
}
|
|
const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
|
|
// Only speculate on affine AddRec
|
|
if (!AddRec || !AddRec->isAffine()) {
|
|
LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
|
|
<< "' is not an affine add recursion, not qualified "
|
|
"for the terminating condition folding.\n");
|
|
continue;
|
|
}
|
|
|
|
// Check that we can compute the value of AddRec on the exiting iteration
|
|
// without soundness problems. evaluateAtIteration internally needs
|
|
// to multiply the stride of the iteration number - which may wrap around.
|
|
// The issue here is subtle because computing the result accounting for
|
|
// wrap is insufficient. In order to use the result in an exit test, we
|
|
// must also know that AddRec doesn't take the same value on any previous
|
|
// iteration. The simplest case to consider is a candidate IV which is
|
|
// narrower than the trip count (and thus original IV), but this can
|
|
// also happen due to non-unit strides on the candidate IVs.
|
|
if (!AddRec->hasNoSelfWrap() ||
|
|
!SE.isKnownNonZero(AddRec->getStepRecurrence(SE)))
|
|
continue;
|
|
|
|
const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
|
|
const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
|
|
if (!Expander.isSafeToExpand(TermValueSLocal)) {
|
|
LLVM_DEBUG(
|
|
dbgs() << "Is not safe to expand terminating value for phi node" << PN
|
|
<< "\n");
|
|
continue;
|
|
}
|
|
|
|
if (Expander.isHighCostExpansion(TermValueSLocal, L, ExpansionBudget, &TTI,
|
|
InsertPt)) {
|
|
LLVM_DEBUG(
|
|
dbgs() << "Is too expensive to expand terminating value for phi node"
|
|
<< PN << "\n");
|
|
continue;
|
|
}
|
|
|
|
// The candidate IV may have been otherwise dead and poison from the
|
|
// very first iteration. If we can't disprove that, we can't use the IV.
|
|
if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
|
|
LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
|
|
continue;
|
|
}
|
|
|
|
// The candidate IV may become poison on the last iteration. If this
|
|
// value is not branched on, this is a well defined program. We're
|
|
// about to add a new use to this IV, and we have to ensure we don't
|
|
// insert UB which didn't previously exist.
|
|
bool MustDropPoisonLocal = false;
|
|
Instruction *PostIncV =
|
|
cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
|
|
if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
|
|
&DT)) {
|
|
LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
|
|
<< "\n");
|
|
|
|
// If this is a complex recurrance with multiple instructions computing
|
|
// the backedge value, we might need to strip poison flags from all of
|
|
// them.
|
|
if (PostIncV->getOperand(0) != &PN)
|
|
continue;
|
|
|
|
// In order to perform the transform, we need to drop the poison
|
|
// generating flags on this instruction (if any).
|
|
MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
|
|
}
|
|
|
|
// We pick the last legal alternate IV. We could expore choosing an optimal
|
|
// alternate IV if we had a decent heuristic to do so.
|
|
ToHelpFold = &PN;
|
|
TermValueS = TermValueSLocal;
|
|
MustDropPoison = MustDropPoisonLocal;
|
|
}
|
|
|
|
LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
|
|
<< "Cannot find other AddRec IV to help folding\n";);
|
|
|
|
LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
|
|
<< "\nFound loop that can fold terminating condition\n"
|
|
<< " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
|
|
<< " TermCond: " << *TermCond << "\n"
|
|
<< " BrandInst: " << *BI << "\n"
|
|
<< " ToFold: " << *ToFold << "\n"
|
|
<< " ToHelpFold: " << *ToHelpFold << "\n");
|
|
|
|
if (!ToFold || !ToHelpFold)
|
|
return std::nullopt;
|
|
return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
|
|
}
|
|
|
|
static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
|
|
LoopInfo &LI, const TargetTransformInfo &TTI,
|
|
TargetLibraryInfo &TLI, MemorySSA *MSSA) {
|
|
std::unique_ptr<MemorySSAUpdater> MSSAU;
|
|
if (MSSA)
|
|
MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
|
|
|
|
auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
|
|
if (!Opt)
|
|
return false;
|
|
|
|
auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
|
|
|
|
NumTermFold++;
|
|
|
|
BasicBlock *LoopPreheader = L->getLoopPreheader();
|
|
BasicBlock *LoopLatch = L->getLoopLatch();
|
|
|
|
(void)ToFold;
|
|
LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
|
|
<< *ToFold << "\n"
|
|
<< "New term-cond phi-node:\n"
|
|
<< *ToHelpFold << "\n");
|
|
|
|
Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader);
|
|
(void)StartValue;
|
|
Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
|
|
|
|
// See comment in canFoldTermCondOfLoop on why this is sufficient.
|
|
if (MustDrop)
|
|
cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
|
|
|
|
// SCEVExpander for both use in preheader and latch
|
|
const DataLayout &DL = L->getHeader()->getDataLayout();
|
|
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
|
|
|
|
assert(Expander.isSafeToExpand(TermValueS) &&
|
|
"Terminating value was checked safe in canFoldTerminatingCondition");
|
|
|
|
// Create new terminating value at loop preheader
|
|
Value *TermValue = Expander.expandCodeFor(TermValueS, ToHelpFold->getType(),
|
|
LoopPreheader->getTerminator());
|
|
|
|
LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
|
|
<< *StartValue << "\n"
|
|
<< "Terminating value of new term-cond phi-node:\n"
|
|
<< *TermValue << "\n");
|
|
|
|
// Create new terminating condition at loop latch
|
|
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
|
|
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
|
|
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
|
|
Value *NewTermCond =
|
|
LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
|
|
"lsr_fold_term_cond.replaced_term_cond");
|
|
// Swap successors to exit loop body if IV equals to new TermValue
|
|
if (BI->getSuccessor(0) == L->getHeader())
|
|
BI->swapSuccessors();
|
|
|
|
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
|
|
<< *OldTermCond << "\n"
|
|
<< "New term-cond:\n"
|
|
<< *NewTermCond << "\n");
|
|
|
|
BI->setCondition(NewTermCond);
|
|
|
|
Expander.clear();
|
|
OldTermCond->eraseFromParent();
|
|
DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get());
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
|
|
class LoopTermFold : public LoopPass {
|
|
public:
|
|
static char ID; // Pass ID, replacement for typeid
|
|
|
|
LoopTermFold();
|
|
|
|
private:
|
|
bool runOnLoop(Loop *L, LPPassManager &LPM) override;
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override;
|
|
};
|
|
|
|
} // end anonymous namespace
|
|
|
|
LoopTermFold::LoopTermFold() : LoopPass(ID) {
|
|
initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
|
|
void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
|
|
AU.addRequired<LoopInfoWrapperPass>();
|
|
AU.addPreserved<LoopInfoWrapperPass>();
|
|
AU.addPreservedID(LoopSimplifyID);
|
|
AU.addRequiredID(LoopSimplifyID);
|
|
AU.addRequired<DominatorTreeWrapperPass>();
|
|
AU.addPreserved<DominatorTreeWrapperPass>();
|
|
AU.addRequired<ScalarEvolutionWrapperPass>();
|
|
AU.addPreserved<ScalarEvolutionWrapperPass>();
|
|
AU.addRequired<TargetLibraryInfoWrapperPass>();
|
|
AU.addRequired<TargetTransformInfoWrapperPass>();
|
|
AU.addPreserved<MemorySSAWrapperPass>();
|
|
}
|
|
|
|
bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
|
|
if (skipLoop(L))
|
|
return false;
|
|
|
|
auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
|
|
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
|
|
auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
|
|
const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
|
|
*L->getHeader()->getParent());
|
|
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
|
|
*L->getHeader()->getParent());
|
|
auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
|
|
MemorySSA *MSSA = nullptr;
|
|
if (MSSAAnalysis)
|
|
MSSA = &MSSAAnalysis->getMSSA();
|
|
return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
|
|
}
|
|
|
|
PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
|
|
LoopStandardAnalysisResults &AR,
|
|
LPMUpdater &) {
|
|
if (!RunTermFold(&L, AR.SE, AR.DT, AR.LI, AR.TTI, AR.TLI, AR.MSSA))
|
|
return PreservedAnalyses::all();
|
|
|
|
auto PA = getLoopPassPreservedAnalyses();
|
|
if (AR.MSSA)
|
|
PA.preserve<MemorySSAAnalysis>();
|
|
return PA;
|
|
}
|
|
|
|
char LoopTermFold::ID = 0;
|
|
|
|
INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
|
|
false, false)
|
|
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
|
|
INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
|
|
INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
|
|
false, false)
|
|
|
|
Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
|