[SimplifyCFG] Set branch weights when merging conditional store to address

This commit is contained in:
Mircea Trofin 2025-08-21 13:54:49 -07:00
parent 18cf46f5b8
commit 975b3e3f70
2 changed files with 54 additions and 13 deletions

View File

@ -15,6 +15,7 @@
#ifndef LLVM_IR_PROFDATAUTILS_H
#define LLVM_IR_PROFDATAUTILS_H
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Twine.h"
#include "llvm/IR/Metadata.h"
@ -186,5 +187,31 @@ LLVM_ABI bool hasExplicitlyUnknownBranchWeights(const Instruction &I);
/// Scaling the profile data attached to 'I' using the ratio of S/T.
LLVM_ABI void scaleProfData(Instruction &I, uint64_t S, uint64_t T);
/// get the branch weights of a branch conditioned on b1 || b2, where b1 and b2
/// are 2 booleans that are the condition of 2 branches for which we have the
/// branch weights B1 and B2, respectivelly.
inline SmallVector<uint64_t, 2>
getDisjunctionWeights(const SmallVector<uint32_t, 2> &B1,
const SmallVector<uint32_t, 2> &B2) {
// for the first conditional branch, the probability the "true" case is taken
// is p(b1) = B1[0] / (B1[0] + B2[0]). The "false" case's probability is
// p(not b1) = B1[1] / (B1[0] + B1[1]).
// Similarly for the second conditional branch and B2.
//
// the probability of the new branch NOT being taken is:
// not P = p((not b1) and (not b2)) =
// = B1[1] / (B1[0]+B1[1]) * B2[1] / (B2[0]+B2[1]) =
// = B1[1] * B2[1] / (B1[0] + B1[1]) * (B2[0] + B2[1])
// then the probability of it being taken is: P = 1 - (not P).
// The denominator will be the same as above, and the numerator of P will be
// (B1[0] + B1[1]) * (B2[0] + B2[1]) - B1[1]*B2[1]
// Which then reduces to what's shown below (out of the 4 terms coming out of
// the product of sums, the subtracted one cancels out)
assert(B1.size() == 2);
assert(B2.size() == 2);
auto FalseWeight = B1[1] * B2[1];
auto TrueWeight = B1[0] * B2[0] + B1[0] * B2[1] + B1[1] * B2[0];
return {TrueWeight, FalseWeight};
}
} // namespace llvm
#endif

View File

@ -1182,7 +1182,7 @@ static void cloneInstructionsIntoPredecessorBlockAndUpdateSSAUses(
// only given the branch precondition.
// Similarly strip attributes on call parameters that may cause UB in
// location the call is moved to.
NewBonusInst->dropUBImplyingAttrsAndMetadata();
NewBonusInst->dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});
NewBonusInst->insertInto(PredBlock, PTI->getIterator());
auto Range = NewBonusInst->cloneDebugInfoFrom(&BonusInst);
@ -1808,7 +1808,8 @@ static void hoistConditionalLoadsStores(
// !annotation: Not impact semantics. Keep it.
if (const MDNode *Ranges = I->getMetadata(LLVMContext::MD_range))
MaskedLoadStore->addRangeRetAttr(getConstantRangeFromMetadata(*Ranges));
I->dropUBImplyingAttrsAndUnknownMetadata({LLVMContext::MD_annotation});
I->dropUBImplyingAttrsAndUnknownMetadata(
{LLVMContext::MD_annotation, LLVMContext::MD_prof});
// FIXME: DIAssignID is not supported for masked store yet.
// (Verifier::visitDIAssignIDMetadata)
at::deleteAssignmentMarkers(I);
@ -3366,7 +3367,7 @@ bool SimplifyCFGOpt::speculativelyExecuteBB(BranchInst *BI,
if (!SpeculatedStoreValue || &I != SpeculatedStore) {
I.setDebugLoc(DebugLoc::getDropped());
}
I.dropUBImplyingAttrsAndMetadata();
I.dropUBImplyingAttrsAndMetadata({LLVMContext::MD_prof});
// Drop ephemeral values.
if (EphTracker.contains(&I)) {
@ -4404,10 +4405,12 @@ static bool mergeConditionalStoreToAddress(
// OK, we're going to sink the stores to PostBB. The store has to be
// conditional though, so first create the predicate.
Value *PCond = cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator())
->getCondition();
Value *QCond = cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator())
->getCondition();
BranchInst *const PBranch =
cast<BranchInst>(PFB->getSinglePredecessor()->getTerminator());
BranchInst *const QBranch =
cast<BranchInst>(QFB->getSinglePredecessor()->getTerminator());
Value *const PCond = PBranch->getCondition();
Value *const QCond = QBranch->getCondition();
Value *PPHI = ensureValueAvailableInSuccessor(PStore->getValueOperand(),
PStore->getParent());
@ -4418,19 +4421,30 @@ static bool mergeConditionalStoreToAddress(
IRBuilder<> QB(PostBB, PostBBFirst);
QB.SetCurrentDebugLocation(PostBBFirst->getStableDebugLoc());
Value *PPred = PStore->getParent() == PTB ? PCond : QB.CreateNot(PCond);
Value *QPred = QStore->getParent() == QTB ? QCond : QB.CreateNot(QCond);
InvertPCond = (PStore->getParent() == PTB) ^ InvertPCond;
InvertQCond = (QStore->getParent() == QTB) ^ InvertQCond;
Value *const PPred = InvertPCond ? PCond : QB.CreateNot(PCond);
Value *const QPred = InvertQCond ? QCond : QB.CreateNot(QCond);
if (InvertPCond)
PPred = QB.CreateNot(PPred);
if (InvertQCond)
QPred = QB.CreateNot(QPred);
Value *CombinedPred = QB.CreateOr(PPred, QPred);
BasicBlock::iterator InsertPt = QB.GetInsertPoint();
auto *T = SplitBlockAndInsertIfThen(CombinedPred, InsertPt,
/*Unreachable=*/false,
/*BranchWeights=*/nullptr, DTU);
if (hasBranchWeightMD(*PBranch) && hasBranchWeightMD(*QBranch)) {
SmallVector<uint32_t, 2> PWeights, QWeights;
extractBranchWeights(*PBranch, PWeights);
extractBranchWeights(*QBranch, QWeights);
if (InvertPCond)
std::swap(PWeights[0], PWeights[1]);
if (InvertQCond)
std::swap(QWeights[0], QWeights[1]);
auto CombinedWeights = getDisjunctionWeights(PWeights, QWeights);
setBranchWeights(PostBB->getTerminator(), CombinedWeights[0],
CombinedWeights[1],
/*IsExpected=*/false);
}
QB.SetInsertPoint(T);
StoreInst *SI = cast<StoreInst>(QB.CreateStore(QPHI, Address));