From a8ba9c4fa2ae1ae80738e7848c28c19a9aa564da Mon Sep 17 00:00:00 2001 From: Peter Klausler Date: Mon, 12 Jan 2026 15:40:44 -0800 Subject: [PATCH] [flang] Fix spurious NaN result from infinite Kahan summation (#175373) There are six instances of Kahan's extended precision summation algorithm in flang/flang-rt, and they share a bug: the calculation of the correction value produces a Nan due to the subtraction Inf-Inf after the accumulation saturates to Inf. This leads to the surprising Nan result from SUM([Inf, 0.]). This bug doesn't affect run-time calculation of SUM when optimization is enabled -- lowering emits an open-coded SUM that lacks Kahan summation -- but it does affect compilation-time folding and -O0 runtime results. Fix the one instance of Kahan summation in the runtime, and consolidate the other five instances in Evaluate into one new member function, also corrected. Fixes https://github.com/llvm/llvm-project/issues/89528. --- flang-rt/lib/runtime/sum.cpp | 12 +++++++--- flang-rt/unittests/Runtime/Reduction.cpp | 16 +++++++++++++ flang/include/flang/Evaluate/complex.h | 3 +++ flang/include/flang/Evaluate/real.h | 2 ++ flang/lib/Evaluate/complex.cpp | 11 +++++++++ flang/lib/Evaluate/fold-matmul.h | 11 +++------ flang/lib/Evaluate/fold-real.cpp | 7 +----- flang/lib/Evaluate/fold-reduction.h | 30 ++++++------------------ flang/lib/Evaluate/real.cpp | 18 ++++++++++++++ flang/test/Evaluate/bug89528.f90 | 8 +++++++ 10 files changed, 78 insertions(+), 40 deletions(-) create mode 100644 flang/test/Evaluate/bug89528.f90 diff --git a/flang-rt/lib/runtime/sum.cpp b/flang-rt/lib/runtime/sum.cpp index a76e228f18a4..0c540606f27c 100644 --- a/flang-rt/lib/runtime/sum.cpp +++ b/flang-rt/lib/runtime/sum.cpp @@ -54,9 +54,15 @@ public: template RT_API_ATTRS bool Accumulate(A x) { // Kahan summation auto next{x - correction_}; - auto oldSum{sum_}; - sum_ += next; - correction_ = (sum_ - oldSum) - next; // algebraically zero + if (next != next) { + // Avoid propagating an accidental Nan from Inf-Inf in corrections + sum_ += x; + correction_ = 0; + } else { + auto oldSum{sum_}; + sum_ += next; + correction_ = (sum_ - oldSum) - next; // algebraically zero + } return true; } template diff --git a/flang-rt/unittests/Runtime/Reduction.cpp b/flang-rt/unittests/Runtime/Reduction.cpp index 3701a32042c5..ac6fed1e34a9 100644 --- a/flang-rt/unittests/Runtime/Reduction.cpp +++ b/flang-rt/unittests/Runtime/Reduction.cpp @@ -672,3 +672,19 @@ TEST(Reductions, ReduceInt4Dim) { EXPECT_EQ(*sums.ZeroBasedIndexedElement(1), 6); sums.Destroy(); } + +TEST(Reductions, InfSums) { + float inf{1.0f / 0.0f}; + auto inf0{MakeArray( + std::vector{2, 3}, std::vector{inf, 0.0f})}; + auto t1{RTNAME(SumReal4)(*inf0, __FILE__, __LINE__)}; + EXPECT_EQ(t1, inf) << t1; + auto infMinusInf{MakeArray( + std::vector{2, 3}, std::vector{inf, -inf})}; + auto t2{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)}; + EXPECT_NE(t2, t2) << t2; + auto minusInfInf{MakeArray( + std::vector{2, 3}, std::vector{-inf, inf})}; + auto t3{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)}; + EXPECT_NE(t3, t3) << t3; +} diff --git a/flang/include/flang/Evaluate/complex.h b/flang/include/flang/Evaluate/complex.h index 720ccaf512df..9781db9a25a6 100644 --- a/flang/include/flang/Evaluate/complex.h +++ b/flang/include/flang/Evaluate/complex.h @@ -77,6 +77,9 @@ public: Rounding rounding = TargetCharacteristics::defaultRounding) const; ValueWithRealFlags Divide(const Complex &, Rounding rounding = TargetCharacteristics::defaultRounding) const; + ValueWithRealFlags KahanSummation(const Complex &, + Complex &correction, + Rounding rounding = TargetCharacteristics::defaultRounding) const; // ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2) ValueWithRealFlags ABS( diff --git a/flang/include/flang/Evaluate/real.h b/flang/include/flang/Evaluate/real.h index dcd74073a473..c0a966820d13 100644 --- a/flang/include/flang/Evaluate/real.h +++ b/flang/include/flang/Evaluate/real.h @@ -175,6 +175,8 @@ public: Rounding rounding = TargetCharacteristics::defaultRounding) const; ValueWithRealFlags MODULO(const Real &, Rounding rounding = TargetCharacteristics::defaultRounding) const; + ValueWithRealFlags KahanSummation(const Real &, Real &correction, + Rounding rounding = TargetCharacteristics::defaultRounding) const; template constexpr INT EXPONENT() const { if (Exponent() == maxExponent) { diff --git a/flang/lib/Evaluate/complex.cpp b/flang/lib/Evaluate/complex.cpp index ab83f193e3f3..a245fb38c82b 100644 --- a/flang/lib/Evaluate/complex.cpp +++ b/flang/lib/Evaluate/complex.cpp @@ -100,6 +100,17 @@ ValueWithRealFlags> Complex::Divide( return {Complex{re, im}, flags}; } +template +ValueWithRealFlags> Complex::KahanSummation( + const Complex &that, Complex &correction, Rounding rounding) const { + RealFlags flags; + Part reSum{re_.KahanSummation(that.re_, correction.re_, rounding) + .AccumulateFlags(flags)}; + Part imSum{im_.KahanSummation(that.im_, correction.im_, rounding) + .AccumulateFlags(flags)}; + return {Complex{reSum, imSum}, flags}; +} + template std::string Complex::DumpHexadecimal() const { std::string result{'('}; result += re_.DumpHexadecimal(); diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h index ae9221f9ce04..a8a24c09774e 100644 --- a/flang/lib/Evaluate/fold-matmul.h +++ b/flang/lib/Evaluate/fold-matmul.h @@ -61,18 +61,13 @@ static Expr FoldMatmul(FoldingContext &context, FunctionRef &&funcRef) { auto product{aElt.Multiply(bElt)}; overflow |= product.flags.test(RealFlag::Overflow); if constexpr (useKahanSummation) { - auto next{product.value.Subtract(correction, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; + auto added{sum.KahanSummation(product.value, correction)}; overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + sum = added.value; } else { auto added{sum.Add(product.value)}; overflow |= added.flags.test(RealFlag::Overflow); - sum = std::move(added.value); + sum = added.value; } } else if constexpr (T::category == TypeCategory::Integer) { auto product{aElt.MultiplySigned(bElt)}; diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp index 907d01b005a0..9c591e2ef36e 100644 --- a/flang/lib/Evaluate/fold-real.cpp +++ b/flang/lib/Evaluate/fold-real.cpp @@ -77,13 +77,8 @@ public: auto scaled{item.Divide(scale).value}; auto square{scaled.Multiply(scaled).value}; if constexpr (useKahanSummation) { - auto next{square.Subtract(correction_, rounding_)}; - overflow_ |= next.flags.test(RealFlag::Overflow); - auto sum{element.Add(next.value, rounding_)}; + auto sum{element.KahanSummation(square, correction_, rounding_)}; overflow_ |= sum.flags.test(RealFlag::Overflow); - correction_ = sum.value.Subtract(element, rounding_) - .value.Subtract(next.value, rounding_) - .value; element = sum.value; } else { auto sum{element.Add(square, rounding_)}; diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h index fe897393fe13..a06836413529 100644 --- a/flang/lib/Evaluate/fold-reduction.h +++ b/flang/lib/Evaluate/fold-reduction.h @@ -47,18 +47,13 @@ static Expr FoldDotProduct( const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { if constexpr (useKahanSummation) { - auto next{x.Subtract(correction, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; + auto added{sum.KahanSummation(x, correction, rounding)}; overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + sum = added.value; } else { auto added{sum.Add(x, rounding)}; overflow |= added.flags.test(RealFlag::Overflow); - sum = std::move(added.value); + sum = added.value; } } } else if constexpr (T::category == TypeCategory::Logical) { @@ -97,18 +92,13 @@ static Expr FoldDotProduct( const auto &rounding{context.targetCharacteristics().roundingMode()}; for (const Element &x : cProducts.values()) { if constexpr (useKahanSummation) { - auto next{x.Subtract(correction, rounding)}; - overflow |= next.flags.test(RealFlag::Overflow); - auto added{sum.Add(next.value, rounding)}; + auto added{sum.KahanSummation(x, correction, rounding)}; overflow |= added.flags.test(RealFlag::Overflow); - correction = added.value.Subtract(sum, rounding) - .value.Subtract(next.value, rounding) - .value; - sum = std::move(added.value); + sum = added.value; } else { auto added{sum.Add(x, rounding)}; overflow |= added.flags.test(RealFlag::Overflow); - sum = std::move(added.value); + sum = added.value; } } } @@ -357,14 +347,8 @@ public: } else if constexpr (T::category == TypeCategory::Unsigned) { element = element.AddUnsigned(array_.At(at)).value; } else { // Real & Complex: use Kahan summation - auto next{array_.At(at).Subtract(correction_, rounding_)}; - overflow_ |= next.flags.test(RealFlag::Overflow); - auto sum{element.Add(next.value, rounding_)}; + auto sum{element.KahanSummation(array_.At(at), correction_, rounding_)}; overflow_ |= sum.flags.test(RealFlag::Overflow); - // correction = (sum - element) - next; algebraically zero - correction_ = sum.value.Subtract(element, rounding_) - .value.Subtract(next.value, rounding_) - .value; element = sum.value; } } diff --git a/flang/lib/Evaluate/real.cpp b/flang/lib/Evaluate/real.cpp index 6e6b9f3ac77c..eb335ce32851 100644 --- a/flang/lib/Evaluate/real.cpp +++ b/flang/lib/Evaluate/real.cpp @@ -465,6 +465,24 @@ ValueWithRealFlags> Real::MODULO( return result; } +template +ValueWithRealFlags> Real::KahanSummation( + const Real &y, Real &correction, Rounding rounding) const { + Real next{y.Subtract(correction, rounding).value}; + if (next.IsNotANumber()) { + // Avoid propagating an accidental NaN from Inf-Inf in corrections + correction = Real{}; // 0. + return Add(y, rounding); + } else { + auto sum{Add(next, rounding)}; + // correction = (sum - *this) - next; algebraically zero + correction = sum.value.Subtract(*this, rounding) + .value.Subtract(next, rounding) + .value; + return sum; + } +} + template ValueWithRealFlags> Real::DIM( const Real &y, Rounding rounding) const { diff --git a/flang/test/Evaluate/bug89528.f90 b/flang/test/Evaluate/bug89528.f90 new file mode 100644 index 000000000000..281f35bf653c --- /dev/null +++ b/flang/test/Evaluate/bug89528.f90 @@ -0,0 +1,8 @@ +!RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s +!CHECK: REAL :: avoidkahannan = (1._4/0.) +real :: avoidKahanNaN = sum([1./0., 0.]) ! Inf, not NaN +!CHECK: REAL :: expectnan1 = (0._4/0.) +real :: expectNaN1 = sum([1./0., -1./0.]) +!CHECK: REAL :: expectnan2 = (0._4/0.) +real :: expectNaN2 = sum([-1./0., 1./0.]) +end