From 59cd1847a79938bfa4229e18a392c4113dc1d20a Mon Sep 17 00:00:00 2001 From: "Joel E. Denny" Date: Tue, 19 Aug 2025 15:25:39 -0400 Subject: [PATCH] Fix case where nested loops share latch --- .../include/llvm/Transforms/Utils/LoopUtils.h | 14 +++-- llvm/lib/Transforms/Utils/LoopUtils.cpp | 49 +++++++++++++---- .../Transforms/Utils/LoopUtilsTest.cpp | 53 +++++++++++++++++++ 3 files changed, 97 insertions(+), 19 deletions(-) diff --git a/llvm/include/llvm/Transforms/Utils/LoopUtils.h b/llvm/include/llvm/Transforms/Utils/LoopUtils.h index 0ee22a24e56c..da2c835e9e4b 100644 --- a/llvm/include/llvm/Transforms/Utils/LoopUtils.h +++ b/llvm/include/llvm/Transforms/Utils/LoopUtils.h @@ -324,12 +324,12 @@ LLVM_ABI void addStringMetadataToLoop(Loop *TheLoop, const char *MDString, unsigned V = 0); /// Return either: +/// - \c std::nullopt, if the implementation is unable to handle the loop form +/// of \p L (e.g., \p L must have a latch block that controls the loop exit). /// - The value of \c llvm.loop.estimated_trip_count from the loop metadata of /// \p L, if that metadata is present. /// - Else, a new estimate of the trip count from the latch branch weights of -/// \p L, if the estimation's implementation is able to handle the loop form -/// of \p L (e.g., \p L must have a latch block that controls the loop exit). -/// - Else, \c std::nullopt. +/// \p L. /// /// An estimated trip count is always a valid positive trip count, saturated at /// \c UINT_MAX. @@ -350,17 +350,15 @@ getLoopEstimatedTripCount(Loop *L, unsigned *EstimatedLoopInvocationWeight = nullptr); /// Set \c llvm.loop.estimated_trip_count with the value \p EstimatedTripCount -/// in the loop metadata of \p L. +/// in the loop metadata of \p L. Return false if the implementation is unable +/// to handle the loop form of \p L (e.g., \p L must have a latch block that +/// controls the loop exit). Otherwise, return true. /// /// In addition, if \p EstimatedLoopInvocationWeight, set the branch weight /// metadata of \p L to reflect that \p L has an estimated /// \p EstimatedTripCount iterations and has \c *EstimatedLoopInvocationWeight /// exit weight through the loop's latch. /// -/// Return false if \p EstimatedLoopInvocationWeight and if branch weight -/// metadata could not be successfully updated (e.g., if \p L does not have a -/// latch block that controls the loop exit). Otherwise, return true. -/// /// TODO: Eventually, once all passes have migrated away from setting branch /// weights to indicate estimated trip counts, this function will drop the /// \p EstimatedLoopInvocationWeight parameter. diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index 38d249a50b0a..b5238f64d99f 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -868,6 +868,27 @@ static std::optional estimateLoopTripCount(Loop *L) { std::optional llvm::getLoopEstimatedTripCount(Loop *L, unsigned *EstimatedLoopInvocationWeight) { + // If EstimatedLoopInvocationWeight, we do not support this loop if + // getExpectedExitLoopLatchBranch returns nullptr. + // + // FIXME: Also, this is a stop-gap solution for nested loops. It avoids + // mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when + // it was created for an inner loop. The problem is that loop metadata is + // attached to the branch instruction in the loop latch block, but that can be + // shared by the loops. The solution is to attach loop metadata to loop + // headers instead, but that would be a large change to LLVM. + // + // Until that happens, we work around the problem as follows. + // getExpectedExitLoopLatchBranch (which also guards + // setLoopEstimatedTripCount) will not recognize the same latch for both loops + // unless the latch exits both loops and has only two successors. However, to + // exit both loops, the latch must have at least three successors: the inner + // loop header, the outer loop header (exit for the inner loop), and an exit + // for the outer loop. + BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L); + if (!ExitingBranch) + return std::nullopt; + // If requested, either compute *EstimatedLoopInvocationWeight or return // nullopt if cannot. // @@ -875,16 +896,14 @@ llvm::getLoopEstimatedTripCount(Loop *L, // weights to indicate estimated trip counts, this function will drop the // EstimatedLoopInvocationWeight parameter. if (EstimatedLoopInvocationWeight) { - if (BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L)) { - uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused. - if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) - return std::nullopt; - if (L->contains(ExitingBranch->getSuccessor(1))) - std::swap(LoopWeight, ExitWeight); - if (!ExitWeight) - return std::nullopt; - *EstimatedLoopInvocationWeight = ExitWeight; - } + uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused. + if (!extractBranchWeights(*ExitingBranch, LoopWeight, ExitWeight)) + return std::nullopt; + if (L->contains(ExitingBranch->getSuccessor(1))) + std::swap(LoopWeight, ExitWeight); + if (!ExitWeight) + return std::nullopt; + *EstimatedLoopInvocationWeight = ExitWeight; } // Return the estimated trip count from metadata unless the metadata is @@ -903,6 +922,15 @@ llvm::getLoopEstimatedTripCount(Loop *L, bool llvm::setLoopEstimatedTripCount( Loop *L, unsigned EstimatedTripCount, std::optional EstimatedloopInvocationWeight) { + // If EstimatedLoopInvocationWeight, we do not support this loop if + // getExpectedExitLoopLatchBranch returns nullptr. + // + // FIXME: See comments in getLoopEstimatedTripCount for why this is required + // here regardless of EstimatedLoopInvocationWeight. + BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L); + if (!LatchBranch) + return false; + // Set the metadata. addStringMetadataToLoop(L, LLVMLoopEstimatedTripCount, EstimatedTripCount); @@ -915,7 +943,6 @@ bool llvm::setLoopEstimatedTripCount( // here at all. if (!EstimatedloopInvocationWeight) return true; - BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L); if (!LatchBranch) return false; diff --git a/llvm/unittests/Transforms/Utils/LoopUtilsTest.cpp b/llvm/unittests/Transforms/Utils/LoopUtilsTest.cpp index c22a3582bee8..ce002e923996 100644 --- a/llvm/unittests/Transforms/Utils/LoopUtilsTest.cpp +++ b/llvm/unittests/Transforms/Utils/LoopUtilsTest.cpp @@ -142,3 +142,56 @@ TEST(LoopUtils, IsKnownNonPositiveInLoopTest) { EXPECT_EQ(isKnownNonPositiveInLoop(ArgSCEV, L, SE), true); }); } + +// The inner and outer loop here share a latch. Because any loop metadata must +// be attached to that latch, loop metadata cannot distinguish between the two +// loops. Until that problem is solved (by moving loop metadata to loops' +// header blocks instead), getLoopEstimatedTripCount and +// setLoopEstimatedTripCount must refuse to operate on at least one of the two +// loops. They choose to reject the outer loop here because the latch does not +// exit it. +TEST(LoopUtils, nestedLoopSharedLatchEstimatedTripCount) { + LLVMContext C; + std::unique_ptr M = + parseIR(C, "declare i1 @f()\n" + "declare i1 @g()\n" + "define void @foo() {\n" + "entry:\n" + " br label %outer\n" + "outer:\n" + " %c0 = call i1 @f()" + " br i1 %c0, label %inner, label %exit, !prof !0\n" + "inner:\n" + " %c1 = call i1 @g()" + " br i1 %c1, label %inner, label %outer, !prof !1\n" + "exit:\n" + " ret void\n" + "}\n" + "!0 = !{!\"branch_weights\", i32 100, i32 1}\n" + "!1 = !{!\"branch_weights\", i32 4, i32 1}\n" + "\n"); + + run(*M, "foo", + [&](Function &F, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI) { + assert(LI.end() - LI.begin() == 1 && "Expected one outer loop"); + Loop *Outer = *LI.begin(); + assert(Outer->end() - Outer->begin() == 1 && "Expected one inner loop"); + Loop *Inner = *Outer->begin(); + + // Even before llvm.loop.estimated_trip_count is added to either loop, + // getLoopEstimatedTripCount rejects the outer loop. + EXPECT_EQ(getLoopEstimatedTripCount(Inner), 5); + EXPECT_EQ(getLoopEstimatedTripCount(Outer), std::nullopt); + + // setLoopEstimatedTripCount for the inner loop does not affect + // getLoopEstimatedTripCount for the outer loop. + EXPECT_EQ(setLoopEstimatedTripCount(Inner, 100), true); + EXPECT_EQ(getLoopEstimatedTripCount(Inner), 100); + EXPECT_EQ(getLoopEstimatedTripCount(Outer), std::nullopt); + + // setLoopEstimatedTripCount rejects the outer loop. + EXPECT_EQ(setLoopEstimatedTripCount(Outer, 999), false); + EXPECT_EQ(getLoopEstimatedTripCount(Inner), 100); + EXPECT_EQ(getLoopEstimatedTripCount(Outer), std::nullopt); + }); +}