diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 699edb188a70..9aee85bc7e9e 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -40,6 +40,29 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, return funcOp; } +/// Helper function to look up or create the symbol for a runtime library +/// function with the given parameter types. Always returns an int64_t. +static FailureOr +lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, + StringRef name, TypeRange paramTypes, + SymbolTableCollection *symbolTables = nullptr) { + auto i64Type = IntegerType::get(symTable->getContext(), 64); + + std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); + auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type}); + FailureOr func = + lookupFnDecl(symTable, funcName, funcT, symbolTables); + // Failed due to type mismatch. + if (failed(func)) + return func; + // Successfully matched existing decl. + if (*func) + return *func; + + return createFnDecl(b, symTable, funcName, funcT, + /*setPrivate=*/true, symbolTables); +} + /// Helper function to look up or create the symbol for a runtime library /// function for a binary arithmetic operation. /// @@ -55,21 +78,14 @@ lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables = nullptr) { auto i32Type = IntegerType::get(symTable->getContext(), 32); auto i64Type = IntegerType::get(symTable->getContext(), 64); + return lookupOrCreateApFloatFn(b, symTable, name, {i32Type, i64Type, i64Type}, + symbolTables); +} - std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); - FunctionType funcT = - FunctionType::get(b.getContext(), {i32Type, i64Type, i64Type}, {i64Type}); - FailureOr func = - lookupFnDecl(symTable, funcName, funcT, symbolTables); - // Failed due to type mismatch. - if (failed(func)) - return func; - // Successfully matched existing decl. - if (*func) - return *func; - - return createFnDecl(b, symTable, funcName, funcT, - /*setPrivate=*/true, symbolTables); +static Value getSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy) { + int32_t sem = llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); + return arith::ConstantOp::create(b, loc, b.getI32Type(), + b.getIntegerAttr(b.getI32Type(), sem)); } /// Rewrite a binary arithmetic operation to an APFloat function call. @@ -104,11 +120,7 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern { arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs())); // Call APFloat function. - int32_t sem = - llvm::APFloatBase::SemanticsToEnum(floatTy.getFloatSemantics()); - Value semValue = arith::ConstantOp::create( - rewriter, loc, rewriter.getI32Type(), - rewriter.getIntegerAttr(rewriter.getI32Type(), sem)); + Value semValue = getSemanticsValue(rewriter, loc, floatTy); SmallVector params = {semValue, lhsBits, rhsBits}; auto resultOp = func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), @@ -126,6 +138,53 @@ struct BinaryArithOpToAPFloatConversion final : OpRewritePattern { const char *APFloatName; }; +template +struct FpToFpConversion final : OpRewritePattern { + FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr fn = lookupOrCreateApFloatFn( + rewriter, symTable, "convert", {i32Type, i32Type, i64Type}); + if (failed(fn)) + return fn; + + rewriter.setInsertionPoint(op); + // Cast operands to 64-bit integers. + Location loc = op.getLoc(); + auto inFloatTy = cast(op.getOperand().getType()); + auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth()); + auto int64Type = rewriter.getI64Type(); + Value operandBits = arith::ExtUIOp::create( + rewriter, loc, int64Type, + arith::BitcastOp::create(rewriter, loc, inIntWType, op.getOperand())); + + // Call APFloat function. + Value inSemValue = getSemanticsValue(rewriter, loc, inFloatTy); + auto outFloatTy = cast(op.getType()); + Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy); + std::array params = {inSemValue, outSemValue, operandBits}; + auto resultOp = + func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()), + SymbolRefAttr::get(*fn), params); + + // Truncate result to the original width. + auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth()); + Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType, + resultOp->getResult(0)); + rewriter.replaceOp( + op, arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits)); + return success(); + } + + SymbolOpInterface symTable; +}; + namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase { @@ -147,6 +206,9 @@ void ArithToAPFloatConversionPass::runOnOperation() { context, "divide", getOperation()); patterns.add>( context, "remainder", getOperation()); + patterns + .add, FpToFpConversion>( + context, getOperation()); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 0a05f7369e55..511b05ea380f 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -51,7 +51,7 @@ /// Binary operations with rounding mode. #define APFLOAT_BINARY_OP_ROUNDING_MODE(OP, ROUNDING_MODE) \ - MLIR_APFLOAT_WRAPPERS_EXPORT int64_t _mlir_apfloat_##OP( \ + MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_##OP( \ int32_t semantics, uint64_t a, uint64_t b) { \ const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( \ static_cast(semantics)); \ @@ -86,4 +86,19 @@ MLIR_APFLOAT_WRAPPERS_EXPORT void printApFloat(int32_t semantics, uint64_t a) { double d = x.convertToDouble(); fprintf(stdout, "%lg", d); } + +MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t +_mlir_apfloat_convert(int32_t inSemantics, int32_t outSemantics, uint64_t a) { + const llvm::fltSemantics &inSem = llvm::APFloatBase::EnumToSemantics( + static_cast(inSemantics)); + const llvm::fltSemantics &outSem = llvm::APFloatBase::EnumToSemantics( + static_cast(outSemantics)); + unsigned bitWidthIn = llvm::APFloatBase::semanticsSizeInBits(inSem); + llvm::APFloat val(inSem, llvm::APInt(bitWidthIn, a)); + // TODO: Custom rounding modes are not supported yet. + bool losesInfo; + val.convert(outSem, llvm::RoundingMode::NearestTiesToEven, &losesInfo); + llvm::APInt result = val.bitcastToAPInt(); + return result.getZExtValue(); +} } diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index 797f42c37a26..038acbfc965a 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -126,3 +126,25 @@ func.func @remf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { %0 = arith.remf %arg0, %arg1 : f4E2M1FN return } + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 18 : i32 +// CHECK: %[[sem_out:.*]] = arith.constant 2 : i32 +// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64 +func.func @extf(%arg0: f4E2M1FN) { + %0 = arith.extf %arg0 : f4E2M1FN to f32 + return +} + +// ----- + +// CHECK: func.func private @_mlir_apfloat_convert(i32, i32, i64) -> i64 +// CHECK: %[[sem_in:.*]] = arith.constant 1 : i32 +// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32 +// CHECK: call @_mlir_apfloat_convert(%[[sem_in]], %[[sem_out]], %{{.*}}) : (i32, i32, i64) -> i64 +func.func @truncf(%arg0: bf16) { + %0 = arith.truncf %arg0 : bf16 to f4E2M1FN + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index dbaa20346a03..51976434d2be 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -27,14 +27,21 @@ func.func @entry() { %a1 = arith.constant 1.4 : f8E4M3FN %a2 = arith.constant 1.4 : f32 %b1, %b2 = func.call @foo() : () -> (f8E4M3FN, f32) - %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM - %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM - // CHECK: 3.5 + // CHECK: 2.2 + vector.print %b2 : f32 + + // CHECK-NEXT: 3.5 + %c1 = arith.addf %a1, %b1 : f8E4M3FN // not supported by LLVM vector.print %c1 : f8E4M3FN - // CHECK: 3.6 + // CHECK-NEXT: 3.6 + %c2 = arith.addf %a2, %b2 : f32 // supported by LLVM vector.print %c2 : f32 + // CHECK-NEXT: 2.25 + %cvt = arith.truncf %b2 : f32 to f8E4M3FN + vector.print %cvt : f8E4M3FN + return }