[mlir][spirv] Add 8-bit float type emulation (#148811)
8-bit floats are not supported in SPIR-V. They are emulated as 8-bit integer during conversion.
This commit is contained in:
parent
c8b6ddf3a3
commit
b9a627e6fb
@ -196,6 +196,10 @@ def ConvertArithToSPIRVPass : Pass<"convert-arith-to-spirv"> {
|
|||||||
"bool", /*default=*/"true",
|
"bool", /*default=*/"true",
|
||||||
"Emulate narrower scalar types with 32-bit ones if not supported by "
|
"Emulate narrower scalar types with 32-bit ones if not supported by "
|
||||||
"the target">,
|
"the target">,
|
||||||
|
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
|
||||||
|
"bool", /*default=*/"true",
|
||||||
|
"Emulate unsupported float types by representing them with integer "
|
||||||
|
"types of same bit width">
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -416,7 +420,11 @@ def ConvertControlFlowToSPIRVPass : Pass<"convert-cf-to-spirv"> {
|
|||||||
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
||||||
"bool", /*default=*/"true",
|
"bool", /*default=*/"true",
|
||||||
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
||||||
" the target">
|
" the target">,
|
||||||
|
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
|
||||||
|
"bool", /*default=*/"true",
|
||||||
|
"Emulate unsupported float types by representing them with integer "
|
||||||
|
"types of same bit width">
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -500,7 +508,11 @@ def ConvertFuncToSPIRVPass : Pass<"convert-func-to-spirv"> {
|
|||||||
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
||||||
"bool", /*default=*/"true",
|
"bool", /*default=*/"true",
|
||||||
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
||||||
" the target">
|
" the target">,
|
||||||
|
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
|
||||||
|
"bool", /*default=*/"true",
|
||||||
|
"Emulate unsupported float types by representing them with integer "
|
||||||
|
"types of same bit width">
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1167,7 +1179,11 @@ def ConvertTensorToSPIRVPass : Pass<"convert-tensor-to-spirv"> {
|
|||||||
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
|
||||||
"bool", /*default=*/"true",
|
"bool", /*default=*/"true",
|
||||||
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
"Emulate narrower scalar types with 32-bit ones if not supported by"
|
||||||
" the target">
|
" the target">,
|
||||||
|
Option<"emulateUnsupportedFloatTypes", "emulate-unsupported-float-types",
|
||||||
|
"bool", /*default=*/"true",
|
||||||
|
"Emulate unsupported float types by representing them with integer "
|
||||||
|
"types of same bit width">
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,6 +39,10 @@ struct SPIRVConversionOptions {
|
|||||||
/// The number of bits to store a boolean value.
|
/// The number of bits to store a boolean value.
|
||||||
unsigned boolNumBits{8};
|
unsigned boolNumBits{8};
|
||||||
|
|
||||||
|
/// Whether to emulate unsupported floats with integer types of same bit
|
||||||
|
/// width.
|
||||||
|
bool emulateUnsupportedFloatTypes{true};
|
||||||
|
|
||||||
/// How sub-byte values are storaged in memory.
|
/// How sub-byte values are storaged in memory.
|
||||||
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
|
SPIRVSubByteTypeStorage subByteTypeStorage{SPIRVSubByteTypeStorage::Packed};
|
||||||
|
|
||||||
|
@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
|
|||||||
return builder.getF32FloatAttr(dstVal.convertToFloat());
|
return builder.getF32FloatAttr(dstVal.convertToFloat());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get in IntegerAttr from FloatAttr while preserving the bits.
|
||||||
|
// Useful for converting float constants to integer constants while preserving
|
||||||
|
// the bits.
|
||||||
|
static IntegerAttr
|
||||||
|
getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
|
||||||
|
ConversionPatternRewriter &rewriter) {
|
||||||
|
APFloat floatVal = floatAttr.getValue();
|
||||||
|
APInt intVal = floatVal.bitcastToAPInt();
|
||||||
|
return rewriter.getIntegerAttr(dstType, intVal);
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns true if the given `type` is a boolean scalar or vector type.
|
/// Returns true if the given `type` is a boolean scalar or vector type.
|
||||||
static bool isBoolScalarOrVector(Type type) {
|
static bool isBoolScalarOrVector(Type type) {
|
||||||
assert(type && "Not a valid type");
|
assert(type && "Not a valid type");
|
||||||
@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
|
|||||||
SmallVector<Attribute, 8> elements;
|
SmallVector<Attribute, 8> elements;
|
||||||
if (isa<FloatType>(srcElemType)) {
|
if (isa<FloatType>(srcElemType)) {
|
||||||
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
|
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
|
||||||
FloatAttr dstAttr =
|
Attribute dstAttr = nullptr;
|
||||||
convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
|
// Handle 8-bit float conversion to 8-bit integer.
|
||||||
|
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
|
||||||
|
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
|
||||||
|
srcElemType.getIntOrFloatBitWidth() == 8 &&
|
||||||
|
isa<IntegerType>(dstElemType)) {
|
||||||
|
dstAttr =
|
||||||
|
getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
|
||||||
|
} else {
|
||||||
|
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
|
||||||
|
rewriter);
|
||||||
|
}
|
||||||
if (!dstAttr)
|
if (!dstAttr)
|
||||||
return failure();
|
return failure();
|
||||||
elements.push_back(dstAttr);
|
elements.push_back(dstAttr);
|
||||||
@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
|
|||||||
// Floating-point types.
|
// Floating-point types.
|
||||||
if (isa<FloatType>(srcType)) {
|
if (isa<FloatType>(srcType)) {
|
||||||
auto srcAttr = cast<FloatAttr>(cstAttr);
|
auto srcAttr = cast<FloatAttr>(cstAttr);
|
||||||
auto dstAttr = srcAttr;
|
Attribute dstAttr = srcAttr;
|
||||||
|
|
||||||
// Floating-point types not supported in the target environment are all
|
// Floating-point types not supported in the target environment are all
|
||||||
// converted to float type.
|
// converted to float type.
|
||||||
if (srcType != dstType) {
|
auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
|
||||||
|
if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
|
||||||
|
srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
|
||||||
|
dstType.getIntOrFloatBitWidth() == 8) {
|
||||||
|
// If the source is an 8-bit float, convert it to a 8-bit integer.
|
||||||
|
dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
|
||||||
|
if (!dstAttr)
|
||||||
|
return failure();
|
||||||
|
} else if (srcType != dstType) {
|
||||||
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
|
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
|
||||||
if (!dstAttr)
|
if (!dstAttr)
|
||||||
return failure();
|
return failure();
|
||||||
@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
|
|||||||
|
|
||||||
SPIRVConversionOptions options;
|
SPIRVConversionOptions options;
|
||||||
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
||||||
|
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
|
||||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||||
|
|
||||||
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
|
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
|
||||||
|
@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
|
|||||||
|
|
||||||
SPIRVConversionOptions options;
|
SPIRVConversionOptions options;
|
||||||
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
||||||
|
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
|
||||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||||
|
|
||||||
// TODO: We should also take care of block argument type conversion.
|
// TODO: We should also take care of block argument type conversion.
|
||||||
|
@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
|
|||||||
|
|
||||||
SPIRVConversionOptions options;
|
SPIRVConversionOptions options;
|
||||||
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
||||||
|
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
|
||||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
|
|||||||
|
|
||||||
SPIRVConversionOptions options;
|
SPIRVConversionOptions options;
|
||||||
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
|
||||||
|
options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
|
||||||
SPIRVTypeConverter typeConverter(targetAttr, options);
|
SPIRVTypeConverter typeConverter(targetAttr, options);
|
||||||
|
|
||||||
RewritePatternSet patterns(context);
|
RewritePatternSet patterns(context);
|
||||||
|
@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
|
|||||||
return bitWidth / 8;
|
return bitWidth / 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle 8-bit floats.
|
||||||
|
if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
|
||||||
|
auto bitWidth = type.getIntOrFloatBitWidth();
|
||||||
|
if (bitWidth == 8)
|
||||||
|
return bitWidth / 8;
|
||||||
|
return std::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
if (auto complexType = dyn_cast<ComplexType>(type)) {
|
if (auto complexType = dyn_cast<ComplexType>(type)) {
|
||||||
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
|
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
|
||||||
if (!elementSize)
|
if (!elementSize)
|
||||||
@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
|
|||||||
type.getSignedness());
|
type.getSignedness());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts 8-bit float types to integer types with the same bit width.
|
||||||
|
/// Returns a nullptr for unsupported 8-bit float types.
|
||||||
|
static Type convert8BitFloatType(const SPIRVConversionOptions &options,
|
||||||
|
FloatType type) {
|
||||||
|
if (!options.emulateUnsupportedFloatTypes)
|
||||||
|
return nullptr;
|
||||||
|
// F8 types are converted to integer types with the same bit width.
|
||||||
|
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||||
|
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
|
||||||
|
Float8E8M0FNUType>(type))
|
||||||
|
return IntegerType::get(type.getContext(), type.getWidth());
|
||||||
|
LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a type with the same shape but with any 8-bit float element type
|
||||||
|
/// converted to the same bit width integer type. This is a noop when the
|
||||||
|
/// element type is not the 8-bit float type or emulation flag is set to false.
|
||||||
|
static ShapedType
|
||||||
|
convertShaped8BitFloatType(ShapedType type,
|
||||||
|
const SPIRVConversionOptions &options) {
|
||||||
|
if (!options.emulateUnsupportedFloatTypes)
|
||||||
|
return type;
|
||||||
|
Type srcElementType = type.getElementType();
|
||||||
|
Type convertedElementType = nullptr;
|
||||||
|
// F8 types are converted to integer types with the same bit width.
|
||||||
|
if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
|
||||||
|
Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
|
||||||
|
Float8E8M0FNUType>(srcElementType))
|
||||||
|
convertedElementType = IntegerType::get(
|
||||||
|
type.getContext(), srcElementType.getIntOrFloatBitWidth());
|
||||||
|
|
||||||
|
if (!convertedElementType)
|
||||||
|
return type;
|
||||||
|
|
||||||
|
return type.clone(convertedElementType);
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a type with the same shape but with any index element type converted
|
/// Returns a type with the same shape but with any index element type converted
|
||||||
/// to the matching integer type. This is a noop when the element type is not
|
/// to the matching integer type. This is a noop when the element type is not
|
||||||
/// the index type.
|
/// the index type.
|
||||||
@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
|
|||||||
const SPIRVConversionOptions &options, VectorType type,
|
const SPIRVConversionOptions &options, VectorType type,
|
||||||
std::optional<spirv::StorageClass> storageClass = {}) {
|
std::optional<spirv::StorageClass> storageClass = {}) {
|
||||||
type = cast<VectorType>(convertIndexElementType(type, options));
|
type = cast<VectorType>(convertIndexElementType(type, options));
|
||||||
|
type = cast<VectorType>(convertShaped8BitFloatType(type, options));
|
||||||
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
|
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
|
||||||
if (!scalarType) {
|
if (!scalarType) {
|
||||||
// If this is not a spec allowed scalar type, try to handle sub-byte integer
|
// If this is not a spec allowed scalar type, try to handle sub-byte integer
|
||||||
@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
|
|||||||
}
|
}
|
||||||
|
|
||||||
type = cast<TensorType>(convertIndexElementType(type, options));
|
type = cast<TensorType>(convertIndexElementType(type, options));
|
||||||
|
type = cast<TensorType>(convertShaped8BitFloatType(type, options));
|
||||||
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
|
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
|
||||||
if (!scalarType) {
|
if (!scalarType) {
|
||||||
LLVM_DEBUG(llvm::dbgs()
|
LLVM_DEBUG(llvm::dbgs()
|
||||||
@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
|
|||||||
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
|
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
|
||||||
type = cast<MemRefType>(convertIndexElementType(type, options));
|
type = cast<MemRefType>(convertIndexElementType(type, options));
|
||||||
arrayElemType = type.getElementType();
|
arrayElemType = type.getElementType();
|
||||||
|
} else if (auto floatType = dyn_cast<FloatType>(elementType)) {
|
||||||
|
// Hnadle 8 bit float types.
|
||||||
|
type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
|
||||||
|
arrayElemType = type.getElementType();
|
||||||
} else {
|
} else {
|
||||||
LLVM_DEBUG(
|
LLVM_DEBUG(
|
||||||
llvm::dbgs()
|
llvm::dbgs()
|
||||||
@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
|
|||||||
addConversion([this](FloatType floatType) -> std::optional<Type> {
|
addConversion([this](FloatType floatType) -> std::optional<Type> {
|
||||||
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
|
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
|
||||||
return convertScalarType(this->targetEnv, this->options, scalarType);
|
return convertScalarType(this->targetEnv, this->options, scalarType);
|
||||||
|
if (floatType.getWidth() == 8)
|
||||||
|
return convert8BitFloatType(this->options, floatType);
|
||||||
return Type();
|
return Type();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -559,6 +559,23 @@ func.func @constant() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @constant_8bit_float
|
||||||
|
func.func @constant_8bit_float() {
|
||||||
|
// CHECK: spirv.Constant 56 : i8
|
||||||
|
%cst = arith.constant 1.0 : f8E4M3
|
||||||
|
// CHECK: spirv.Constant 56 : i8
|
||||||
|
%cst_i8 = arith.bitcast %cst : f8E4M3 to i8
|
||||||
|
// CHECK: spirv.Constant dense<56> : vector<4xi8>
|
||||||
|
%cst_vector = arith.constant dense<1.0> : vector<4xf8E4M3>
|
||||||
|
// CHECK: spirv.Constant dense<56> : vector<4xi8>
|
||||||
|
%cst_vector_i8 = arith.bitcast %cst_vector : vector<4xf8E4M3> to vector<4xi8>
|
||||||
|
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
|
||||||
|
%cst_tensor = arith.constant dense<1.0> : tensor<4xf8E5M2>
|
||||||
|
// CHECK: spirv.Constant dense<60> : tensor<4xi8> : !spirv.array<4 x i8>
|
||||||
|
%cst_tensor_i8 = arith.bitcast %cst_tensor : tensor<4xf8E5M2> to tensor<4xi8>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: @constant_16bit
|
// CHECK-LABEL: @constant_16bit
|
||||||
func.func @constant_16bit() {
|
func.func @constant_16bit() {
|
||||||
// CHECK: spirv.Constant 4 : i16
|
// CHECK: spirv.Constant 4 : i16
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
|
// RUN: mlir-opt -split-input-file -convert-func-to-spirv %s -o - | FileCheck %s
|
||||||
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
|
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-lt-32-bit-scalar-types=false" %s | \
|
||||||
// RUN: FileCheck %s --check-prefix=NOEMU
|
// RUN: FileCheck %s --check-prefix=NOEMU
|
||||||
|
// RUN: mlir-opt -split-input-file -convert-func-to-spirv="emulate-unsupported-float-types=false" %s | \
|
||||||
|
// RUN: FileCheck %s --check-prefix=UNSUPPORTED_FLOAT
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Integer types
|
// Integer types
|
||||||
@ -944,3 +946,55 @@ func.func @unranked_tensor(%arg0: tensor<*xi32>) { return }
|
|||||||
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
|
func.func @dynamic_dim_tensor(%arg0: tensor<8x?xi32>) { return }
|
||||||
|
|
||||||
} // end module
|
} // end module
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// Check that 8-bit float types are emulated as i8.
|
||||||
|
module attributes {
|
||||||
|
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8], []>, #spirv.resource_limits<>>
|
||||||
|
} {
|
||||||
|
|
||||||
|
// CHECK: spirv.func @float8_to_integer8
|
||||||
|
// CHECK-SAME: (%arg0: i8
|
||||||
|
// CHECK-SAME: %arg1: i8
|
||||||
|
// CHECK-SAME: %arg2: i8
|
||||||
|
// CHECK-SAME: %arg3: i8
|
||||||
|
// CHECK-SAME: %arg4: i8
|
||||||
|
// CHECK-SAME: %arg5: i8
|
||||||
|
// CHECK-SAME: %arg6: i8
|
||||||
|
// CHECK-SAME: %arg7: i8
|
||||||
|
// CHECK-SAME: %arg8: vector<4xi8>
|
||||||
|
// CHECK-SAME: %arg9: !spirv.ptr<!spirv.struct<(!spirv.array<8 x i8, stride=1> [0])>, StorageBuffer>
|
||||||
|
// CHECK-SAME: %arg10: !spirv.array<4 x i8>
|
||||||
|
// UNSUPPORTED_FLOAT-LABEL: func.func @float8_to_integer8
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: (%arg0: f8E5M2
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg1: f8E4M3
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg2: f8E4M3FN
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg3: f8E5M2FNUZ
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg4: f8E4M3FNUZ
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg5: f8E4M3B11FNUZ
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg6: f8E3M4
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg7: f8E8M0FNU
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg8: vector<4xf8E4M3B11FNUZ>
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: %arg10: tensor<4xf8E5M2>
|
||||||
|
// UNSUPPORTED_FLOAT-SAME: ) {
|
||||||
|
|
||||||
|
func.func @float8_to_integer8(
|
||||||
|
%arg0: f8E5M2, // CHECK-NOT: f8E5M2
|
||||||
|
%arg1: f8E4M3, // CHECK-NOT: f8E4M3
|
||||||
|
%arg2: f8E4M3FN, // CHECK-NOT: f8E4M3FN
|
||||||
|
%arg3: f8E5M2FNUZ, // CHECK-NOT: f8E5M2FNUZ
|
||||||
|
%arg4: f8E4M3FNUZ, // CHECK-NOT: f8E4M3FNUZ
|
||||||
|
%arg5: f8E4M3B11FNUZ, // CHECK-NOT: f8E4M3B11FNUZ
|
||||||
|
%arg6: f8E3M4, // CHECK-NOT: f8E3M4
|
||||||
|
%arg7: f8E8M0FNU, // CHECK-NOT: f8E8M0FNU
|
||||||
|
%arg8: vector<4xf8E4M3B11FNUZ>, // CHECK-NOT: vector<4xf8E4M3B11FNUZ>
|
||||||
|
%arg9: memref<8xf8E4M3, #spirv.storage_class<StorageBuffer>>, // CHECK-NOT: memref
|
||||||
|
%arg10: tensor<4xf8E5M2> // CHECK-NOT: tensor
|
||||||
|
) {
|
||||||
|
// CHECK: spirv.Return
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user