[flang] Fix spurious NaN result from infinite Kahan summation (#175373)

There are six instances of Kahan's extended precision summation
algorithm in flang/flang-rt, and they share a bug: the calculation of
the correction value produces a Nan due to the subtraction Inf-Inf after
the accumulation saturates to Inf. This leads to the surprising Nan
result from SUM([Inf, 0.]).

This bug doesn't affect run-time calculation of SUM when optimization is
enabled -- lowering emits an open-coded SUM that lacks Kahan summation
-- but it does affect compilation-time folding and -O0 runtime results.

Fix the one instance of Kahan summation in the runtime, and consolidate
the other five instances in Evaluate into one new member function, also
corrected.

Fixes https://github.com/llvm/llvm-project/issues/89528.
This commit is contained in:
Peter Klausler 2026-01-12 15:40:44 -08:00 committed by GitHub
parent 3874c4541a
commit a8ba9c4fa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 78 additions and 40 deletions

View File

@ -54,9 +54,15 @@ public:
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
// Kahan summation
auto next{x - correction_};
auto oldSum{sum_};
sum_ += next;
correction_ = (sum_ - oldSum) - next; // algebraically zero
if (next != next) {
// Avoid propagating an accidental Nan from Inf-Inf in corrections
sum_ += x;
correction_ = 0;
} else {
auto oldSum{sum_};
sum_ += next;
correction_ = (sum_ - oldSum) - next; // algebraically zero
}
return true;
}
template <typename A>

View File

@ -672,3 +672,19 @@ TEST(Reductions, ReduceInt4Dim) {
EXPECT_EQ(*sums.ZeroBasedIndexedElement<std::int32_t>(1), 6);
sums.Destroy();
}
TEST(Reductions, InfSums) {
float inf{1.0f / 0.0f};
auto inf0{MakeArray<TypeCategory::Real, 4>(
std::vector<int>{2, 3}, std::vector<float>{inf, 0.0f})};
auto t1{RTNAME(SumReal4)(*inf0, __FILE__, __LINE__)};
EXPECT_EQ(t1, inf) << t1;
auto infMinusInf{MakeArray<TypeCategory::Real, 4>(
std::vector<int>{2, 3}, std::vector<float>{inf, -inf})};
auto t2{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)};
EXPECT_NE(t2, t2) << t2;
auto minusInfInf{MakeArray<TypeCategory::Real, 4>(
std::vector<int>{2, 3}, std::vector<float>{-inf, inf})};
auto t3{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)};
EXPECT_NE(t3, t3) << t3;
}

View File

@ -77,6 +77,9 @@ public:
Rounding rounding = TargetCharacteristics::defaultRounding) const;
ValueWithRealFlags<Complex> Divide(const Complex &,
Rounding rounding = TargetCharacteristics::defaultRounding) const;
ValueWithRealFlags<Complex> KahanSummation(const Complex &,
Complex &correction,
Rounding rounding = TargetCharacteristics::defaultRounding) const;
// ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2)
ValueWithRealFlags<Part> ABS(

View File

@ -175,6 +175,8 @@ public:
Rounding rounding = TargetCharacteristics::defaultRounding) const;
ValueWithRealFlags<Real> MODULO(const Real &,
Rounding rounding = TargetCharacteristics::defaultRounding) const;
ValueWithRealFlags<Real> KahanSummation(const Real &, Real &correction,
Rounding rounding = TargetCharacteristics::defaultRounding) const;
template <typename INT> constexpr INT EXPONENT() const {
if (Exponent() == maxExponent) {

View File

@ -100,6 +100,17 @@ ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
return {Complex{re, im}, flags};
}
template <typename R>
ValueWithRealFlags<Complex<R>> Complex<R>::KahanSummation(
const Complex &that, Complex &correction, Rounding rounding) const {
RealFlags flags;
Part reSum{re_.KahanSummation(that.re_, correction.re_, rounding)
.AccumulateFlags(flags)};
Part imSum{im_.KahanSummation(that.im_, correction.im_, rounding)
.AccumulateFlags(flags)};
return {Complex{reSum, imSum}, flags};
}
template <typename R> std::string Complex<R>::DumpHexadecimal() const {
std::string result{'('};
result += re_.DumpHexadecimal();

View File

@ -61,18 +61,13 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
if constexpr (useKahanSummation) {
auto next{product.value.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
auto added{sum.KahanSummation(product.value, correction)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
sum = added.value;
} else {
auto added{sum.Add(product.value)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
sum = added.value;
}
} else if constexpr (T::category == TypeCategory::Integer) {
auto product{aElt.MultiplySigned(bElt)};

View File

@ -77,13 +77,8 @@ public:
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
if constexpr (useKahanSummation) {
auto next{square.Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
auto sum{element.KahanSummation(square, correction_, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
} else {
auto sum{element.Add(square, rounding_)};

View File

@ -47,18 +47,13 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
auto added{sum.KahanSummation(x, correction, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
sum = added.value;
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
sum = added.value;
}
}
} else if constexpr (T::category == TypeCategory::Logical) {
@ -97,18 +92,13 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
auto added{sum.KahanSummation(x, correction, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
sum = added.value;
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
sum = added.value;
}
}
}
@ -357,14 +347,8 @@ public:
} else if constexpr (T::category == TypeCategory::Unsigned) {
element = element.AddUnsigned(array_.At(at)).value;
} else { // Real & Complex: use Kahan summation
auto next{array_.At(at).Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
auto sum{element.KahanSummation(array_.At(at), correction_, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
// correction = (sum - element) - next; algebraically zero
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
}
}

View File

@ -465,6 +465,24 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::MODULO(
return result;
}
template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::KahanSummation(
const Real &y, Real &correction, Rounding rounding) const {
Real next{y.Subtract(correction, rounding).value};
if (next.IsNotANumber()) {
// Avoid propagating an accidental NaN from Inf-Inf in corrections
correction = Real{}; // 0.
return Add(y, rounding);
} else {
auto sum{Add(next, rounding)};
// correction = (sum - *this) - next; algebraically zero
correction = sum.value.Subtract(*this, rounding)
.value.Subtract(next, rounding)
.value;
return sum;
}
}
template <typename W, int P>
ValueWithRealFlags<Real<W, P>> Real<W, P>::DIM(
const Real &y, Rounding rounding) const {

View File

@ -0,0 +1,8 @@
!RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
!CHECK: REAL :: avoidkahannan = (1._4/0.)
real :: avoidKahanNaN = sum([1./0., 0.]) ! Inf, not NaN
!CHECK: REAL :: expectnan1 = (0._4/0.)
real :: expectNaN1 = sum([1./0., -1./0.])
!CHECK: REAL :: expectnan2 = (0._4/0.)
real :: expectNaN2 = sum([-1./0., 1./0.])
end