[HLSL][DirectX][SPIRV] Implement the fma API (#185304)
This PR adds `fma` HLSL intrinsic (with support for matrices) It follows all of the steps from #99117. Closes #99117.
This commit is contained in:
parent
3d5a2552c5
commit
c703ea52be
@ -13285,6 +13285,9 @@ def err_builtin_invalid_arg_type: Error<
|
||||
"%plural{0:|: }3"
|
||||
"%plural{[0,3]:type|:types}1 (was %4)">;
|
||||
|
||||
def err_builtin_requires_double_type: Error<
|
||||
"%ordinal0 argument must be a scalar, vector, or matrix of double type (was %1)">;
|
||||
|
||||
def err_bswapg_invalid_bit_width : Error<
|
||||
"_BitInt type %0 (%1 bits) must be a multiple of 16 bits for byte swapping">;
|
||||
|
||||
|
||||
@ -1235,6 +1235,60 @@ float3 floor(float3);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_floor)
|
||||
float4 floor(float4);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// fused multiply-add builtins
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// \fn double fma(double a, double b, double c)
|
||||
/// \brief Returns the double-precision fused multiply-addition of a * b + c.
|
||||
/// \param a The first value in the fused multiply-addition.
|
||||
/// \param b The second value in the fused multiply-addition.
|
||||
/// \param c The third value in the fused multiply-addition.
|
||||
|
||||
// double scalars and vectors
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double fma(double, double, double);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double2 fma(double2, double2, double2);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double3 fma(double3, double3, double3);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double4 fma(double4, double4, double4);
|
||||
|
||||
// double matrices
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double1x1 fma(double1x1, double1x1, double1x1);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double1x2 fma(double1x2, double1x2, double1x2);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double1x3 fma(double1x3, double1x3, double1x3);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double1x4 fma(double1x4, double1x4, double1x4);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double2x1 fma(double2x1, double2x1, double2x1);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double2x2 fma(double2x2, double2x2, double2x2);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double2x3 fma(double2x3, double2x3, double2x3);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double2x4 fma(double2x4, double2x4, double2x4);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double3x1 fma(double3x1, double3x1, double3x1);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double3x2 fma(double3x2, double3x2, double3x2);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double3x3 fma(double3x3, double3x3, double3x3);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double3x4 fma(double3x4, double3x4, double3x4);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double4x1 fma(double4x1, double4x1, double4x1);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double4x2 fma(double4x2, double4x2, double4x2);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double4x3 fma(double4x3, double4x3, double4x3);
|
||||
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_fma)
|
||||
double4x4 fma(double4x4, double4x4, double4x4);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// frac builtins
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@ -2178,9 +2178,10 @@ static bool
|
||||
checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
|
||||
Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr,
|
||||
int ArgOrdinal) {
|
||||
QualType EltTy = ArgTy;
|
||||
if (auto *VecTy = EltTy->getAs<VectorType>())
|
||||
EltTy = VecTy->getElementType();
|
||||
clang::QualType EltTy =
|
||||
ArgTy->isVectorType() ? ArgTy->getAs<VectorType>()->getElementType()
|
||||
: ArgTy->isMatrixType() ? ArgTy->getAs<MatrixType>()->getElementType()
|
||||
: ArgTy;
|
||||
|
||||
switch (ArgTyRestr) {
|
||||
case Sema::EltwiseBuiltinArgTyRestriction::None:
|
||||
@ -2192,6 +2193,7 @@ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
|
||||
break;
|
||||
case Sema::EltwiseBuiltinArgTyRestriction::FloatTy:
|
||||
if (!EltTy->isRealFloatingType()) {
|
||||
// FIXME: make diagnostic's wording correct for matrices
|
||||
return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
|
||||
<< ArgOrdinal << /* scalar or vector */ 5 << /* no int */ 0
|
||||
<< /* floating-point */ 1 << ArgTy;
|
||||
|
||||
@ -3149,6 +3149,25 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool CheckAnyDoubleRepresentation(Sema *S, SourceLocation Loc,
|
||||
int ArgOrdinal,
|
||||
clang::QualType PassedType) {
|
||||
clang::QualType BaseType =
|
||||
PassedType->isVectorType()
|
||||
? PassedType->castAs<clang::VectorType>()->getElementType()
|
||||
: PassedType->isMatrixType()
|
||||
? PassedType->castAs<clang::MatrixType>()->getElementType()
|
||||
: PassedType;
|
||||
if (!BaseType->isDoubleType()) {
|
||||
// FIXME: adopt standard `err_builtin_invalid_arg_type` instead of using
|
||||
// this custom error.
|
||||
return S->Diag(Loc, diag::err_builtin_requires_double_type)
|
||||
<< ArgOrdinal << PassedType;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
|
||||
unsigned ArgIndex) {
|
||||
auto *Arg = TheCall->getArg(ArgIndex);
|
||||
@ -4120,6 +4139,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
|
||||
TheCall->setType(ArgTyA);
|
||||
break;
|
||||
}
|
||||
case Builtin::BI__builtin_elementwise_fma: {
|
||||
if (SemaRef.checkArgCount(TheCall, 3) ||
|
||||
CheckAllArgsHaveSameType(&SemaRef, TheCall)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (CheckAllArgTypesAreCorrect(&SemaRef, TheCall,
|
||||
CheckAnyDoubleRepresentation))
|
||||
return true;
|
||||
|
||||
ExprResult A = TheCall->getArg(0);
|
||||
QualType ArgTyA = A.get()->getType();
|
||||
// return type is the same as input type
|
||||
TheCall->setType(ArgTyA);
|
||||
break;
|
||||
}
|
||||
case Builtin::BI__builtin_hlsl_transpose: {
|
||||
if (SemaRef.checkArgCount(TheCall, 1))
|
||||
return true;
|
||||
|
||||
138
clang/test/CodeGenHLSL/builtins/fma.hlsl
Normal file
138
clang/test/CodeGenHLSL/builtins/fma.hlsl
Normal file
@ -0,0 +1,138 @@
|
||||
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
|
||||
// RUN: dxil-pc-shadermodel6.3-library %s -emit-llvm \
|
||||
// RUN: -disable-llvm-passes -o - | FileCheck %s
|
||||
// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \
|
||||
// RUN: spirv-unknown-vulkan-compute %s -emit-llvm \
|
||||
// RUN: -disable-llvm-passes -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define {{.*}} double @{{.*}}fma_double{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn double @llvm.fma.f64(double
|
||||
// CHECK: ret double
|
||||
double fma_double(double a, double b, double c) { return fma(a, b, c); }
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
|
||||
// CHECK: ret <2 x double>
|
||||
double2 fma_double2(double2 a, double2 b, double2 c) { return fma(a, b, c); }
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
|
||||
// CHECK: ret <3 x double>
|
||||
double3 fma_double3(double3 a, double3 b, double3 c) { return fma(a, b, c); }
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
|
||||
// CHECK: ret <4 x double>
|
||||
double4 fma_double4(double4 a, double4 b, double4 c) { return fma(a, b, c); }
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <1 x double> @{{.*}}fma_double1x1{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <1 x double> @llvm.fma.v1f64(<1 x double>
|
||||
// CHECK: ret <1 x double>
|
||||
double1x1 fma_double1x1(double1x1 a, double1x1 b, double1x1 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double1x2{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
|
||||
// CHECK: ret <2 x double>
|
||||
double1x2 fma_double1x2(double1x2 a, double1x2 b, double1x2 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double1x3{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
|
||||
// CHECK: ret <3 x double>
|
||||
double1x3 fma_double1x3(double1x3 a, double1x3 b, double1x3 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double1x4{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
|
||||
// CHECK: ret <4 x double>
|
||||
double1x4 fma_double1x4(double1x4 a, double1x4 b, double1x4 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <2 x double> @{{.*}}fma_double2x1{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <2 x double> @llvm.fma.v2f64(<2 x double>
|
||||
// CHECK: ret <2 x double>
|
||||
double2x1 fma_double2x1(double2x1 a, double2x1 b, double2x1 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double2x2{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
|
||||
// CHECK: ret <4 x double>
|
||||
double2x2 fma_double2x2(double2x2 a, double2x2 b, double2x2 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double2x3{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double>
|
||||
// CHECK: ret <6 x double>
|
||||
double2x3 fma_double2x3(double2x3 a, double2x3 b, double2x3 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double2x4{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double>
|
||||
// CHECK: ret <8 x double>
|
||||
double2x4 fma_double2x4(double2x4 a, double2x4 b, double2x4 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <3 x double> @{{.*}}fma_double3x1{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <3 x double> @llvm.fma.v3f64(<3 x double>
|
||||
// CHECK: ret <3 x double>
|
||||
double3x1 fma_double3x1(double3x1 a, double3x1 b, double3x1 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <6 x double> @{{.*}}fma_double3x2{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <6 x double> @llvm.fma.v6f64(<6 x double>
|
||||
// CHECK: ret <6 x double>
|
||||
double3x2 fma_double3x2(double3x2 a, double3x2 b, double3x2 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <9 x double> @{{.*}}fma_double3x3{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <9 x double> @llvm.fma.v9f64(<9 x double>
|
||||
// CHECK: ret <9 x double>
|
||||
double3x3 fma_double3x3(double3x3 a, double3x3 b, double3x3 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double3x4{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double>
|
||||
// CHECK: ret <12 x double>
|
||||
double3x4 fma_double3x4(double3x4 a, double3x4 b, double3x4 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <4 x double> @{{.*}}fma_double4x1{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <4 x double> @llvm.fma.v4f64(<4 x double>
|
||||
// CHECK: ret <4 x double>
|
||||
double4x1 fma_double4x1(double4x1 a, double4x1 b, double4x1 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <8 x double> @{{.*}}fma_double4x2{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <8 x double> @llvm.fma.v8f64(<8 x double>
|
||||
// CHECK: ret <8 x double>
|
||||
double4x2 fma_double4x2(double4x2 a, double4x2 b, double4x2 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <12 x double> @{{.*}}fma_double4x3{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <12 x double> @llvm.fma.v12f64(<12 x double>
|
||||
// CHECK: ret <12 x double>
|
||||
double4x3 fma_double4x3(double4x3 a, double4x3 b, double4x3 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define {{.*}} <16 x double> @{{.*}}fma_double4x4{{.*}}(
|
||||
// CHECK: call reassoc nnan ninf nsz arcp afn <16 x double> @llvm.fma.v16f64(<16 x double>
|
||||
// CHECK: ret <16 x double>
|
||||
double4x4 fma_double4x4(double4x4 a, double4x4 b, double4x4 c) {
|
||||
return fma(a, b, c);
|
||||
}
|
||||
113
clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl
Normal file
113
clang/test/SemaHLSL/BuiltIns/fma-errors.hlsl
Normal file
@ -0,0 +1,113 @@
|
||||
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \
|
||||
// RUN: -triple dxil-pc-shadermodel6.6-library %s \
|
||||
// RUN: -emit-llvm-only -disable-llvm-passes -verify \
|
||||
// RUN: -verify-ignore-unexpected=note
|
||||
// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -x hlsl \
|
||||
// RUN: -triple spirv-unknown-vulkan-compute %s \
|
||||
// RUN: -emit-llvm-only -disable-llvm-passes -verify \
|
||||
// RUN: -verify-ignore-unexpected=note
|
||||
|
||||
float bad_float(float a, float b, float c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float')}}
|
||||
}
|
||||
|
||||
float2 bad_float2(float2 a, float2 b, float2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2' (aka 'vector<float, 2>'))}}
|
||||
}
|
||||
|
||||
float2x2 bad_float2x2(float2x2 a, float2x2 b, float2x2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'float2x2' (aka 'matrix<float, 2, 2>'))}}
|
||||
}
|
||||
|
||||
half bad_half(half a, half b, half c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half')}}
|
||||
}
|
||||
|
||||
half2 bad_half2(half2 a, half2 b, half2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2' (aka 'vector<half, 2>'))}}
|
||||
}
|
||||
|
||||
half2x2 bad_half2x2(half2x2 a, half2x2 b, half2x2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar, vector, or matrix of double type (was 'half2x2' (aka 'matrix<half, 2, 2>'))}}
|
||||
}
|
||||
|
||||
double mixed_bad_second(double a, float b, double c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{arguments are of different types ('double' vs 'float')}}
|
||||
}
|
||||
|
||||
double mixed_bad_third(double a, double b, half c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{arguments are of different types ('double' vs 'half')}}
|
||||
}
|
||||
|
||||
double2 mixed_bad_second_vec(double2 a, float2 b, double2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{arguments are of different types ('vector<double, [...]>' vs 'vector<float, [...]>')}}
|
||||
}
|
||||
|
||||
double2 mixed_bad_third_vec(double2 a, double2 b, float2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{arguments are of different types ('vector<double, [...]>' vs 'vector<float, [...]>')}}
|
||||
}
|
||||
|
||||
double2x2 mixed_bad_second_mat(double2x2 a, float2x2 b, double2x2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{arguments are of different types ('matrix<double, [2 * ...]>' vs 'matrix<float, [2 * ...]>')}}
|
||||
}
|
||||
|
||||
double2x2 mixed_bad_third_mat(double2x2 a, double2x2 b, half2x2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{arguments are of different types ('matrix<double, [2 * ...]>' vs 'matrix<half, [2 * ...]>')}}
|
||||
}
|
||||
|
||||
double shape_mismatch_second(double a, double2 b, double c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{call to 'fma' is ambiguous}}
|
||||
}
|
||||
|
||||
double2 shape_mismatch_third(double2 a, double2 b, double c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{call to 'fma' is ambiguous}}
|
||||
}
|
||||
|
||||
double2x2 shape_mismatch_scalar_mat(double2x2 a, double b, double2x2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{call to 'fma' is ambiguous}}
|
||||
}
|
||||
|
||||
double2x2 shape_mismatch_vec_mat(double2x2 a, double2 b, double2x2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{arguments are of different types ('double2x2' (aka 'matrix<double, 2, 2>') vs 'double2' (aka 'vector<double, 2>'))}}
|
||||
}
|
||||
|
||||
int bad_int(int a, int b, int c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'int')}}
|
||||
}
|
||||
|
||||
int2 bad_int2(int2 a, int2 b, int2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'int2' (aka 'vector<int, 2>'))}}
|
||||
}
|
||||
|
||||
bool bad_bool(bool a, bool b, bool c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool')}}
|
||||
}
|
||||
|
||||
bool2 bad_bool2(bool2 a, bool2 b, bool2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool2' (aka 'vector<bool, 2>'))}}
|
||||
}
|
||||
|
||||
bool2x2 bad_bool2x2(bool2x2 a, bool2x2 b, bool2x2 c) {
|
||||
return fma(a, b, c);
|
||||
// expected-error@-1 {{1st argument must be a scalar or vector of floating-point types (was 'bool2x2' (aka 'matrix<bool, 2, 2>'))}}
|
||||
}
|
||||
@ -790,6 +790,16 @@ def FMad : DXILOp<46, tertiary> {
|
||||
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
|
||||
}
|
||||
|
||||
def Fma : DXILOp<47, tertiary> {
|
||||
let Doc = "Double-precision fused multiply-add. fma(a,b,c) = a * b + c.";
|
||||
let intrinsics = [IntrinSelect<int_fma>];
|
||||
let arguments = [OverloadTy, OverloadTy, OverloadTy];
|
||||
let result = OverloadTy;
|
||||
let overloads = [Overloads<DXIL1_0, [DoubleTy]>];
|
||||
let stages = [Stages<DXIL1_0, [all_stages]>];
|
||||
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
|
||||
}
|
||||
|
||||
def IMad : DXILOp<48, tertiary> {
|
||||
let Doc = "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m "
|
||||
"* a + b.";
|
||||
|
||||
@ -113,6 +113,15 @@ static bool checkWaveOps(Intrinsic::ID IID) {
|
||||
}
|
||||
}
|
||||
|
||||
static bool checkDoubleExtensionOps(Intrinsic::ID IID) {
|
||||
switch (IID) {
|
||||
default:
|
||||
return false;
|
||||
case Intrinsic::fma:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
static bool isOptimizationDisabled(const Module &M) {
|
||||
const StringRef Key = "dx.disable_optimizations";
|
||||
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(M.getModuleFlag(Key)))
|
||||
@ -250,9 +259,8 @@ void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
|
||||
if (FunctionFlags.contains(CF))
|
||||
CSF.merge(FunctionFlags[CF]);
|
||||
|
||||
// TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
|
||||
// DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
|
||||
|
||||
CSF.DX11_1_DoubleExtensions |=
|
||||
checkDoubleExtensionOps(CI->getIntrinsicID());
|
||||
CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID());
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,6 +26,12 @@ define double @test_fdiv_double(double %a, double %b) #0 {
|
||||
ret double %res
|
||||
}
|
||||
|
||||
; CHECK: ; Function test_fma_double : 0x00000044
|
||||
define double @test_fma_double(double %a, double %b, double %c) #0 {
|
||||
%r = call double @llvm.fma.f64(double %a, double %b, double %c)
|
||||
ret double %r
|
||||
}
|
||||
|
||||
; CHECK: ; Function test_uitofp_i64 : 0x00100044
|
||||
define double @test_uitofp_i64(i64 %a) #0 {
|
||||
%r = uitofp i64 %a to double
|
||||
@ -50,4 +56,6 @@ define i64 @test_fptosi_i64(double %a) #0 {
|
||||
ret i64 %r
|
||||
}
|
||||
|
||||
declare double @llvm.fma.f64(double, double, double)
|
||||
|
||||
attributes #0 = { convergent norecurse nounwind "hlsl.export"}
|
||||
|
||||
53
llvm/test/CodeGen/DirectX/fma.ll
Normal file
53
llvm/test/CodeGen/DirectX/fma.ll
Normal file
@ -0,0 +1,53 @@
|
||||
; RUN: opt -S -scalarizer -dxil-op-lower < %s | FileCheck %s
|
||||
|
||||
target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64"
|
||||
target triple = "dxil-pc-shadermodel6.7-library"
|
||||
|
||||
; CHECK-LABEL: define double @fma_double(
|
||||
; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR:]]
|
||||
define double @fma_double(double %a, double %b, double %c) {
|
||||
%r = call double @llvm.fma.f64(double %a, double %b, double %c)
|
||||
ret double %r
|
||||
}
|
||||
|
||||
; CHECK-LABEL: define <2 x double> @fma_v2f64(
|
||||
; CHECK: extractelement <2 x double> %a, i64 0
|
||||
; CHECK: extractelement <2 x double> %b, i64 0
|
||||
; CHECK: extractelement <2 x double> %c, i64 0
|
||||
; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR]]
|
||||
; CHECK: extractelement <2 x double> %a, i64 1
|
||||
; CHECK: extractelement <2 x double> %b, i64 1
|
||||
; CHECK: extractelement <2 x double> %c, i64 1
|
||||
; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR]]
|
||||
; CHECK: insertelement <2 x double> poison, double %{{.*}}, i64 0
|
||||
; CHECK: insertelement <2 x double> %{{.*}}, double %{{.*}}, i64 1
|
||||
define <2 x double> @fma_v2f64(<2 x double> %a, <2 x double> %b,
|
||||
<2 x double> %c) {
|
||||
%r = call <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> %b,
|
||||
<2 x double> %c)
|
||||
ret <2 x double> %r
|
||||
}
|
||||
|
||||
; CHECK-LABEL: define <16 x double> @fma_v16f64(
|
||||
; CHECK: extractelement <16 x double> %a, i64 0
|
||||
; CHECK: extractelement <16 x double> %b, i64 0
|
||||
; CHECK: extractelement <16 x double> %c, i64 0
|
||||
; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR]]
|
||||
; CHECK: extractelement <16 x double> %a, i64 15
|
||||
; CHECK: extractelement <16 x double> %b, i64 15
|
||||
; CHECK: extractelement <16 x double> %c, i64 15
|
||||
; CHECK: call double @dx.op.tertiary.f64(i32 47, double %{{.*}}, double %{{.*}}, double %{{.*}}) #[[#ATTR]]
|
||||
; CHECK: insertelement <16 x double> poison, double %{{.*}}, i64 0
|
||||
; CHECK: insertelement <16 x double> %{{.*}}, double %{{.*}}, i64 15
|
||||
define <16 x double> @fma_v16f64(<16 x double> %a, <16 x double> %b,
|
||||
<16 x double> %c) {
|
||||
%r = call <16 x double> @llvm.fma.v16f64(<16 x double> %a, <16 x double> %b,
|
||||
<16 x double> %c)
|
||||
ret <16 x double> %r
|
||||
}
|
||||
|
||||
declare double @llvm.fma.f64(double, double, double)
|
||||
declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
|
||||
declare <16 x double> @llvm.fma.v16f64(<16 x double>, <16 x double>, <16 x double>)
|
||||
|
||||
; CHECK: attributes #[[#ATTR]] = { memory(none) }
|
||||
Loading…
x
Reference in New Issue
Block a user