[mlir][spirv] Add 8-bit float type emulation (#148811)

8-bit floats are not supported in SPIR-V. They are emulated as 8-bit
integer during conversion.
This commit is contained in:
Md Abdullah Shahneous Bari 2025-07-30 17:39:49 -05:00 committed by GitHub
parent c8b6ddf3a3
commit b9a627e6fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 185 additions and 7 deletions

View File

@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
"bool", /*default=*/"true", "bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by " "Emulate narrower scalar types with 32-bit ones if not supported by "
"the target">, "the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by representing them with integer "
"types of same bit width">
]; ];
} }
@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true", "bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by" "Emulate narrower scalar types with 32-bit ones if not supported by"
" the target"> " the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by representing them with integer "
"types of same bit width">
]; ];
} }
@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true", "bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by" "Emulate narrower scalar types with 32-bit ones if not supported by"
" the target"> " the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by representing them with integer "
"types of same bit width">
]; ];
} }
@ -1167,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types", Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
"bool", /*default=*/"true", "bool", /*default=*/"true",
"Emulate narrower scalar types with 32-bit ones if not supported by" "Emulate narrower scalar types with 32-bit ones if not supported by"
" the target"> " the target">,
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
"bool", /*default=*/"true",
"Emulate unsupported float types by representing them with integer "
"types of same bit width">
]; ];
} }

View File

@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
/// The number of bits to store a boolean value. /// The number of bits to store a boolean value.
unsigned boolNumBits{8}; unsigned boolNumBits{8};
/// Whether to emulate unsupported floats with integer types of same bit
/// width.
bool emulateUnsupportedFloatTypes{true};
/// How sub-byte values are storaged in memory. /// How sub-byte values are storaged in memory.
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed}; SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};

View File

@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat()); return builder.getF32FloatAttr(dstVal.convertToFloat());
} }
// Get in IntegerAttr from FloatAttr while preserving the bits.
// Useful for converting float constants to integer constants while preserving
// the bits.
static IntegerAttr
getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
ConversionPatternRewriter &rewriter) {
APFloat floatVal = floatAttr.getValue();
APInt intVal = floatVal.bitcastToAPInt();
return rewriter.getIntegerAttr(dstType, intVal);
}
/// Returns true if the given `type` is a boolean scalar or vector type. /// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) { static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type"); assert(type && "Not a valid type");
@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements; SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) { if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
FloatAttr dstAttr = Attribute dstAttr = nullptr;
convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); // Handle 8-bit float conversion to 8-bit integer.
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
srcElemType.getIntOrFloatBitWidth() == 8 &&
isa<IntegerType>(dstElemType)) {
dstAttr =
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
} else {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
rewriter);
}
if (!dstAttr) if (!dstAttr)
return failure(); return failure();
elements.push_back(dstAttr); elements.push_back(dstAttr);
@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types. // Floating-point types.
if (isa<FloatType>(srcType)) { if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr); auto srcAttr = cast<FloatAttr>(cstAttr);
auto dstAttr = srcAttr; Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all // Floating-point types not supported in the target environment are all
// converted to float type. // converted to float type.
if (srcType != dstType) { auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
dstType.getIntOrFloatBitWidth() == 8) {
// If the source is an 8-bit float, convert it to a 8-bit integer.
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
if (!dstAttr)
return failure();
} else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr) if (!dstAttr)
return failure(); return failure();
@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options; SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options); SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull // Use UnrealizedConversionCast as the bridge so that we don't need to pull

View File

@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options; SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options); SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion. // TODO: We should also take care of block argument type conversion.

View File

@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options; SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options); SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);

View File

@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options; SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options); SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context); RewritePatternSet patterns(context);

View File

