From 0bebee6782850d777cff6b10e53df86a45e7a934 Mon Sep 17 00:00:00 2001 From: Zorojuro Date: Fri, 13 Mar 2026 11:03:52 +0530 Subject: [PATCH] [libc][math][c++23] Add Fmabf16 math function (#182836) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit closes #180171 part of #177259 Here are some extra changes apart from the usual which were needed 1. `libc/src/__support/FPUtil/generic/add_sub.h` → +0 -0 error 2. `libc/src/__support/FPUtil/generic/FMA.h` → implemented to handle fmabf16(Normal,Normal,+/-INF) ```jsx /home/runner/work/llvm-project/llvm-project/libc/test/src/math/fmabf16_test.cpp:62: FAILURE Failed to match __llvm_libc_23_0_0_git::fmabf16(x, y, z) against LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher( input, __llvm_libc_23_0_0_git::fmabf16(x, y, z), 0.5, mpfr::RoundingMode::Nearest). Input decimal: x: 338953138925153547590470800371487866880.00000000000000000000000000000000000000000000000000 y: 338953138925153547590470800371487866880.00000000000000000000000000000000000000000000000000 z: -inf First input bits: 0x7F7F = (S: 0, E: 0x00FE, M: 0x007F) Second input bits: 0x7F7F = (S: 0, E: 0x00FE, M: 0x007F) Third input bits: (-Infinity) Libc result: nan MPFR result: -inf Libc floating point result bits: (NaN) MPFR rounded bits: (-Infinity) ``` 1. ~~`libc/src/__support/FPUtil/bfloat16.h` → to handle *= operator for Bfloat16 ( uses the already available mult operator)~~ moved to #182882 The exhaustive test currently includes subnormal range and for checking for specific edge cases . This is due to the repeated failure at <2^32 input space specifically only for ubuntu 24.04 and 24.04-arm The removed tests included -> PositiveRange and NegativeRange for Normals and an extra positive test for subnormals/Denormals Let me know if there are any changes expected or anything I missed in this . cc: @lntue @krishna2803 @overmighty --- libc/config/baremetal/aarch64/entrypoints.txt | 1 + libc/config/baremetal/arm/entrypoints.txt | 1 + libc/config/baremetal/riscv/entrypoints.txt | 1 + libc/config/darwin/aarch64/entrypoints.txt | 1 + libc/config/darwin/x86_64/entrypoints.txt | 1 + libc/config/gpu/amdgpu/entrypoints.txt | 1 + libc/config/gpu/nvptx/entrypoints.txt | 1 + libc/config/linux/aarch64/entrypoints.txt | 1 + libc/config/linux/arm/entrypoints.txt | 1 + libc/config/linux/riscv/entrypoints.txt | 1 + libc/config/linux/x86_64/entrypoints.txt | 1 + libc/config/windows/entrypoints.txt | 1 + libc/docs/headers/math/index.rst | 2 +- libc/shared/math.h | 1 + libc/shared/math/fmabf16.h | 23 ++++++ libc/src/__support/FPUtil/generic/FMA.h | 7 +- libc/src/__support/math/CMakeLists.txt | 11 +++ libc/src/__support/math/fmabf16.h | 27 +++++++ libc/src/math/CMakeLists.txt | 1 + libc/src/math/fmabf16.h | 21 +++++ libc/src/math/generic/CMakeLists.txt | 10 +++ libc/src/math/generic/fmabf16.cpp | 18 +++++ libc/test/shared/CMakeLists.txt | 1 + libc/test/shared/shared_math_test.cpp | 3 + libc/test/src/math/CMakeLists.txt | 13 ++++ libc/test/src/math/exhaustive/CMakeLists.txt | 19 +++++ .../test/src/math/exhaustive/fmabf16_test.cpp | 77 +++++++++++++++++++ libc/test/src/math/fmabf16_test.cpp | 66 ++++++++++++++++ libc/test/src/math/smoke/CMakeLists.txt | 14 ++++ libc/test/src/math/smoke/fmabf16_test.cpp | 14 ++++ libc/utils/MPFRWrapper/MPFRUtils.cpp | 5 ++ .../llvm-project-overlay/libc/BUILD.bazel | 18 +++++ 32 files changed, 361 insertions(+), 2 deletions(-) create mode 100644 libc/shared/math/fmabf16.h create mode 100644 libc/src/__support/math/fmabf16.h create mode 100644 libc/src/math/fmabf16.h create mode 100644 libc/src/math/generic/fmabf16.cpp create mode 100644 libc/test/src/math/exhaustive/fmabf16_test.cpp create mode 100644 libc/test/src/math/fmabf16_test.cpp create mode 100644 libc/test/src/math/smoke/fmabf16_test.cpp diff --git a/libc/config/baremetal/aarch64/entrypoints.txt b/libc/config/baremetal/aarch64/entrypoints.txt index 4e720a234d47..a6e492327451 100644 --- a/libc/config/baremetal/aarch64/entrypoints.txt +++ b/libc/config/baremetal/aarch64/entrypoints.txt @@ -400,6 +400,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/baremetal/arm/entrypoints.txt b/libc/config/baremetal/arm/entrypoints.txt index 7a7d78d28318..48fc612358e9 100644 --- a/libc/config/baremetal/arm/entrypoints.txt +++ b/libc/config/baremetal/arm/entrypoints.txt @@ -411,6 +411,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/baremetal/riscv/entrypoints.txt b/libc/config/baremetal/riscv/entrypoints.txt index 73235b0a33b0..2697cb1b5376 100644 --- a/libc/config/baremetal/riscv/entrypoints.txt +++ b/libc/config/baremetal/riscv/entrypoints.txt @@ -406,6 +406,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/darwin/aarch64/entrypoints.txt b/libc/config/darwin/aarch64/entrypoints.txt index 04b2de76aa5b..234ea3f42d6d 100644 --- a/libc/config/darwin/aarch64/entrypoints.txt +++ b/libc/config/darwin/aarch64/entrypoints.txt @@ -221,6 +221,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/darwin/x86_64/entrypoints.txt b/libc/config/darwin/x86_64/entrypoints.txt index 27e50b9e96e9..fcd874eb5ed9 100644 --- a/libc/config/darwin/x86_64/entrypoints.txt +++ b/libc/config/darwin/x86_64/entrypoints.txt @@ -147,6 +147,7 @@ set(TARGET_LIBM_ENTRYPOINTS #libc.src.math.floorf #libc.src.math.floorl #libc.src.math.fma + #libc.src.math.fmabf16 #libc.src.math.fmaf #libc.src.math.fmax #libc.src.math.fmaxf diff --git a/libc/config/gpu/amdgpu/entrypoints.txt b/libc/config/gpu/amdgpu/entrypoints.txt index 0441207ace96..c6960d31907a 100644 --- a/libc/config/gpu/amdgpu/entrypoints.txt +++ b/libc/config/gpu/amdgpu/entrypoints.txt @@ -340,6 +340,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/gpu/nvptx/entrypoints.txt b/libc/config/gpu/nvptx/entrypoints.txt index f127ba6358b4..c1927d8dc3c9 100644 --- a/libc/config/gpu/nvptx/entrypoints.txt +++ b/libc/config/gpu/nvptx/entrypoints.txt @@ -340,6 +340,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt index ea2e3021a01b..1c0abae76dbc 100644 --- a/libc/config/linux/aarch64/entrypoints.txt +++ b/libc/config/linux/aarch64/entrypoints.txt @@ -487,6 +487,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/linux/arm/entrypoints.txt b/libc/config/linux/arm/entrypoints.txt index e7764ad3c0a0..3ce816b9727e 100644 --- a/libc/config/linux/arm/entrypoints.txt +++ b/libc/config/linux/arm/entrypoints.txt @@ -311,6 +311,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/linux/riscv/entrypoints.txt b/libc/config/linux/riscv/entrypoints.txt index a202b41030bd..6c8ccb15e74b 100644 --- a/libc/config/linux/riscv/entrypoints.txt +++ b/libc/config/linux/riscv/entrypoints.txt @@ -495,6 +495,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt index c735d6443fe5..0602955dc1cb 100644 --- a/libc/config/linux/x86_64/entrypoints.txt +++ b/libc/config/linux/x86_64/entrypoints.txt @@ -541,6 +541,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmax libc.src.math.fmaxf diff --git a/libc/config/windows/entrypoints.txt b/libc/config/windows/entrypoints.txt index 4974ffcb4ae9..1ff79f8f8e89 100644 --- a/libc/config/windows/entrypoints.txt +++ b/libc/config/windows/entrypoints.txt @@ -192,6 +192,7 @@ set(TARGET_LIBM_ENTRYPOINTS libc.src.math.floorf libc.src.math.floorl libc.src.math.fma + libc.src.math.fmabf16 libc.src.math.fmaf libc.src.math.fmin libc.src.math.fminf diff --git a/libc/docs/headers/math/index.rst b/libc/docs/headers/math/index.rst index ee2b6d9b7bd1..6b987671546e 100644 --- a/libc/docs/headers/math/index.rst +++ b/libc/docs/headers/math/index.rst @@ -309,7 +309,7 @@ Higher Math Functions +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+------------------------+----------------------------+ | expm1 | |check| | |check| | | |check| | | | 7.12.6.6 | F.10.3.6 | +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+------------------------+----------------------------+ -| fma | |check| | |check| | | |check| | | | 7.12.13.1 | F.10.10.1 | +| fma | |check| | |check| | | |check| | | |check| | 7.12.13.1 | F.10.10.1 | +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+------------------------+----------------------------+ | f16sqrt | |check|\* | |check|\* | |check|\* | N/A | |check| | | 7.12.14.6 | F.10.11 | +-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+------------------------+----------------------------+ diff --git a/libc/shared/math.h b/libc/shared/math.h index ede0ebd5371a..bd70a8335a96 100644 --- a/libc/shared/math.h +++ b/libc/shared/math.h @@ -132,6 +132,7 @@ #include "math/floorf128.h" #include "math/floorf16.h" #include "math/floorl.h" +#include "math/fmabf16.h" #include "math/fmax.h" #include "math/fmaxbf16.h" #include "math/fmaxf.h" diff --git a/libc/shared/math/fmabf16.h b/libc/shared/math/fmabf16.h new file mode 100644 index 000000000000..c9c467706c1c --- /dev/null +++ b/libc/shared/math/fmabf16.h @@ -0,0 +1,23 @@ +//===-- Shared fmabf16 function ---------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SHARED_MATH_FMABF16_H +#define LLVM_LIBC_SHARED_MATH_FMABF16_H + +#include "shared/libc_common.h" +#include "src/__support/math/fmabf16.h" + +namespace LIBC_NAMESPACE_DECL { +namespace shared { + +using math::fmabf16; + +} // namespace shared +} // namespace LIBC_NAMESPACE_DECL + +#endif // LLVM_LIBC_SHARED_MATH_FMABF16_H diff --git a/libc/src/__support/FPUtil/generic/FMA.h b/libc/src/__support/FPUtil/generic/FMA.h index bec312e44b1b..9ca6d5f594c9 100644 --- a/libc/src/__support/FPUtil/generic/FMA.h +++ b/libc/src/__support/FPUtil/generic/FMA.h @@ -198,8 +198,13 @@ fma(InType x, InType y, InType z) { if (LIBC_UNLIKELY(x_exp == InFPBits::MAX_BIASED_EXPONENT || y_exp == InFPBits::MAX_BIASED_EXPONENT || - z_exp == InFPBits::MAX_BIASED_EXPONENT)) + z_exp == InFPBits::MAX_BIASED_EXPONENT)) { + if (LIBC_UNLIKELY(x_exp != InFPBits::MAX_BIASED_EXPONENT && + y_exp != InFPBits::MAX_BIASED_EXPONENT && + z_bits.is_inf())) + return cast(z); return cast(x * y + z); + } // Extract mantissa and append hidden leading bits. InStorageType x_mant = x_bits.get_explicit_mantissa(); diff --git a/libc/src/__support/math/CMakeLists.txt b/libc/src/__support/math/CMakeLists.txt index 315cc1069c1d..646fd5a93e2a 100644 --- a/libc/src/__support/math/CMakeLists.txt +++ b/libc/src/__support/math/CMakeLists.txt @@ -1265,6 +1265,17 @@ add_header_library( libc.src.__support.macros.config ) +add_header_library( + fmabf16 + HDRS + fmabf16.h + DEPENDS + libc.src.__support.FPUtil.fma + libc.src.__support.FPUtil.bfloat16 + libc.src.__support.common + libc.src.__support.macros.config +) + add_header_library( floor HDRS diff --git a/libc/src/__support/math/fmabf16.h b/libc/src/__support/math/fmabf16.h new file mode 100644 index 000000000000..f69de4bbd984 --- /dev/null +++ b/libc/src/__support/math/fmabf16.h @@ -0,0 +1,27 @@ +//===-- Implementation header for fmabf16 ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC___SUPPORT_MATH_FMABF16_H +#define LLVM_LIBC_SRC___SUPPORT_MATH_FMABF16_H + +#include "src/__support/FPUtil/FMA.h" +#include "src/__support/FPUtil/bfloat16.h" +#include "src/__support/common.h" +#include "src/__support/macros/config.h" + +namespace LIBC_NAMESPACE_DECL { +namespace math { + +LIBC_INLINE bfloat16 fmabf16(bfloat16 x, bfloat16 y, bfloat16 z) { + return fputil::fma(x, y, z); +} + +} // namespace math +} // namespace LIBC_NAMESPACE_DECL + +#endif // LLVM_LIBC_SRC___SUPPORT_MATH_FMAXBF16_H diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt index 766feb0606c6..418cf75dfa17 100644 --- a/libc/src/math/CMakeLists.txt +++ b/libc/src/math/CMakeLists.txt @@ -229,6 +229,7 @@ add_math_entrypoint_object(floorf128) add_math_entrypoint_object(floorbf16) add_math_entrypoint_object(fma) +add_math_entrypoint_object(fmabf16) add_math_entrypoint_object(fmaf) add_math_entrypoint_object(fmaf16) diff --git a/libc/src/math/fmabf16.h b/libc/src/math/fmabf16.h new file mode 100644 index 000000000000..5ba0876894f9 --- /dev/null +++ b/libc/src/math/fmabf16.h @@ -0,0 +1,21 @@ +//===-- Implementation header for fmabf16 -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIBC_SRC_MATH_FMABF16_H +#define LLVM_LIBC_SRC_MATH_FMABF16_H + +#include "src/__support/macros/config.h" +#include "src/__support/macros/properties/types.h" + +namespace LIBC_NAMESPACE_DECL { + +bfloat16 fmabf16(bfloat16 x, bfloat16 y, bfloat16 z); + +} // namespace LIBC_NAMESPACE_DECL + +#endif // LLVM_LIBC_SRC_MATH_FMABF16_H diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt index b9de548c8bed..8bc7525e4f52 100644 --- a/libc/src/math/generic/CMakeLists.txt +++ b/libc/src/math/generic/CMakeLists.txt @@ -4375,6 +4375,16 @@ add_entrypoint_object( libc.src.__support.FPUtil.fma ) +add_entrypoint_object( + fmabf16 + SRCS + fmabf16.cpp + HDRS + ../fmabf16.h + DEPENDS + libc.src.__support.math.fmabf16 +) + add_entrypoint_object( totalorder SRCS diff --git a/libc/src/math/generic/fmabf16.cpp b/libc/src/math/generic/fmabf16.cpp new file mode 100644 index 000000000000..495817ce1bd6 --- /dev/null +++ b/libc/src/math/generic/fmabf16.cpp @@ -0,0 +1,18 @@ +//===-- Implementation of fmabf16 function --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/math/fmabf16.h" +#include "src/__support/math/fmabf16.h" + +namespace LIBC_NAMESPACE_DECL { + +LLVM_LIBC_FUNCTION(bfloat16, fmabf16, (bfloat16 x, bfloat16 y, bfloat16 z)) { + return math::fmabf16(x, y, z); +} + +} // namespace LIBC_NAMESPACE_DECL diff --git a/libc/test/shared/CMakeLists.txt b/libc/test/shared/CMakeLists.txt index c90e5687d8c3..73da792b75f4 100644 --- a/libc/test/shared/CMakeLists.txt +++ b/libc/test/shared/CMakeLists.txt @@ -129,6 +129,7 @@ add_fp_unittest( libc.src.__support.math.floorf128 libc.src.__support.math.floorf16 libc.src.__support.math.floorl + libc.src.__support.math.fmabf16 libc.src.__support.math.fmax libc.src.__support.math.fmaxbf16 libc.src.__support.math.fmaxf diff --git a/libc/test/shared/shared_math_test.cpp b/libc/test/shared/shared_math_test.cpp index 17045ce5edfd..c2fa5f1eadd5 100644 --- a/libc/test/shared/shared_math_test.cpp +++ b/libc/test/shared/shared_math_test.cpp @@ -427,6 +427,9 @@ TEST(LlvmLibcSharedMathTest, AllBFloat16) { EXPECT_FP_EQ(bfloat16(0.0), LIBC_NAMESPACE::shared::floorbf16(bfloat16(0.0))); EXPECT_FP_EQ(bfloat16(0.0), LIBC_NAMESPACE::shared::fdimbf16(bfloat16(0.0), bfloat16(0.0))); + EXPECT_FP_EQ(bfloat16(10.0), + LIBC_NAMESPACE::shared::fmabf16(bfloat16(2.0), bfloat16(3.0), + bfloat16(4.0))); EXPECT_FP_EQ(bfloat16(0.0), LIBC_NAMESPACE::shared::fmaxbf16(bfloat16(0.0), bfloat16(0.0))); diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt index 90afe842e9de..c8460ee030db 100644 --- a/libc/test/src/math/CMakeLists.txt +++ b/libc/test/src/math/CMakeLists.txt @@ -2012,6 +2012,19 @@ add_fp_unittest( libc.src.stdlib.srand ) +add_fp_unittest( + fmabf16_test + NEED_MPFR + SUITE + libc-math-unittests + SRCS + fmabf16_test.cpp + DEPENDS + libc.src.math.fmabf16 + libc.src.__support.FPUtil.fp_bits + libc.src.__support.FPUtil.bfloat16 +) + add_fp_unittest( tan_test NEED_MPFR diff --git a/libc/test/src/math/exhaustive/CMakeLists.txt b/libc/test/src/math/exhaustive/CMakeLists.txt index a21e208312c5..afeb02469dcd 100644 --- a/libc/test/src/math/exhaustive/CMakeLists.txt +++ b/libc/test/src/math/exhaustive/CMakeLists.txt @@ -293,6 +293,25 @@ add_fp_unittest( -lpthread ) +add_fp_unittest( + fmabf16_test + NO_RUN_POSTBUILD + NEED_MPFR + SUITE + libc_math_exhaustive_tests + SRCS + fmabf16_test.cpp + COMPILE_OPTIONS + ${libc_opt_high_flag} + DEPENDS + .exhaustive_test + libc.src.math.fmabf16 + libc.src.__support.FPUtil.fp_bits + libc.src.__support.FPUtil.bfloat16 + LINK_LIBRARIES + -lpthread +) + add_fp_unittest( logf_test NO_RUN_POSTBUILD diff --git a/libc/test/src/math/exhaustive/fmabf16_test.cpp b/libc/test/src/math/exhaustive/fmabf16_test.cpp new file mode 100644 index 000000000000..82cb0290cd81 --- /dev/null +++ b/libc/test/src/math/exhaustive/fmabf16_test.cpp @@ -0,0 +1,77 @@ +//===-- Exhaustive tests for fmabf16 --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "exhaustive_test.h" +#include "src/__support/FPUtil/FPBits.h" +#include "src/__support/FPUtil/bfloat16.h" +#include "src/math/fmabf16.h" +#include "test/UnitTest/FPMatcher.h" +#include "utils/MPFRWrapper/MPCommon.h" +#include "utils/MPFRWrapper/MPFRUtils.h" + +namespace mpfr = LIBC_NAMESPACE::testing::mpfr; +using LIBC_NAMESPACE::fputil::BFloat16; + +struct FmaBf16Checker : public virtual LIBC_NAMESPACE::testing::Test { + + using FloatType = BFloat16; + using FPBits = LIBC_NAMESPACE::fputil::FPBits; + using StorageType = typename FPBits::StorageType; + + uint64_t check(uint16_t x_start, uint16_t x_stop, uint16_t y_start, + uint16_t y_stop, mpfr::RoundingMode rounding) { + + mpfr::ForceRoundingMode r(rounding); + if (!r.success) + return true; + uint16_t xbits = x_start; + uint64_t failed = 0; + do { + BFloat16 x = FPBits(xbits).get_val(); + uint16_t ybits = y_start; + do { + BFloat16 y = FPBits(ybits).get_val(); + BFloat16 z = FPBits(uint16_t(0x03E1)).get_val(); + mpfr::TernaryInput input{x, y, z}; + bool correct = TEST_MPFR_MATCH_ROUNDING_SILENTLY( + mpfr::Operation::Fma, input, LIBC_NAMESPACE::fmabf16(x, y, z), 0.5, + rounding); + failed += (!correct); + + } while (ybits++ < y_stop); + } while (xbits++ < x_stop); + return failed; + } +}; + +using LlvmLibcBfloat16ExhaustiveFmaTest = + LlvmLibcExhaustiveMathTest; + +// range: [0, inf] +static constexpr uint16_t POS_START = 0x0000U; +static constexpr uint16_t POS_STOP = 0x7f80U; + +// range: [-0, -inf] +static constexpr uint16_t NEG_START = 0x8000U; +static constexpr uint16_t NEG_STOP = 0xff80U; + +TEST_F(LlvmLibcBfloat16ExhaustiveFmaTest, PositiveRange) { + test_full_range_all_roundings(POS_START, POS_STOP, POS_START, POS_STOP); +} + +TEST_F(LlvmLibcBfloat16ExhaustiveFmaTest, PositiveNegative) { + test_full_range_all_roundings(POS_START, POS_STOP, NEG_START, NEG_STOP); +} + +TEST_F(LlvmLibcBfloat16ExhaustiveFmaTest, NegativePositive) { + test_full_range_all_roundings(NEG_START, NEG_STOP, POS_START, POS_STOP); +} + +TEST_F(LlvmLibcBfloat16ExhaustiveFmaTest, NegativeRange) { + test_full_range_all_roundings(NEG_START, NEG_STOP, NEG_START, NEG_STOP); +} diff --git a/libc/test/src/math/fmabf16_test.cpp b/libc/test/src/math/fmabf16_test.cpp new file mode 100644 index 000000000000..f89016a0ab86 --- /dev/null +++ b/libc/test/src/math/fmabf16_test.cpp @@ -0,0 +1,66 @@ +//===-- Unit test for fmabf16 ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "src/__support/FPUtil/bfloat16.h" +#include "src/math/fmabf16.h" +#include "test/UnitTest/FPMatcher.h" +#include "test/UnitTest/Test.h" +#include "utils/MPFRWrapper/MPFRUtils.h" + +using LlvmLibcFmaBf16Test = LIBC_NAMESPACE::testing::FPTest; + +namespace mpfr = LIBC_NAMESPACE::testing::mpfr; + +// subnormal range (negative) +static constexpr uint16_t SUBNORM_NEG_START = 0x8001U; +static constexpr uint16_t SUBNORM_NEG_STOP = 0x807FU; + +TEST_F(LlvmLibcFmaBf16Test, SubnormalNegativeRange) { + constexpr bfloat16 Z_VALUES[] = {zero, neg_zero, inf, + neg_inf, min_normal, max_normal}; + for (uint16_t v1 = SUBNORM_NEG_START; v1 <= SUBNORM_NEG_STOP; v1++) { + for (uint16_t v2 = v1; v2 <= SUBNORM_NEG_STOP; v2++) { + + bfloat16 x = FPBits(v1).get_val(); + bfloat16 y = FPBits(v2).get_val(); + for (const bfloat16 &z : Z_VALUES) { + mpfr::TernaryInput input{x, y, z}; + + EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fma, input, + LIBC_NAMESPACE::fmabf16(x, y, z), 0.5); + } + bfloat16 neg_xy = -(x * y); + mpfr::TernaryInput input{x, y, neg_xy}; + EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fma, input, + LIBC_NAMESPACE::fmabf16(x, y, neg_xy), + 0.5); + } + } +} + +TEST_F(LlvmLibcFmaBf16Test, SpecialNumbers) { + constexpr bfloat16 VALUES[] = {zero, neg_zero, inf, + neg_inf, min_normal, max_normal}; + for (size_t i = 0; i < 6; ++i) { + for (size_t j = i; j < 6; ++j) { + bfloat16 x = VALUES[i]; + bfloat16 y = VALUES[j]; + for (const bfloat16 &z : VALUES) { + mpfr::TernaryInput input{x, y, z}; + + EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fma, input, + LIBC_NAMESPACE::fmabf16(x, y, z), 0.5); + } + bfloat16 neg_xy = -(x * y); + mpfr::TernaryInput input{x, y, neg_xy}; + EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Fma, input, + LIBC_NAMESPACE::fmabf16(x, y, neg_xy), + 0.5); + } + } +} diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt index 185f646e3aa9..ce419a322611 100644 --- a/libc/test/src/math/smoke/CMakeLists.txt +++ b/libc/test/src/math/smoke/CMakeLists.txt @@ -4281,6 +4281,20 @@ add_fp_unittest( libc.src.__support.macros.properties.types ) +add_fp_unittest( + fmabf16_test + SUITE + libc-math-smoke-tests + SRCS + fmabf16_test.cpp + HDRS + FmaTest.h + DEPENDS + libc.src.math.fmabf16 + libc.src.__support.FPUtil.bfloat16 + libc.src.__support.FPUtil.fp_bits +) + add_fp_unittest( expm1_test SUITE diff --git a/libc/test/src/math/smoke/fmabf16_test.cpp b/libc/test/src/math/smoke/fmabf16_test.cpp new file mode 100644 index 000000000000..4b1ececf40d7 --- /dev/null +++ b/libc/test/src/math/smoke/fmabf16_test.cpp @@ -0,0 +1,14 @@ +//===-- Unittests for fmabf16 ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "FmaTest.h" + +#include "src/__support/FPUtil/bfloat16.h" +#include "src/math/fmabf16.h" + +LIST_FMA_TESTS(bfloat16, LIBC_NAMESPACE::fmabf16) diff --git a/libc/utils/MPFRWrapper/MPFRUtils.cpp b/libc/utils/MPFRWrapper/MPFRUtils.cpp index a7d307b47c3e..4c5c891bcfe9 100644 --- a/libc/utils/MPFRWrapper/MPFRUtils.cpp +++ b/libc/utils/MPFRWrapper/MPFRUtils.cpp @@ -496,6 +496,8 @@ explain_ternary_operation_one_output_error(Operation, float16, double, RoundingMode); #endif +template void explain_ternary_operation_one_output_error( + Operation, const TernaryInput &, bfloat16, double, RoundingMode); template void explain_ternary_operation_one_output_error( Operation, const TernaryInput &, bfloat16, double, RoundingMode); template void explain_ternary_operation_one_output_error( @@ -762,6 +764,9 @@ compare_ternary_operation_one_output(Operation, double, RoundingMode); #endif +template bool +compare_ternary_operation_one_output(Operation, const TernaryInput &, + bfloat16, double, RoundingMode); template bool compare_ternary_operation_one_output(Operation, const TernaryInput &, bfloat16, double, diff --git a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel index 9a0a70b13cae..82708e0d1ba2 100644 --- a/utils/bazel/llvm-project-overlay/libc/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/libc/BUILD.bazel @@ -4044,6 +4044,17 @@ libc_support_library( ], ) +libc_support_library( + name = "__support_math_fmabf16", + hdrs = ["src/__support/math/fmabf16.h"], + deps = [ + ":__support_common", + ":__support_fputil_bfloat16", + ":__support_fputil_fma", + ":__support_macros_config", + ], +) + libc_support_library( name = "__support_math_fmax", hdrs = ["src/__support/math/fmax.h"], @@ -6831,6 +6842,13 @@ libc_math_function( ], ) +libc_math_function( + name = "fmabf16", + additional_deps = [ + ":__support_math_fmabf16", + ], +) + libc_math_function( name = "fmax", additional_deps = [