diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index db1e3630435d..4d352f1def04 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -13285,6 +13285,9 @@ def err_builtin_invalid_arg_type: Error< "%plural{0:|: }3" "%plural{[0,3]:type|:types}1 (was %4)">; +def err_builtin_requires_double_type: Error< + "%ordinal0 argument must be a scalar, vector, or matrix of double type (was %1)">; + def err_bswapg_invalid_bit_width : Error< "_BitInt type %0 (%1 bits) must be a multiple of 16 bits for byte swapping">; diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h index ba9646db5de1..80c415ef6664 100644 --- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h @@ -1235,6 +1235,60 @@ float3 floor(float3); _HLSL_BUILTIN_ALIAS(__builtin_elementwise_floor) float4 floor(float4); +//===----------------------------------------------------------------------===// +// fused multiply-add builtins +//===----------------------------------------------------------------------===// + +/// \fn double fma(double a, double b, double c) +/// \brief Returns the double-precision fused multiply-addition of a * b + c. +/// \param a The first value in the fused multiply-addition. +/// \param b The second value in the fused multiply-addition. +/// \param c The third value in the fused multiply-addition. + +// double scalars and vectors +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double fma(double, double, double); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2 fma(double2, double2, double2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3 fma(double3, double3, double3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4 fma(double4, double4, double4); + +// double matrices +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double1x1 fma(double1x1, double1x1, double1x1); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double1x2 fma(double1x2, double1x2, double1x2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double1x3 fma(double1x3, double1x3, double1x3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double1x4 fma(double1x4, double1x4, double1x4); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2x1 fma(double2x1, double2x1, double2x1); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2x2 fma(double2x2, double2x2, double2x2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2x3 fma(double2x3, double2x3, double2x3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double2x4 fma(double2x4, double2x4, double2x4); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3x1 fma(double3x1, double3x1, double3x1); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3x2 fma(double3x2, double3x2, double3x2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3x3 fma(double3x3, double3x3, double3x3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double3x4 fma(double3x4, double3x4, double3x4); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4x1 fma(double4x1, double4x1, double4x1); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4x2 fma(double4x2, double4x2, double4x2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4x3 fma(double4x3, double4x3, double4x3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma) +double4x4 fma(double4x4, double4x4, double4x4); + //===----------------------------------------------------------------------===// // frac builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 00a664dfedbb..de8b96514497 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -2178,9 +2178,10 @@ static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy, Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr, int ArgOrdinal) { - QualType EltTy = ArgTy; - if (auto *VecTy = EltTy->getAs()) - EltTy = VecTy->getElementType(); + clang::QualType EltTy = + ArgTy->isVectorType() ? ArgTy->getAs()->getElementType() + : ArgTy->isMatrixType() ? ArgTy->getAs()->getElementType() + : ArgTy; switch (ArgTyRestr) { case Sema::EltwiseBuiltinArgTyRestriction::None: @@ -2192,6 +2193,7 @@ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy, break; case Sema::EltwiseBuiltinArgTyRestriction::FloatTy: if (!EltTy->isRealFloatingType()) { + // FIXME: make diagnostic's wording correct for matrices return S.Diag(Loc, diag::err_builtin_invalid_arg_type) << ArgOrdinal << /* scalar or vector */ 5 << /* no int */ 0 << /* floating-point */ 1 << ArgTy; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 3b7b12a884f4..2b977b2793ef 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -3149,6 +3149,25 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc, return false; } +static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc, + int ArgOrdinal, + clang::QualType PassedType) { + clang::QualType BaseType = + PassedType->isVectorType() + ? PassedType->castAs()->getElementType() + : PassedType->isMatrixType() + ? PassedType->castAs()->getElementType() + : PassedType; + if (!BaseType->isDoubleType()) { + // FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using + // this custom error. + return S->Diag(Loc, diag::err_builtin_requires_double_type) + << ArgOrdinal << PassedType; + } + + return false; +} + static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, unsigned ArgIndex) { auto *Arg = TheCall->getArg(ArgIndex); @@ -4120,6 +4139,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { TheCall->setType(ArgTyA); break; } + case Builtin::BI__builtin_elementwise_fma: { + if (SemaRef.checkArgCount(TheCall, 3) || + CheckAllArgsHaveSameType(&SemaRef, TheCall)) { + return true; + } + + if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall, + CheckAnyDoubleRepresentation)) + return true; + + ExprResult A = TheCall->getArg(0); + QualType ArgTyA = A.get()->getType(); + // return type is the same as input type + TheCall->setType(ArgTyA); + break; + } case Builtin::BI__builtin_hlsl_transpose: { if (SemaRef.checkArgCount(TheCall, 1)) return true; diff --git a/clang/test/CodeGenHLSL/builtins/fma.hlsl b/clang/test/CodeGenHLSL/builtins/fma.hlsl new file mode 100644 index 000000000000..3d9549197035 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/fma.hlsl @@ -0,0 +1,138 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm \ +// RUN: -disable-llvm-passes -o - | FileCheck %s +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -emit-llvm \ +// RUN: -disable-llvm-passes -o - | FileCheck %s + +// CHECK-LABEL: define {{.*}} double @{{.*}}fma_double{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.fma.f64(double +// CHECK: ret double +double fma_double(double a, double b, double c) { return fma(a, b, c); } + +// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double> +// CHECK: ret <2 x double> +double2 fma_double2(double2 a, double2 b, double2 c) { return fma(a, b, c); } + +// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double> +// CHECK: ret <3 x double> +double3 fma_double3(double3 a, double3 b, double3 c) { return fma(a, b, c); } + +// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double> +// CHECK: ret <4 x double> +double4 fma_double4(double4 a, double4 b, double4 c) { return fma(a, b, c); } + +// CHECK-LABEL: define {{.*}} <1 x double> @{{.*}}fma_double1x1{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <1 x double> @llvm.fma.v1f64(<1 x double> +// CHECK: ret <1 x double> +double1x1 fma_double1x1(double1x1 a, double1x1 b, double1x1 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double1x2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double> +// CHECK: ret <2 x double> +double1x2 fma_double1x2(double1x2 a, double1x2 b, double1x2 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double1x3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double> +// CHECK: ret <3 x double> +double1x3 fma_double1x3(double1x3 a, double1x3 b, double1x3 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double1x4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double> +// CHECK: ret <4 x double> +double1x4 fma_double1x4(double1x4 a, double1x4 b, double1x4 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2x1{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double> +// CHECK: ret <2 x double> +double2x1 fma_double2x1(double2x1 a, double2x1 b, double2x1 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double2x2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double> +// CHECK: ret <4 x double> +double2x2 fma_double2x2(double2x2 a, double2x2 b, double2x2 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double2x3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double> +// CHECK: ret <6 x double> +double2x3 fma_double2x3(double2x3 a, double2x3 b, double2x3 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double2x4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double> +// CHECK: ret <8 x double> +double2x4 fma_double2x4(double2x4 a, double2x4 b, double2x4 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3x1{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double> +// CHECK: ret <3 x double> +double3x1 fma_double3x1(double3x1 a, double3x1 b, double3x1 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double3x2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double> +// CHECK: ret <6 x double> +double3x2 fma_double3x2(double3x2 a, double3x2 b, double3x2 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <9 x double> @{{.*}}fma_double3x3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <9 x double> @llvm.fma.v9f64(<9 x double> +// CHECK: ret <9 x double> +double3x3 fma_double3x3(double3x3 a, double3x3 b, double3x3 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double3x4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double> +// CHECK: ret <12 x double> +double3x4 fma_double3x4(double3x4 a, double3x4 b, double3x4 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4x1{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double> +// CHECK: ret <4 x double> +double4x1 fma_double4x1(double4x1 a, double4x1 b, double4x1 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double4x2{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double> +// CHECK: ret <8 x double> +double4x2 fma_double4x2(double4x2 a, double4x2 b, double4x2 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double4x3{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double> +// CHECK: ret <12 x double> +double4x3 fma_double4x3(double4x3 a, double4x3 b, double4x3 c) { + return fma(a, b, c); +} + +// CHECK-LABEL: define {{.*}} <16 x double> @{{.*}}fma_double4x4{{.*}}( +// CHECK: call reassoc nnan ninf nsz arcp afn <16 x double> @llvm.fma.v16f64(<16 x double> +// CHECK: ret <16 x double> +double4x4 fma_double4x4(double4x4 a, double4x4 b, double4x4 c) { + return fma(a, b, c); +} diff --git a/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl new file mode 100644 index 000000000000..a454f2df4d50 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl @@ -0,0 +1,113 @@ +// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \ +// RUN: -triple dxil-pc-shadermodel6.6-library %s \ +// RUN: -emit-llvm-only -disable-llvm-passes -verify \ +// RUN: -verify-ignore-unexpected=note +// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \ +// RUN: -triple spirv-unknown-vulkan-compute %s \ +// RUN: -emit-llvm-only -disable-llvm-passes -verify \ +// RUN: -verify-ignore-unexpected=note + +float bad_float(float a, float b, float c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float')}} +} + +float2 bad_float2(float2 a, float2 b, float2 c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector'))}} +} + +float2x2 bad_float2x2(float2x2 a, float2x2 b, float2x2 c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix'))}} +} + +half bad_half(half a, half b, half c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half')}} +} + +half2 bad_half2(half2 a, half2 b, half2 c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2' (aka 'vector'))}} +} + +half2x2 bad_half2x2(half2x2 a, half2x2 b, half2x2 c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2x2' (aka 'matrix'))}} +} + +double mixed_bad_second(double a, float b, double c) { + return fma(a, b, c); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} +} + +double mixed_bad_third(double a, double b, half c) { + return fma(a, b, c); + // expected-error@-1 {{arguments are of different types ('double' vs 'half')}} +} + +double2 mixed_bad_second_vec(double2 a, float2 b, double2 c) { + return fma(a, b, c); + // expected-error@-1 {{arguments are of different types ('vector' vs 'vector')}} +} + +double2 mixed_bad_third_vec(double2 a, double2 b, float2 c) { + return fma(a, b, c); + // expected-error@-1 {{arguments are of different types ('vector' vs 'vector')}} +} + +double2x2 mixed_bad_second_mat(double2x2 a, float2x2 b, double2x2 c) { + return fma(a, b, c); + // expected-error@-1 {{arguments are of different types ('matrix' vs 'matrix')}} +} + +double2x2 mixed_bad_third_mat(double2x2 a, double2x2 b, half2x2 c) { + return fma(a, b, c); + // expected-error@-1 {{arguments are of different types ('matrix' vs 'matrix')}} +} + +double shape_mismatch_second(double a, double2 b, double c) { + return fma(a, b, c); + // expected-error@-1 {{call to 'fma' is ambiguous}} +} + +double2 shape_mismatch_third(double2 a, double2 b, double c) { + return fma(a, b, c); + // expected-error@-1 {{call to 'fma' is ambiguous}} +} + +double2x2 shape_mismatch_scalar_mat(double2x2 a, double b, double2x2 c) { + return fma(a, b, c); + // expected-error@-1 {{call to 'fma' is ambiguous}} +} + +double2x2 shape_mismatch_vec_mat(double2x2 a, double2 b, double2x2 c) { + return fma(a, b, c); + // expected-error@-1 {{arguments are of different types ('double2x2' (aka 'matrix') vs 'double2' (aka 'vector'))}} +} + +int bad_int(int a, int b, int c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'int')}} +} + +int2 bad_int2(int2 a, int2 b, int2 c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'int2' (aka 'vector'))}} +} + +bool bad_bool(bool a, bool b, bool c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool')}} +} + +bool2 bad_bool2(bool2 a, bool2 b, bool2 c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool2' (aka 'vector'))}} +} + +bool2x2 bad_bool2x2(bool2x2 a, bool2x2 b, bool2x2 c) { + return fma(a, b, c); + // expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool2x2' (aka 'matrix'))}} +} diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index e7763e915d8c..0a1e0114aa3b 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -790,6 +790,16 @@ def FMad : DXILOp<46, tertiary> { let attributes = [Attributes]; } +def Fma : DXILOp<47, tertiary> { + let Doc = "Double-precision fused multiply-add. fma(a,b,c) = a * b + c."; + let intrinsics = [IntrinSelect]; + let arguments = [OverloadTy, OverloadTy, OverloadTy]; + let result = OverloadTy; + let overloads = [Overloads]; + let stages = [Stages]; + let attributes = [Attributes]; +} + def IMad : DXILOp<48, tertiary> { let Doc = "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m " "* a + b."; diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp index 01f301e027c8..997d44112197 100644 --- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp +++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp @@ -113,6 +113,15 @@ static bool checkWaveOps(Intrinsic::ID IID) { } } +static bool checkDoubleExtensionOps(Intrinsic::ID IID) { + switch (IID) { + default: + return false; + case Intrinsic::fma: + return true; + } +} + static bool isOptimizationDisabled(const Module &M) { const StringRef Key = "dx.disable_optimizations"; if (auto *Flag = mdconst::extract_or_null(M.getModuleFlag(Key))) @@ -250,9 +259,8 @@ void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, if (FunctionFlags.contains(CF)) CSF.merge(FunctionFlags[CF]); - // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic - // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 - + CSF.DX11_1_DoubleExtensions |= + checkDoubleExtensionOps(CI->getIntrinsicID()); CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID()); } } diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll index dd8ea5f5b1ae..d56d8ff5bf5e 100644 --- a/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll +++ b/llvm/test/CodeGen/DirectX/ShaderFlags/double-extensions.ll @@ -26,6 +26,12 @@ define double @test_fdiv_double(double %a, double %b) #0 { ret double %res } +; CHECK: ; Function test_fma_double : 0x00000044 +define double @test_fma_double(double %a, double %b, double %c) #0 { + %r = call double @llvm.fma.f64(double %a, double %b, double %c) + ret double %r +} + ; CHECK: ; Function test_uitofp_i64 : 0x00100044 define double @test_uitofp_i64(i64 %a) #0 { %r = uitofp i64 %a to double @@ -50,4 +56,6 @@ define i64 @test_fptosi_i64(double %a) #0 { ret i64 %r } +declare double @llvm.fma.f64(double, double, double) + attributes #0 = { convergent norecurse nounwind "hlsl.export"} diff --git a/llvm/test/CodeGen/DirectX/fma.ll b/llvm/test/CodeGen/DirectX/fma.ll new file mode 100644 index 000000000000..ed0eafcfc328 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/fma.ll @@ -0,0 +1,53 @@ +; RUN: opt -S -scalarizer -dxil-op-lower < %s | FileCheck %s + +target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64" +target triple = "dxil-pc-shadermodel6.7-library" + +; CHECK-LABEL: define double @fma_double( +; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR:]] +define double @fma_double(double %a, double %b, double %c) { + %r = call double @llvm.fma.f64(double %a, double %b, double %c) + ret double %r +} + +; CHECK-LABEL: define <2 x double> @fma_v2f64( +; CHECK: extractelement <2 x double> %a, i64 0 +; CHECK: extractelement <2 x double> %b, i64 0 +; CHECK: extractelement <2 x double> %c, i64 0 +; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR]] +; CHECK: extractelement <2 x double> %a, i64 1 +; CHECK: extractelement <2 x double> %b, i64 1 +; CHECK: extractelement <2 x double> %c, i64 1 +; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR]] +; CHECK: insertelement <2 x double> poison, double %{{.*}}, i64 0 +; CHECK: insertelement <2 x double> %{{.*}}, double %{{.*}}, i64 1 +define <2 x double> @fma_v2f64(<2 x double> %a, <2 x double> %b, + <2 x double> %c) { + %r = call <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> %b, + <2 x double> %c) + ret <2 x double> %r +} + +; CHECK-LABEL: define <16 x double> @fma_v16f64( +; CHECK: extractelement <16 x double> %a, i64 0 +; CHECK: extractelement <16 x double> %b, i64 0 +; CHECK: extractelement <16 x double> %c, i64 0 +; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR]] +; CHECK: extractelement <16 x double> %a, i64 15 +; CHECK: extractelement <16 x double> %b, i64 15 +; CHECK: extractelement <16 x double> %c, i64 15 +; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR]] +; CHECK: insertelement <16 x double> poison, double %{{.*}}, i64 0 +; CHECK: insertelement <16 x double> %{{.*}}, double %{{.*}}, i64 15 +define <16 x double> @fma_v16f64(<16 x double> %a, <16 x double> %b, + <16 x double> %c) { + %r = call <16 x double> @llvm.fma.v16f64(<16 x double> %a, <16 x double> %b, + <16 x double> %c) + ret <16 x double> %r +} + +declare double @llvm.fma.f64(double, double, double) +declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>) +declare <16 x double> @llvm.fma.v16f64(<16 x double>, <16 x double>, <16 x double>) + +; CHECK: attributes #[[#ATTR]] = { memory(none) }