Fix case where nested loops share latch

This commit is contained in:
Joel E. Denny 2025-08-19 15:25:39 -04:00
parent 680bdc22b4
commit 59cd1847a7
3 changed files with 97 additions and 19 deletions

View File

@ -324,12 +324,12 @@ LLVM_ABI void addStringMetadataToLoop(Loop *TheLoop, const char *MDString,
unsigned V = 0);
/// Return either:
/// - \c std::nullopt, if the implementation is unable to handle the loop form
/// of \p L (e.g., \p L must have a latch block that controls the loop exit).
/// - The value of \c llvm.loop.estimated_trip_count from the loop metadata of
/// \p L, if that metadata is present.
/// - Else, a new estimate of the trip count from the latch branch weights of
/// \p L, if the estimation's implementation is able to handle the loop form
/// of \p L (e.g., \p L must have a latch block that controls the loop exit).
/// - Else, \c std::nullopt.
/// \p L.
///
/// An estimated trip count is always a valid positive trip count, saturated at
/// \c UINT_MAX.
@ -350,17 +350,15 @@ getLoopEstimatedTripCount(Loop *L,
unsigned *EstimatedLoopInvocationWeight = nullptr);
/// Set \c llvm.loop.estimated_trip_count with the value \p EstimatedTripCount
/// in the loop metadata of \p L.
/// in the loop metadata of \p L. Return false if the implementation is unable
/// to handle the loop form of \p L (e.g., \p L must have a latch block that
/// controls the loop exit). Otherwise, return true.
///
/// In addition, if \p EstimatedLoopInvocationWeight, set the branch weight
/// metadata of \p L to reflect that \p L has an estimated
/// \p EstimatedTripCount iterations and has \c *EstimatedLoopInvocationWeight
/// exit weight through the loop's latch.
///
/// Return false if \p EstimatedLoopInvocationWeight and if branch weight
/// metadata could not be successfully updated (e.g., if \p L does not have a
/// latch block that controls the loop exit). Otherwise, return true.
///
/// TODO: Eventually, once all passes have migrated away from setting branch
/// weights to indicate estimated trip counts, this function will drop the
/// \p EstimatedLoopInvocationWeight parameter.

View File

