[mlir][spirv] Add definitions and (de)serialization for FPRoundingMode (#101546)
This commit is contained in:
parent
e3d9b01a36
commit
ca26ea28ed
@ -3249,6 +3249,19 @@ def SPIRV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
|
||||
];
|
||||
}
|
||||
|
||||
def SPIRV_FPRM_RTE : I32EnumAttrCase<"RTE", 0>;
|
||||
def SPIRV_FPRM_RTZ : I32EnumAttrCase<"RTZ", 1>;
|
||||
def SPIRV_FPRM_RTP : I32EnumAttrCase<"RTP", 2>;
|
||||
def SPIRV_FPRM_RTN : I32EnumAttrCase<"RTN", 3>;
|
||||
|
||||
// TODO: Enforce SPIR-V spec validation rule for Shader capability: only permit
|
||||
// FPRoundingMode on a value stored to certain storage classes?
|
||||
// (The OpenCL environment also has FPRoundingMode rules, but different.)
|
||||
def SPIRV_FPRoundingModeAttr :
|
||||
SPIRV_I32EnumAttr<"FPRoundingMode", "valid SPIR-V FPRoundingMode", "fp_rounding_mode", [
|
||||
SPIRV_FPRM_RTE, SPIRV_FPRM_RTZ, SPIRV_FPRM_RTP, SPIRV_FPRM_RTN
|
||||
]>;
|
||||
|
||||
def SPIRV_FunctionControlAttr :
|
||||
SPIRV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
|
||||
SPIRV_FC_None, SPIRV_FC_Inline, SPIRV_FC_DontInline, SPIRV_FC_Pure, SPIRV_FC_Const,
|
||||
|
@ -250,6 +250,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
|
||||
symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
|
||||
static_cast<FPFastMathMode>(words[2])));
|
||||
break;
|
||||
case spirv::Decoration::FPRoundingMode:
|
||||
if (words.size() != 3) {
|
||||
return emitError(unknownLoc, "OpDecorate with ")
|
||||
<< decorationName << " needs a single integer literal";
|
||||
}
|
||||
decorations[words[0]].set(
|
||||
symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
|
||||
static_cast<FPRoundingMode>(words[2])));
|
||||
break;
|
||||
case spirv::Decoration::DescriptorSet:
|
||||
case spirv::Decoration::Binding:
|
||||
if (words.size() != 3) {
|
||||
|
@ -214,6 +214,9 @@ static std::string getDecorationName(StringRef attrName) {
|
||||
// expected FPFastMathMode.
|
||||
if (attrName == "fp_fast_math_mode")
|
||||
return "FPFastMathMode";
|
||||
// similar here
|
||||
if (attrName == "fp_rounding_mode")
|
||||
return "FPRoundingMode";
|
||||
|
||||
return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
|
||||
}
|
||||
@ -242,6 +245,13 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
|
||||
}
|
||||
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
|
||||
<< stringifyDecoration(decoration);
|
||||
case spirv::Decoration::FPRoundingMode:
|
||||
if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
|
||||
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
|
||||
break;
|
||||
}
|
||||
return emitError(loc, "expected FPRoundingModeAttr attribute for ")
|
||||
<< stringifyDecoration(decoration);
|
||||
case spirv::Decoration::Binding:
|
||||
case spirv::Decoration::DescriptorSet:
|
||||
case spirv::Decoration::Location:
|
||||
|
@ -97,3 +97,13 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
|
||||
spirv.ReturnValue %0 : f32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
|
||||
spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
|
||||
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
|
||||
%0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
|
||||
spirv.ReturnValue %0 : f16
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user