[flang] Fix spurious NaN result from infinite Kahan summation (#175373)
There are six instances of Kahan's extended precision summation algorithm in flang/flang-rt, and they share a bug: the calculation of the correction value produces a Nan due to the subtraction Inf-Inf after the accumulation saturates to Inf. This leads to the surprising Nan result from SUM([Inf, 0.]). This bug doesn't affect run-time calculation of SUM when optimization is enabled -- lowering emits an open-coded SUM that lacks Kahan summation -- but it does affect compilation-time folding and -O0 runtime results. Fix the one instance of Kahan summation in the runtime, and consolidate the other five instances in Evaluate into one new member function, also corrected. Fixes https://github.com/llvm/llvm-project/issues/89528.
This commit is contained in:
parent
3874c4541a
commit
a8ba9c4fa2
@ -54,9 +54,15 @@ public:
|
||||
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
|
||||
// Kahan summation
|
||||
auto next{x - correction_};
|
||||
auto oldSum{sum_};
|
||||
sum_ += next;
|
||||
correction_ = (sum_ - oldSum) - next; // algebraically zero
|
||||
if (next != next) {
|
||||
// Avoid propagating an accidental Nan from Inf-Inf in corrections
|
||||
sum_ += x;
|
||||
correction_ = 0;
|
||||
} else {
|
||||
auto oldSum{sum_};
|
||||
sum_ += next;
|
||||
correction_ = (sum_ - oldSum) - next; // algebraically zero
|
||||
}
|
||||
return true;
|
||||
}
|
||||
template <typename A>
|
||||
|
||||
@ -672,3 +672,19 @@ TEST(Reductions, ReduceInt4Dim) {
|
||||
EXPECT_EQ(*sums.ZeroBasedIndexedElement<std::int32_t>(1), 6);
|
||||
sums.Destroy();
|
||||
}
|
||||
|
||||
TEST(Reductions, InfSums) {
|
||||
float inf{1.0f / 0.0f};
|
||||
auto inf0{MakeArray<TypeCategory::Real, 4>(
|
||||
std::vector<int>{2, 3}, std::vector<float>{inf, 0.0f})};
|
||||
auto t1{RTNAME(SumReal4)(*inf0, __FILE__, __LINE__)};
|
||||
EXPECT_EQ(t1, inf) << t1;
|
||||
auto infMinusInf{MakeArray<TypeCategory::Real, 4>(
|
||||
std::vector<int>{2, 3}, std::vector<float>{inf, -inf})};
|
||||
auto t2{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)};
|
||||
EXPECT_NE(t2, t2) << t2;
|
||||
auto minusInfInf{MakeArray<TypeCategory::Real, 4>(
|
||||
std::vector<int>{2, 3}, std::vector<float>{-inf, inf})};
|
||||
auto t3{RTNAME(SumReal4)(*infMinusInf, __FILE__, __LINE__)};
|
||||
EXPECT_NE(t3, t3) << t3;
|
||||
}
|
||||
|
||||
@ -77,6 +77,9 @@ public:
|
||||
Rounding rounding = TargetCharacteristics::defaultRounding) const;
|
||||
ValueWithRealFlags<Complex> Divide(const Complex &,
|
||||
Rounding rounding = TargetCharacteristics::defaultRounding) const;
|
||||
ValueWithRealFlags<Complex> KahanSummation(const Complex &,
|
||||
Complex &correction,
|
||||
Rounding rounding = TargetCharacteristics::defaultRounding) const;
|
||||
|
||||
// ABS/CABS = HYPOT(re_, imag_) = SQRT(re_**2 + im_**2)
|
||||
ValueWithRealFlags<Part> ABS(
|
||||
|
||||
@ -175,6 +175,8 @@ public:
|
||||
Rounding rounding = TargetCharacteristics::defaultRounding) const;
|
||||
ValueWithRealFlags<Real> MODULO(const Real &,
|
||||
Rounding rounding = TargetCharacteristics::defaultRounding) const;
|
||||
ValueWithRealFlags<Real> KahanSummation(const Real &, Real &correction,
|
||||
Rounding rounding = TargetCharacteristics::defaultRounding) const;
|
||||
|
||||
template <typename INT> constexpr INT EXPONENT() const {
|
||||
if (Exponent() == maxExponent) {
|
||||
|
||||
@ -100,6 +100,17 @@ ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
|
||||
return {Complex{re, im}, flags};
|
||||
}
|
||||
|
||||
template <typename R>
|
||||
ValueWithRealFlags<Complex<R>> Complex<R>::KahanSummation(
|
||||
const Complex &that, Complex &correction, Rounding rounding) const {
|
||||
RealFlags flags;
|
||||
Part reSum{re_.KahanSummation(that.re_, correction.re_, rounding)
|
||||
.AccumulateFlags(flags)};
|
||||
Part imSum{im_.KahanSummation(that.im_, correction.im_, rounding)
|
||||
.AccumulateFlags(flags)};
|
||||
return {Complex{reSum, imSum}, flags};
|
||||
}
|
||||
|
||||
template <typename R> std::string Complex<R>::DumpHexadecimal() const {
|
||||
std::string result{'('};
|
||||
result += re_.DumpHexadecimal();
|
||||
|
||||
@ -61,18 +61,13 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
|
||||
auto product{aElt.Multiply(bElt)};
|
||||
overflow |= product.flags.test(RealFlag::Overflow);
|
||||
if constexpr (useKahanSummation) {
|
||||
auto next{product.value.Subtract(correction, rounding)};
|
||||
overflow |= next.flags.test(RealFlag::Overflow);
|
||||
auto added{sum.Add(next.value, rounding)};
|
||||
auto added{sum.KahanSummation(product.value, correction)};
|
||||
overflow |= added.flags.test(RealFlag::Overflow);
|
||||
correction = added.value.Subtract(sum, rounding)
|
||||
.value.Subtract(next.value, rounding)
|
||||
.value;
|
||||
sum = std::move(added.value);
|
||||
sum = added.value;
|
||||
} else {
|
||||
auto added{sum.Add(product.value)};
|
||||
overflow |= added.flags.test(RealFlag::Overflow);
|
||||
sum = std::move(added.value);
|
||||
sum = added.value;
|
||||
}
|
||||
} else if constexpr (T::category == TypeCategory::Integer) {
|
||||
auto product{aElt.MultiplySigned(bElt)};
|
||||
|
||||
@ -77,13 +77,8 @@ public:
|
||||
auto scaled{item.Divide(scale).value};
|
||||
auto square{scaled.Multiply(scaled).value};
|
||||
if constexpr (useKahanSummation) {
|
||||
auto next{square.Subtract(correction_, rounding_)};
|
||||
overflow_ |= next.flags.test(RealFlag::Overflow);
|
||||
auto sum{element.Add(next.value, rounding_)};
|
||||
auto sum{element.KahanSummation(square, correction_, rounding_)};
|
||||
overflow_ |= sum.flags.test(RealFlag::Overflow);
|
||||
correction_ = sum.value.Subtract(element, rounding_)
|
||||
.value.Subtract(next.value, rounding_)
|
||||
.value;
|
||||
element = sum.value;
|
||||
} else {
|
||||
auto sum{element.Add(square, rounding_)};
|
||||
|
||||
@ -47,18 +47,13 @@ static Expr<T> FoldDotProduct(
|
||||
const auto &rounding{context.targetCharacteristics().roundingMode()};
|
||||
for (const Element &x : cProducts.values()) {
|
||||
if constexpr (useKahanSummation) {
|
||||
auto next{x.Subtract(correction, rounding)};
|
||||
overflow |= next.flags.test(RealFlag::Overflow);
|
||||
auto added{sum.Add(next.value, rounding)};
|
||||
auto added{sum.KahanSummation(x, correction, rounding)};
|
||||
overflow |= added.flags.test(RealFlag::Overflow);
|
||||
correction = added.value.Subtract(sum, rounding)
|
||||
.value.Subtract(next.value, rounding)
|
||||
.value;
|
||||
sum = std::move(added.value);
|
||||
sum = added.value;
|
||||
} else {
|
||||
auto added{sum.Add(x, rounding)};
|
||||
overflow |= added.flags.test(RealFlag::Overflow);
|
||||
sum = std::move(added.value);
|
||||
sum = added.value;
|
||||
}
|
||||
}
|
||||
} else if constexpr (T::category == TypeCategory::Logical) {
|
||||
@ -97,18 +92,13 @@ static Expr<T> FoldDotProduct(
|
||||
const auto &rounding{context.targetCharacteristics().roundingMode()};
|
||||
for (const Element &x : cProducts.values()) {
|
||||
if constexpr (useKahanSummation) {
|
||||
auto next{x.Subtract(correction, rounding)};
|
||||
overflow |= next.flags.test(RealFlag::Overflow);
|
||||
auto added{sum.Add(next.value, rounding)};
|
||||
auto added{sum.KahanSummation(x, correction, rounding)};
|
||||
overflow |= added.flags.test(RealFlag::Overflow);
|
||||
correction = added.value.Subtract(sum, rounding)
|
||||
.value.Subtract(next.value, rounding)
|
||||
.value;
|
||||
sum = std::move(added.value);
|
||||
sum = added.value;
|
||||
} else {
|
||||
auto added{sum.Add(x, rounding)};
|
||||
overflow |= added.flags.test(RealFlag::Overflow);
|
||||
sum = std::move(added.value);
|
||||
sum = added.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -357,14 +347,8 @@ public:
|
||||
} else if constexpr (T::category == TypeCategory::Unsigned) {
|
||||
element = element.AddUnsigned(array_.At(at)).value;
|
||||
} else { // Real & Complex: use Kahan summation
|
||||
auto next{array_.At(at).Subtract(correction_, rounding_)};
|
||||
overflow_ |= next.flags.test(RealFlag::Overflow);
|
||||
auto sum{element.Add(next.value, rounding_)};
|
||||
auto sum{element.KahanSummation(array_.At(at), correction_, rounding_)};
|
||||
overflow_ |= sum.flags.test(RealFlag::Overflow);
|
||||
// correction = (sum - element) - next; algebraically zero
|
||||
correction_ = sum.value.Subtract(element, rounding_)
|
||||
.value.Subtract(next.value, rounding_)
|
||||
.value;
|
||||
element = sum.value;
|
||||
}
|
||||
}
|
||||
|
||||
@ -465,6 +465,24 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::MODULO(
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename W, int P>
|
||||
ValueWithRealFlags<Real<W, P>> Real<W, P>::KahanSummation(
|
||||
const Real &y, Real &correction, Rounding rounding) const {
|
||||
Real next{y.Subtract(correction, rounding).value};
|
||||
if (next.IsNotANumber()) {
|
||||
// Avoid propagating an accidental NaN from Inf-Inf in corrections
|
||||
correction = Real{}; // 0.
|
||||
return Add(y, rounding);
|
||||
} else {
|
||||
auto sum{Add(next, rounding)};
|
||||
// correction = (sum - *this) - next; algebraically zero
|
||||
correction = sum.value.Subtract(*this, rounding)
|
||||
.value.Subtract(next, rounding)
|
||||
.value;
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename W, int P>
|
||||
ValueWithRealFlags<Real<W, P>> Real<W, P>::DIM(
|
||||
const Real &y, Rounding rounding) const {
|
||||
|
||||
8
flang/test/Evaluate/bug89528.f90
Normal file
8
flang/test/Evaluate/bug89528.f90
Normal file
@ -0,0 +1,8 @@
|
||||
!RUN: %flang_fc1 -fdebug-unparse %s 2>&1 | FileCheck %s
|
||||
!CHECK: REAL :: avoidkahannan = (1._4/0.)
|
||||
real :: avoidKahanNaN = sum([1./0., 0.]) ! Inf, not NaN
|
||||
!CHECK: REAL :: expectnan1 = (0._4/0.)
|
||||
real :: expectNaN1 = sum([1./0., -1./0.])
|
||||
!CHECK: REAL :: expectnan2 = (0._4/0.)
|
||||
real :: expectNaN2 = sum([-1./0., 1./0.])
|
||||
end
|
||||
Loading…
x
Reference in New Issue
Block a user