diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 6ec97e17c5dc..b38978272c5b 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -3249,6 +3249,19 @@ def SPIRV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> { ]; } +def SPIRV_FPRM_RTE : I32EnumAttrCase<"RTE", 0>; +def SPIRV_FPRM_RTZ : I32EnumAttrCase<"RTZ", 1>; +def SPIRV_FPRM_RTP : I32EnumAttrCase<"RTP", 2>; +def SPIRV_FPRM_RTN : I32EnumAttrCase<"RTN", 3>; + +// TODO: Enforce SPIR-V spec validation rule for Shader capability: only permit +// FPRoundingMode on a value stored to certain storage classes? +// (The OpenCL environment also has FPRoundingMode rules, but different.) +def SPIRV_FPRoundingModeAttr : + SPIRV_I32EnumAttr<"FPRoundingMode", "valid SPIR-V FPRoundingMode", "fp_rounding_mode", [ + SPIRV_FPRM_RTE, SPIRV_FPRM_RTZ, SPIRV_FPRM_RTP, SPIRV_FPRM_RTN + ]>; + def SPIRV_FunctionControlAttr : SPIRV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [ SPIRV_FC_None, SPIRV_FC_Inline, SPIRV_FC_DontInline, SPIRV_FC_Pure, SPIRV_FC_Const, diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index d7a308548cf4..12980879b20a 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -250,6 +250,15 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { symbol, FPFastMathModeAttr::get(opBuilder.getContext(), static_cast(words[2]))); break; + case spirv::Decoration::FPRoundingMode: + if (words.size() != 3) { + return emitError(unknownLoc, "OpDecorate with ") + << decorationName << " needs a single integer literal"; + } + decorations[words[0]].set( + symbol, FPRoundingModeAttr::get(opBuilder.getContext(), + static_cast(words[2]))); + break; case spirv::Decoration::DescriptorSet: case spirv::Decoration::Binding: if (words.size() != 3) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 4c4fef177317..714a3edfb565 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -214,6 +214,9 @@ static std::string getDecorationName(StringRef attrName) { // expected FPFastMathMode. if (attrName == "fp_fast_math_mode") return "FPFastMathMode"; + // similar here + if (attrName == "fp_rounding_mode") + return "FPRoundingMode"; return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true); } @@ -242,6 +245,13 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, } return emitError(loc, "expected FPFastMathModeAttr attribute for ") << stringifyDecoration(decoration); + case spirv::Decoration::FPRoundingMode: + if (auto intAttr = dyn_cast(attr)) { + args.push_back(static_cast(intAttr.getValue())); + break; + } + return emitError(loc, "expected FPRoundingModeAttr attribute for ") + << stringifyDecoration(decoration); case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir index 195773735431..0a29290b6a6f 100644 --- a/mlir/test/Target/SPIRV/decorations.mlir +++ b/mlir/test/Target/SPIRV/decorations.mlir @@ -97,3 +97,13 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" { spirv.ReturnValue %0 : f32 } } + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { +spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" { + // CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode} : f32 to f16 + %0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode} : f32 to f16 + spirv.ReturnValue %0 : f16 +} +}