@ -868,6 +868,27 @@ static std::optional<unsigned> estimateLoopTripCount(Loop *L) {
std::optional<unsigned>
llvm::getLoopEstimatedTripCount(Loop *L,
unsigned *EstimatedLoopInvocationWeight) {
// If EstimatedLoopInvocationWeight, we do not support this loop if
// getExpectedExitLoopLatchBranch returns nullptr.
//
// FIXME: Also, this is a stop-gap solution for nested loops. It avoids
// mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when
// it was created for an inner loop. The problem is that loop metadata is
// attached to the branch instruction in the loop latch block, but that can be
// shared by the loops. The solution is to attach loop metadata to loop
// headers instead, but that would be a large change to LLVM.
//
// Until that happens, we work around the problem as follows.
// getExpectedExitLoopLatchBranch (which also guards
// setLoopEstimatedTripCount) will not recognize the same latch for both loops
// unless the latch exits both loops and has only two successors. However, to
// exit both loops, the latch must have at least three successors: the inner
// loop header, the outer loop header (exit for the inner loop), and an exit
// for the outer loop.
BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
if (!ExitingBranch)
return std::nullopt;
// If requested, either compute *EstimatedLoopInvocationWeight or return
// nullopt if cannot.
//
@ -875,16 +896,14 @@ llvm::getLoopEstimatedTripCount(Loop *L,
// weights to indicate estimated trip counts, this function will drop the
// EstimatedLoopInvocationWeight parameter.
if (EstimatedLoopInvocationWeight) {
if (BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L)) {
uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused.
if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
return std::nullopt;
if (L->contains(ExitingBranch->getSuccessor(1)))
std::swap(LoopWeight, ExitWeight);
if (!ExitWeight)
return std::nullopt;
*EstimatedLoopInvocationWeight = ExitWeight;
}
uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused.
if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight))
return std::nullopt;
if (L->contains(ExitingBranch->getSuccessor(1)))
std::swap(LoopWeight, ExitWeight);
if (!ExitWeight)
return std::nullopt;
*EstimatedLoopInvocationWeight = ExitWeight;
}
// Return the estimated trip count from metadata unless the metadata is
@ -903,6 +922,15 @@ llvm::getLoopEstimatedTripCount(Loop *L,
bool llvm::setLoopEstimatedTripCount(
Loop *L, unsigned EstimatedTripCount,
std::optional<unsigned> EstimatedloopInvocationWeight) {
// If EstimatedLoopInvocationWeight, we do not support this loop if
// getExpectedExitLoopLatchBranch returns nullptr.
//
// FIXME: See comments in getLoopEstimatedTripCount for why this is required
// here regardless of EstimatedLoopInvocationWeight.
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
if (!LatchBranch)
return false;
// Set the metadata.
addStringMetadataToLoop(L, LLVMLoopEstimatedTripCount, EstimatedTripCount);
@ -915,7 +943,6 @@ bool llvm::setLoopEstimatedTripCount(
// here at all.
if (!EstimatedloopInvocationWeight)
return true;
BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
if (!LatchBranch)
return false;

View File

@ -142,3 +142,56 @@ TEST(LoopUtils, IsKnownNonPositiveInLoopTest) {
EXPECT_EQ(isKnownNonPositiveInLoop(ArgSCEV, L, SE), true);
});
}
// The inner and outer loop here share a latch. Because any loop metadata must
// be attached to that latch, loop metadata cannot distinguish between the two
// loops. Until that problem is solved (by moving loop metadata to loops'
// header blocks instead), getLoopEstimatedTripCount and
// setLoopEstimatedTripCount must refuse to operate on at least one of the two
// loops. They choose to reject the outer loop here because the latch does not
// exit it.
TEST(LoopUtils, nestedLoopSharedLatchEstimatedTripCount) {
LLVMContext C;
std::unique_ptr<Module> M =
parseIR(C, "declare i1 @f()\n"
"declare i1 @g()\n"
"define void @foo() {\n"
"entry:\n"
" br label %outer\n"
"outer:\n"
" %c0 = call i1 @f()"
" br i1 %c0, label %inner, label %exit, !prof !0\n"
"inner:\n"
" %c1 = call i1 @g()"
" br i1 %c1, label %inner, label %outer, !prof !1\n"
"exit:\n"
" ret void\n"
"}\n"
"!0 = !{!\"branch_weights\", i32 100, i32 1}\n"
"!1 = !{!\"branch_weights\", i32 4, i32 1}\n"
"\n");
run(*M, "foo",
[&](Function &F, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI) {
assert(LI.end() - LI.begin() == 1 && "Expected one outer loop");
Loop *Outer = *LI.begin();
assert(Outer->end() - Outer->begin() == 1 && "Expected one inner loop");
Loop *Inner = *Outer->begin();
// Even before llvm.loop.estimated_trip_count is added to either loop,
// getLoopEstimatedTripCount rejects the outer loop.
EXPECT_EQ(getLoopEstimatedTripCount(Inner), 5);
EXPECT_EQ(getLoopEstimatedTripCount(Outer), std::nullopt);
// setLoopEstimatedTripCount for the inner loop does not affect
// getLoopEstimatedTripCount for the outer loop.
EXPECT_EQ(setLoopEstimatedTripCount(Inner, 100), true);
EXPECT_EQ(getLoopEstimatedTripCount(Inner), 100);
EXPECT_EQ(getLoopEstimatedTripCount(Outer), std::nullopt);
// setLoopEstimatedTripCount rejects the outer loop.
EXPECT_EQ(setLoopEstimatedTripCount(Outer, 999), false);
EXPECT_EQ(getLoopEstimatedTripCount(Inner), 100);
EXPECT_EQ(getLoopEstimatedTripCount(Outer), std::nullopt);
});
}