[mlir][spirv] Add 8-bit float type emulation (#148811)
8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion.
This commit is contained in:
parent
c8b6ddf3a3
commit
b9a627e6fb
@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
|
||||
"bool", /*default=*/"true",
|
||||
"Emulate narrower scalar types with 32-bit ones if not supported by "
|
||||
"the target">,
|
||||
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
|
||||
"bool", /*default=*/"true",
|
||||
"Emulate unsupported float types by representing them with integer "
|
||||
"types of same bit width">
|
||||
];
|
||||
}
|
||||
|
||||
@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
|
||||
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
||||
"bool", /*default=*/"true",
|
||||
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
||||
" the target">
|
||||
" the target">,
|
||||
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
|
||||
"bool", /*default=*/"true",
|
||||
"Emulate unsupported float types by representing them with integer "
|
||||
"types of same bit width">
|
||||
];
|
||||
}
|
||||
|
||||
@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
|
||||
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
||||
"bool", /*default=*/"true",
|
||||
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
||||
" the target">
|
||||
" the target">,
|
||||
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
|
||||
"bool", /*default=*/"true",
|
||||
"Emulate unsupported float types by representing them with integer "
|
||||
"types of same bit width">
|
||||
];
|
||||
}
|
||||
|
||||
@ -1167,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
|
||||
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
||||
"bool", /*default=*/"true",
|
||||
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
||||
" the target">
|
||||
" the target">,
|
||||
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
|
||||
"bool", /*default=*/"true",
|
||||
"Emulate unsupported float types by representing them with integer "
|
||||
"types of same bit width">
|
||||
];
|
||||
}
|
||||
|
||||
|
@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
|
||||
/// The number of bits to store a boolean value.
|
||||
unsigned boolNumBits{8};
|
||||
|
||||
/// Whether to emulate unsupported floats with integer types of same bit
|
||||
/// width.
|
||||
bool emulateUnsupportedFloatTypes{true};
|
||||
|
||||
/// How sub-byte values are storaged in memory.
|
||||
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
|
||||
|
||||
|
@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
|
||||
return builder.getF32FloatAttr(dstVal.convertToFloat());
|
||||
}
|
||||
|
||||
// Get in IntegerAttr from FloatAttr while preserving the bits.
|
||||
// Useful for converting float constants to integer constants while preserving
|
||||
// the bits.
|
||||
static IntegerAttr
|
||||
getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
APFloat floatVal = floatAttr.getValue();
|
||||
APInt intVal = floatVal.bitcastToAPInt();
|
||||
return rewriter.getIntegerAttr(dstType, intVal);
|
||||
}
|
||||
|
||||
/// Returns true if the given `type` is a boolean scalar or vector type.
|
||||
static bool isBoolScalarOrVector(Type type) {
|
||||
assert(type && "Not a valid type");
|
||||
@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
|
||||
SmallVector<Attribute, 8> elements;
|
||||
if (isa<FloatType>(srcElemType)) {
|
||||
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
|
||||
FloatAttr dstAttr =
|
||||
convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
|
||||
Attribute dstAttr = nullptr;
|
||||
// Handle 8-bit float conversion to 8-bit integer.
|
||||
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
|
||||
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
|
||||
srcElemType.getIntOrFloatBitWidth() == 8 &&
|
||||
isa<IntegerType>(dstElemType)) {
|
||||
dstAttr =
|
||||
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
|
||||
} else {
|
||||
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
|
||||
rewriter);
|
||||
}
|
||||
if (!dstAttr)
|
||||
return failure();
|
||||
elements.push_back(dstAttr);
|
||||
@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
|
||||
// Floating-point types.
|
||||
if (isa<FloatType>(srcType)) {
|
||||
auto srcAttr = cast<FloatAttr>(cstAttr);
|
||||
auto dstAttr = srcAttr;
|
||||
Attribute dstAttr = srcAttr;
|
||||
|
||||
// Floating-point types not supported in the target environment are all
|
||||
// converted to float type.
|
||||
if (srcType != dstType) {
|
||||
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
|
||||
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
|
||||
srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
|
||||
dstType.getIntOrFloatBitWidth() == 8) {
|
||||
// If the source is an 8-bit float, convert it to a 8-bit integer.
|
||||
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
|
||||
if (!dstAttr)
|
||||
return failure();
|
||||
} else if (srcType != dstType) {
|
||||
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
|
||||
if (!dstAttr)
|
||||
return failure();
|
||||
@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
|
||||
|
||||
SPIRVConversionOptions options;
|
||||
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
||||
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
|
||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||
|
||||
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
|
||||
|
@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
|
||||
|
||||
SPIRVConversionOptions options;
|
||||
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
||||
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
|
||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||
|
||||
// TODO: We should also take care of block argument type conversion.
|
||||
|
@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
|
||||
|
||||
SPIRVConversionOptions options;
|
||||
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
||||
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
|
||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
|
||||
|
||||
SPIRVConversionOptions options;
|
||||
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
||||
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
|
||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
|
||||
return bitWidth / 8;
|
||||
}
|
||||
|
||||
// Handle 8-bit floats.
|
||||
if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
|
||||
auto bitWidth = type.getIntOrFloatBitWidth();
|
||||
if (bitWidth == 8)
|
||||
return bitWidth / 8;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (auto complexType = dyn_cast<ComplexType>(type)) {
|
||||
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
|
||||
if (!elementSize)
|
||||
@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
|
||||
type.getSignedness());
|
||||
}
|
||||
|
||||
/// Converts 8-bit float types to integer types with the same bit width.
|
||||
/// Returns a nullptr for unsupported 8-bit float types.
|
||||
static Type convert8BitFloatType(const SPIRVConversionOptions &options,
|
||||
FloatType type) {
|
||||
if (!options.emulateUnsupportedFloatTypes)
|
||||
return nullptr;
|
||||
// F8 types are converted to integer types with the same bit width.
|
||||
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
|
||||
Float8E8M0FNUType>(type))
|
||||
return IntegerType::get(type.getContext(), type.getWidth());
|
||||
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns a type with the same shape but with any 8-bit float element type
|
||||
/// converted to the same bit width integer type. This is a noop when the
|
||||
/// element type is not the 8-bit float type or emulation flag is set to false.
|
||||
static ShapedType
|
||||
convertShaped8BitFloatType(ShapedType type,
|
||||
const SPIRVConversionOptions &options) {
|
||||
if (!options.emulateUnsupportedFloatTypes)
|
||||
return type;
|
||||
Type srcElementType = type.getElementType();
|
||||
Type convertedElementType = nullptr;
|
||||
// F8 types are converted to integer types with the same bit width.
|
||||
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
|
||||
Float8E8M0FNUType>(srcElementType))
|
||||
convertedElementType = IntegerType::get(
|
||||
type.getContext(), srcElementType.getIntOrFloatBitWidth());
|
||||
|
||||
if (!convertedElementType)
|
||||
return type;
|
||||
|
||||
return type.clone(convertedElementType);
|
||||
}
|
||||
|
||||
/// Returns a type with the same shape but with any index element type converted
|
||||
/// to the matching integer type. This is a noop when the element type is not
|
||||
/// the index type.
|
||||
@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
|
||||
const SPIRVConversionOptions &options, VectorType type,
|
||||
std::optional<spirv::StorageClass> storageClass = {}) {
|
||||
type = cast<VectorType>(convertIndexElementType(type, options));
|
||||
type = cast<VectorType>(convertShaped8BitFloatType(type, options));
|
||||
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
|
||||
if (!scalarType) {
|
||||
// If this is not a spec allowed scalar type, try to handle sub-byte integer
|
||||
@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
|
||||
}
|
||||
|
||||
type = cast<TensorType>(convertIndexElementType(type, options));
|
||||
type = cast<TensorType>(convertShaped8BitFloatType(type, options));
|
||||
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
|
||||
if (!scalarType) {
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
|
||||
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
|
||||
type = cast<MemRefType>(convertIndexElementType(type, options));
|
||||
arrayElemType = type.getElementType();
|
||||
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
||||
// Hnadle 8 bit float types.
|
||||
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
|
||||
arrayElemType = type.getElementType();
|
||||
} else {
|
||||
LLVM_DEBUG(
|
||||
llvm::dbgs()
|
||||
@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
|
||||
addConversion([this](FloatType floatType) -> std::optional<Type> {
|
||||
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
|
||||
return convertScalarType(this->targetEnv, this->options, scalarType);
|
||||
if (floatType.getWidth() == 8)
|
||||
return convert8BitFloatType(this->options, floatType);
|
||||
return Type();
|
||||
});
|
||||
|
||||
|
@ -559,6 +559,23 @@ func.func @constant() {
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @constant_8bit_float
|
||||
func.func @constant_8bit_float() {
|
||||
// CHECK: spirv.Constant 56 : i8
|
||||
%cst = arith.constant 1.0 : f8E4M3
|
||||
// CHECK: spirv.Constant 56 : i8
|
||||
%cst_i8 = arith.bitcast %cst : f8E4M3 to i8
|
||||
// CHECK: spirv.Constant dense<56> : vector<4xi8>
|
||||
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
|
||||
// CHECK: spirv.Constant dense<56> : vector<4xi8>
|
||||
%cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
|
||||
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
|
||||
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
|
||||
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
|
||||
%cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @constant_16bit
|
||||
func.func @constant_16bit() {
|
||||
// CHECK: spirv.Constant 4 : i16
|
||||
|
@ -1,6 +1,8 @@
|
||||
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
|
||||
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
|
||||
// RUN: FileCheck %s --check-prefix=NOEMU
|
||||
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
|
||||
// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Integer types
|
||||
@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
|
||||
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
|
||||
|
||||
} // end module
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// Check that 8-bit float types are emulated as i8.
|
||||
module attributes {
|
||||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
|
||||
} {
|
||||
|
||||
// CHECK: spirv.func @float8_to_integer8
|
||||
// CHECK-SAME: (%arg0: i8
|
||||
// CHECK-SAME: %arg1: i8
|
||||
// CHECK-SAME: %arg2: i8
|
||||
// CHECK-SAME: %arg3: i8
|
||||
// CHECK-SAME: %arg4: i8
|
||||
// CHECK-SAME: %arg5: i8
|
||||
// CHECK-SAME: %arg6: i8
|
||||
// CHECK-SAME: %arg7: i8
|
||||
// CHECK-SAME: %arg8: vector<4xi8>
|
||||
// CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
|
||||
// CHECK-SAME: %arg10: !spirv.array<4 x i8>
|
||||
// UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
|
||||
// UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
|
||||
// UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
|
||||
// UNSUPPORTED_FLOAT-SAME: ) {
|
||||
|
||||
func.func @float8_to_integer8(
|
||||
%arg0: f8E5M2, // CHECK-NOT: f8E5M2
|
||||
%arg1: f8E4M3, // CHECK-NOT: f8E4M3
|
||||
%arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
|
||||
%arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
|
||||
%arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
|
||||
%arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
|
||||
%arg6: f8E3M4, // CHECK-NOT: f8E3M4
|
||||
%arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
|
||||
%arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
|
||||
%arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
|
||||
%arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
|
||||
) {
|
||||
// CHECK: spirv.Return
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user