llvm-project/mlir/lib/ExecutionEngine/APFloatWrappers.cpp
Maksim Levental 222f4e494a
[mlir] Add FP software implementation lowering pass: arith-to-apfloat (#166618)
This commit adds a new pass that lowers floating-point `arith`
operations to calls into the execution engine runtime library. Currently
supported operations: `addf`, `subf`, `mulf`, `divf`, `remf`.

All floating-point types that have an APFloat semantics are supported.
This includes low-precision floating-point types such as `f4E2M1FN` that
cannot execute natively on CPUs.

This commit also improves the `vector.print` lowering pattern to call
into the runtime library for floating-point types that are not supported
by LLVM. This is necessary to write a meaningful integration test.

The way it works is 

```mlir
func.func @full_example() {
  %a = arith.constant 1.4 : f8E4M3FN
  %b = func.call @foo() : () -> (f8E4M3FN)
  %c = arith.addf %a, %b : f8E4M3FN
  vector.print %c : f8E4M3FN
  return
}
```

gets transformed to

```mlir
func.func private @__mlir_apfloat_add(i32, i64, i64) -> i6
func.func @full_example() {
  %cst = arith.constant 1.375000e+00 : f8E4M3FN
  %0 = call @foo() : () -> f8E4M3FN
  // bitcast operand A to integer of equal width
  %1 = arith.bitcast %cst : f8E4M3FN to i8
  // zext A to i64
  %2 = arith.extui %1 : i8 to i64
  // same for operand B
  %3 = arith.bitcast %0 : f8E4M3FN to i8
  %4 = arith.extui %3 : i8 to i64
  // get the llvm::fltSemantics(f8E4M3FN) as an enum
  %c10_i32 = arith.constant 10 : i32
  // call the impl against APFloat in mlir_apfloat_wrappers
  %5 = call @__mlir_apfloat_add(%c10_i32, %2, %4) : (i32, i64, i64) -> i64
  // "cast" back to the original fp type
  %6 = arith.trunci %5 : i64 to i8
  %7 = arith.bitcast %6 : i8 to f8E4M3FN
  vector.print %7 : f8E4M3FN
}
```

Note, `llvm::fltSemantics(f8E4M3FN)` is emitted by the pattern each time
an `arith` op is transformed, thereby making the call to
`__mlir_apfloat_add` correct (i.e., no name mangling on type necessary).


RFC:
https://discourse.llvm.org/t/rfc-software-implementation-for-unsupported-fp-types-in-convert-arith-to-llvm/88785

---------

Co-authored-by: Matthias Springer <me@m-sp.org>
2025-11-10 16:21:39 -08:00

82 lines
3.9 KiB
C++

//===- APFloatWrappers.cpp - Software Implementation of FP Arithmetics --- ===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file exposes the APFloat infrastructure to MLIR programs as a runtime
// library. APFloat is a software implementation of floating point arithmetics.
//
// On the MLIR side, floating-point values must be bitcasted to 64-bit integers
// before calling a runtime function. If a floating-point type has less than
// 64 bits, it must be zero-extended to 64 bits after bitcasting it to an
// integer.
//
// Runtime functions receive the floating-point operands of the arithmeic
// operation in the form of 64-bit integers, along with the APFloat semantics
// in the form of a 32-bit integer, which will be interpreted as an
// APFloatBase::Semantics enum value.
//
#include "llvm/ADT/APFloat.h"
#if (defined(_WIN32) || defined(__CYGWIN__))
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __declspec(dllexport)
#else
#define MLIR_APFLOAT_WRAPPERS_EXPORTED __attribute__((visibility("default")))
#endif
/// Binary operations without rounding mode.
#define APFLOAT_BINARY_OP(OP) \
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \
int32_t semantics, uint64_t a, uint64_t b) { \
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
lhs.OP(rhs); \
return lhs.bitcastToAPInt().getZExtValue(); \
}
/// Binary operations with rounding mode.
#define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \
int64_t MLIR_APFLOAT_WRAPPERS_EXPORTED __mlir_apfloat_##OP( \
int32_t semantics, uint64_t a, uint64_t b) { \
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \
static_cast<llvm::APFloatBase::Semantics>(semantics)); \
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); \
llvm::APFloat lhs(sem, llvm::APInt(bitWidth, a)); \
llvm::APFloat rhs(sem, llvm::APInt(bitWidth, b)); \
lhs.OP(rhs, ROUNDING_MODE); \
return lhs.bitcastToAPInt().getZExtValue(); \
}
extern "C" {
#define BIN_OPS_WITH_ROUNDING(X) \
X(add, llvm::RoundingMode::NearestTiesToEven) \
X(subtract, llvm::RoundingMode::NearestTiesToEven) \
X(multiply, llvm::RoundingMode::NearestTiesToEven) \
X(divide, llvm::RoundingMode::NearestTiesToEven)
BIN_OPS_WITH_ROUNDING(APFLOAT_BINARY_OP_ROUNDING_MODE)
#undef BIN_OPS_WITH_ROUNDING
#undef APFLOAT_BINARY_OP_ROUNDING_MODE
APFLOAT_BINARY_OP(remainder)
#undef APFLOAT_BINARY_OP
void MLIR_APFLOAT_WRAPPERS_EXPORTED printApFloat(int32_t semantics,
uint64_t a) {
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
static_cast<llvm::APFloatBase::Semantics>(semantics));
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
double d = x.convertToDouble();
fprintf(stdout, "%lg", d);
}
}