@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8; return bitWidth / 8;
} }
// Handle 8-bit floats.
if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
auto bitWidth = type.getIntOrFloatBitWidth();
if (bitWidth == 8)
return bitWidth / 8;
return std::nullopt;
}
if (auto complexType = dyn_cast<ComplexType>(type)) { if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType()); auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize) if (!elementSize)
@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness()); type.getSignedness());
} }
/// Converts 8-bit float types to integer types with the same bit width.
/// Returns a nullptr for unsupported 8-bit float types.
static Type convert8BitFloatType(const SPIRVConversionOptions &options,
FloatType type) {
if (!options.emulateUnsupportedFloatTypes)
return nullptr;
// F8 types are converted to integer types with the same bit width.
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float8E8M0FNUType>(type))
return IntegerType::get(type.getContext(), type.getWidth());
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
return nullptr;
}
/// Returns a type with the same shape but with any 8-bit float element type
/// converted to the same bit width integer type. This is a noop when the
/// element type is not the 8-bit float type or emulation flag is set to false.
static ShapedType
convertShaped8BitFloatType(ShapedType type,
const SPIRVConversionOptions &options) {
if (!options.emulateUnsupportedFloatTypes)
return type;
Type srcElementType = type.getElementType();
Type convertedElementType = nullptr;
// F8 types are converted to integer types with the same bit width.
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
Float8E8M0FNUType>(srcElementType))
convertedElementType = IntegerType::get(
type.getContext(), srcElementType.getIntOrFloatBitWidth());
if (!convertedElementType)
return type;
return type.clone(convertedElementType);
}
/// Returns a type with the same shape but with any index element type converted /// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not /// to the matching integer type. This is a noop when the element type is not
/// the index type. /// the index type.
@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type, const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) { std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options)); type = cast<VectorType>(convertIndexElementType(type, options));
type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) { if (!scalarType) {
// If this is not a spec allowed scalar type, try to handle sub-byte integer // If this is not a spec allowed scalar type, try to handle sub-byte integer
@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
} }
type = cast<TensorType>(convertIndexElementType(type, options)); type = cast<TensorType>(convertIndexElementType(type, options));
type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) { if (!scalarType) {
LLVM_DEBUG(llvm::dbgs() LLVM_DEBUG(llvm::dbgs()
@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) { } else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options)); type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType(); arrayElemType = type.getElementType();
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
// Hnadle 8 bit float types.
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
arrayElemType = type.getElementType();
} else { } else {
LLVM_DEBUG( LLVM_DEBUG(
llvm::dbgs() llvm::dbgs()
@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> { addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType)) if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType); return convertScalarType(this->targetEnv, this->options, scalarType);
if (floatType.getWidth() == 8)
return convert8BitFloatType(this->options, floatType);
return Type(); return Type();
}); });

View File

@ -559,6 +559,23 @@ func.func @constant() {
return return
} }
// CHECK-LABEL: @constant_8bit_float
func.func @constant_8bit_float() {
// CHECK: spirv.Constant 56 : i8
%cst = arith.constant 1.0 : f8E4M3
// CHECK: spirv.Constant 56 : i8
%cst_i8 = arith.bitcast %cst : f8E4M3 to i8
// CHECK: spirv.Constant dense<56> : vector<4xi8>
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
// CHECK: spirv.Constant dense<56> : vector<4xi8>
%cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
%cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
return
}
// CHECK-LABEL: @constant_16bit // CHECK-LABEL: @constant_16bit
func.func @constant_16bit() { func.func @constant_16bit() {
// CHECK: spirv.Constant 4 : i16 // CHECK: spirv.Constant 4 : i16

View File

@ -1,6 +1,8 @@
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s // RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \ // RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
// RUN: FileCheck %s --check-prefix=NOEMU // RUN: FileCheck %s --check-prefix=NOEMU
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Integer types // Integer types
@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return } func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
} // end module } // end module
// -----
// Check that 8-bit float types are emulated as i8.
module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
} {
// CHECK: spirv.func @float8_to_integer8
// CHECK-SAME: (%arg0: i8
// CHECK-SAME: %arg1: i8
// CHECK-SAME: %arg2: i8
// CHECK-SAME: %arg3: i8
// CHECK-SAME: %arg4: i8
// CHECK-SAME: %arg5: i8
// CHECK-SAME: %arg6: i8
// CHECK-SAME: %arg7: i8
// CHECK-SAME: %arg8: vector<4xi8>
// CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
// CHECK-SAME: %arg10: !spirv.array<4 x i8>
// UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
// UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
// UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
// UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
// UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
// UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
// UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
// UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
// UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
// UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
// UNSUPPORTED_FLOAT-SAME: ) {
func.func @float8_to_integer8(
%arg0: f8E5M2, // CHECK-NOT: f8E5M2
%arg1: f8E4M3, // CHECK-NOT: f8E4M3
%arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
%arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
%arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
%arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
%arg6: f8E3M4, // CHECK-NOT: f8E3M4
%arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
%arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
%arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
%arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
) {
// CHECK: spirv.Return
return
}
}