[HLSL][Matrix] Make HLSLElementwiseCast respect matrix memory layout (#184429)
Fixes #184379 Changes the implementation of HLSLElementwiseCast to respect matrix memory layout. The new implementation reads from the `LoadList` array in row-major order as opposed to column-major in the old implementation, which makes more sense because `LoadList` is always interpreted in row-major order when read as a matrix. The writes to the allocation `V` for the destination matrix now respects the default matrix memory layout. Assisted-by: claude-opus-4.6
This commit is contained in:
parent
d3a22eaab8
commit
f1a2fd2abb
@ -2539,7 +2539,7 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
|
||||
QualType DestTy, SourceLocation Loc) {
|
||||
SmallVector<LValue, 16> LoadList;
|
||||
CGF.FlattenAccessAndTypeLValue(SrcVal, LoadList);
|
||||
// Dest is either a vector or a builtin?
|
||||
// Dest is either a vector, constant matrix, or a builtin
|
||||
// if its a vector create a temp alloca to store into and return that
|
||||
if (auto *VecTy = DestTy->getAs<VectorType>()) {
|
||||
assert(LoadList.size() >= VecTy->getNumElements() &&
|
||||
@ -2564,20 +2564,26 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
|
||||
"Flattened type on RHS must have the same number or more elements "
|
||||
"than vector on LHS.");
|
||||
|
||||
bool IsRowMajor = CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
|
||||
LangOptions::MatrixMemoryLayout::MatrixRowMajor;
|
||||
|
||||
llvm::Value *V = CGF.Builder.CreateLoad(
|
||||
CGF.CreateIRTempWithoutCast(DestTy, "flatcast.tmp"));
|
||||
// V is an allocated temporary to build the truncated matrix into.
|
||||
for (unsigned I = 0, E = MatTy->getNumElementsFlattened(); I < E; I++) {
|
||||
unsigned ColMajorIndex =
|
||||
(I % MatTy->getNumRows()) * MatTy->getNumColumns() +
|
||||
(I / MatTy->getNumRows());
|
||||
RValue RVal = CGF.EmitLoadOfLValue(LoadList[ColMajorIndex], Loc);
|
||||
assert(RVal.isScalar() &&
|
||||
"All flattened source values should be scalars.");
|
||||
llvm::Value *Cast = CGF.EmitScalarConversion(
|
||||
RVal.getScalarVal(), LoadList[ColMajorIndex].getType(),
|
||||
MatTy->getElementType(), Loc);
|
||||
V = CGF.Builder.CreateInsertElement(V, Cast, I);
|
||||
// V is an allocated temporary for constructing the matrix.
|
||||
for (unsigned Row = 0, RE = MatTy->getNumRows(); Row < RE; Row++) {
|
||||
for (unsigned Col = 0, CE = MatTy->getNumColumns(); Col < CE; Col++) {
|
||||
// When interpreted as a matrix, \p LoadList is *always* row-major order
|
||||
// regardless of the default matrix memory layout.
|
||||
unsigned LoadIdx = MatTy->getRowMajorFlattenedIndex(Row, Col);
|
||||
RValue RVal = CGF.EmitLoadOfLValue(LoadList[LoadIdx], Loc);
|
||||
assert(RVal.isScalar() &&
|
||||
"All flattened source values should be scalars.");
|
||||
llvm::Value *Cast = CGF.EmitScalarConversion(
|
||||
RVal.getScalarVal(), LoadList[LoadIdx].getType(),
|
||||
MatTy->getElementType(), Loc);
|
||||
unsigned MatrixIdx = MatTy->getFlattenedIndex(Row, Col, IsRowMajor);
|
||||
V = CGF.Builder.CreateInsertElement(V, Cast, MatrixIdx);
|
||||
}
|
||||
}
|
||||
return V;
|
||||
}
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 6
|
||||
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -fnative-half-type -fnative-int16-type -o - %s | FileCheck %s
|
||||
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -fnative-half-type -fnative-int16-type -fmatrix-memory-layout=row-major -o - %s | FileCheck %s --check-prefixes=CHECK,ROW-CHECK
|
||||
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -fnative-half-type -fnative-int16-type -fmatrix-memory-layout=column-major -o - %s | FileCheck %s --check-prefixes=CHECK,COL-CHECK
|
||||
|
||||
|
||||
// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast0u11matrix_typeILm3ELm2EfE(
|
||||
// CHECK-SAME: <6 x float> noundef nofpclass(nan inf) [[F32:%.*]]) #[[ATTR0:[0-9]+]] {
|
||||
// CHECK-NEXT: [[ENTRY:.*:]]
|
||||
// CHECK-NEXT: [[F32_ADDR:%.*]] = alloca [2 x <3 x float>], align 4
|
||||
// CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
|
||||
// ROW-CHECK-NEXT: [[F32_ADDR:%.*]] = alloca [3 x <2 x float>], align 4
|
||||
// ROW-CHECK-NEXT: [[I32:%.*]] = alloca [3 x <2 x i32>], align 4
|
||||
// COL-CHECK-NEXT: [[F32_ADDR:%.*]] = alloca [2 x <3 x float>], align 4
|
||||
// COL-CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
|
||||
// CHECK-NEXT: store <6 x float> [[F32]], ptr [[F32_ADDR]], align 4
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x float>, ptr [[F32_ADDR]], align 4
|
||||
// CHECK-NEXT: [[CONV:%.*]] = fptosi <6 x float> [[TMP0]] to <6 x i32>
|
||||
@ -22,8 +24,10 @@ int3x2 elementwise_type_cast0(float3x2 f32) {
|
||||
// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast1u11matrix_typeILm3ELm2EsE(
|
||||
// CHECK-SAME: <6 x i16> noundef [[I16_32:%.*]]) #[[ATTR0]] {
|
||||
// CHECK-NEXT: [[ENTRY:.*:]]
|
||||
// CHECK-NEXT: [[I16_32_ADDR:%.*]] = alloca [2 x <3 x i16>], align 2
|
||||
// CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
|
||||
// ROW-CHECK-NEXT: [[I16_32_ADDR:%.*]] = alloca [3 x <2 x i16>], align 2
|
||||
// ROW-CHECK-NEXT: [[I32:%.*]] = alloca [3 x <2 x i32>], align 4
|
||||
// COL-CHECK-NEXT: [[I16_32_ADDR:%.*]] = alloca [2 x <3 x i16>], align 2
|
||||
// COL-CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
|
||||
// CHECK-NEXT: store <6 x i16> [[I16_32]], ptr [[I16_32_ADDR]], align 2
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x i16>, ptr [[I16_32_ADDR]], align 2
|
||||
// CHECK-NEXT: [[CONV:%.*]] = sext <6 x i16> [[TMP0]] to <6 x i32>
|
||||
@ -39,8 +43,10 @@ int3x2 elementwise_type_cast1(int16_t3x2 i16_32) {
|
||||
// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast2u11matrix_typeILm3ELm2ElE(
|
||||
// CHECK-SAME: <6 x i64> noundef [[I64_32:%.*]]) #[[ATTR0]] {
|
||||
// CHECK-NEXT: [[ENTRY:.*:]]
|
||||
// CHECK-NEXT: [[I64_32_ADDR:%.*]] = alloca [2 x <3 x i64>], align 8
|
||||
// CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
|
||||
// ROW-CHECK-NEXT: [[I64_32_ADDR:%.*]] = alloca [3 x <2 x i64>], align 8
|
||||
// ROW-CHECK-NEXT: [[I32:%.*]] = alloca [3 x <2 x i32>], align 4
|
||||
// COL-CHECK-NEXT: [[I64_32_ADDR:%.*]] = alloca [2 x <3 x i64>], align 8
|
||||
// COL-CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
|
||||
// CHECK-NEXT: store <6 x i64> [[I64_32]], ptr [[I64_32_ADDR]], align 8
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x i64>, ptr [[I64_32_ADDR]], align 8
|
||||
// CHECK-NEXT: [[CONV:%.*]] = trunc <6 x i64> [[TMP0]] to <6 x i32>
|
||||
@ -56,8 +62,10 @@ int3x2 elementwise_type_cast2(int64_t3x2 i64_32) {
|
||||
// CHECK-LABEL: define hidden noundef <6 x i16> @_Z22elementwise_type_cast3u11matrix_typeILm2ELm3EDhE(
|
||||
// CHECK-SAME: <6 x half> noundef nofpclass(nan inf) [[H23:%.*]]) #[[ATTR0]] {
|
||||
// CHECK-NEXT: [[ENTRY:.*:]]
|
||||
// CHECK-NEXT: [[H23_ADDR:%.*]] = alloca [3 x <2 x half>], align 2
|
||||
// CHECK-NEXT: [[I23:%.*]] = alloca [3 x <2 x i16>], align 2
|
||||
// ROW-CHECK-NEXT: [[H23_ADDR:%.*]] = alloca [2 x <3 x half>], align 2
|
||||
// ROW-CHECK-NEXT: [[I23:%.*]] = alloca [2 x <3 x i16>], align 2
|
||||
// COL-CHECK-NEXT: [[H23_ADDR:%.*]] = alloca [3 x <2 x half>], align 2
|
||||
// COL-CHECK-NEXT: [[I23:%.*]] = alloca [3 x <2 x i16>], align 2
|
||||
// CHECK-NEXT: store <6 x half> [[H23]], ptr [[H23_ADDR]], align 2
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x half>, ptr [[H23_ADDR]], align 2
|
||||
// CHECK-NEXT: [[CONV:%.*]] = fptosi <6 x half> [[TMP0]] to <6 x i16>
|
||||
@ -73,8 +81,10 @@ int16_t2x3 elementwise_type_cast3(half2x3 h23) {
|
||||
// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast4u11matrix_typeILm3ELm2EdE(
|
||||
// CHECK-SAME: <6 x double> noundef nofpclass(nan inf) [[D32:%.*]]) #[[ATTR0]] {
|
||||
// CHECK-NEXT: [[ENTRY:.*:]]
|
||||
// CHECK-NEXT: [[D32_ADDR:%.*]] = alloca [2 x <3 x double>], align 8
|
||||
// CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
|
||||
// ROW-CHECK-NEXT: [[D32_ADDR:%.*]] = alloca [3 x <2 x double>], align 8
|
||||
// ROW-CHECK-NEXT: [[I32:%.*]] = alloca [3 x <2 x i32>], align 4
|
||||
// COL-CHECK-NEXT: [[D32_ADDR:%.*]] = alloca [2 x <3 x double>], align 8
|
||||
// COL-CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
|
||||
// CHECK-NEXT: store <6 x double> [[D32]], ptr [[D32_ADDR]], align 8
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x double>, ptr [[D32_ADDR]], align 8
|
||||
// CHECK-NEXT: [[CONV:%.*]] = fptosi <6 x double> [[TMP0]] to <6 x i32>
|
||||
@ -91,7 +101,8 @@ int3x2 elementwise_type_cast4(double3x2 d32) {
|
||||
// CHECK-SAME: ) #[[ATTR0]] {
|
||||
// CHECK-NEXT: [[ENTRY:.*:]]
|
||||
// CHECK-NEXT: [[A:%.*]] = alloca [2 x [1 x i32]], align 4
|
||||
// CHECK-NEXT: [[B:%.*]] = alloca [1 x <2 x i32>], align 4
|
||||
// ROW-CHECK-NEXT: [[B:%.*]] = alloca [2 x <1 x i32>], align 4
|
||||
// COL-CHECK-NEXT: [[B:%.*]] = alloca [1 x <2 x i32>], align 4
|
||||
// CHECK-NEXT: [[AGG_TEMP:%.*]] = alloca [2 x [1 x i32]], align 4
|
||||
// CHECK-NEXT: [[FLATCAST_TMP:%.*]] = alloca <2 x i32>, align 4
|
||||
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 @__const._Z5call2v.A, i32 8, i1 false)
|
||||
@ -120,7 +131,8 @@ struct S {
|
||||
// CHECK-SAME: ) #[[ATTR0]] {
|
||||
// CHECK-NEXT: [[ENTRY:.*:]]
|
||||
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
|
||||
// CHECK-NEXT: [[A:%.*]] = alloca [1 x <2 x i32>], align 4
|
||||
// ROW-CHECK-NEXT: [[A:%.*]] = alloca [2 x <1 x i32>], align 4
|
||||
// COL-CHECK-NEXT: [[A:%.*]] = alloca [1 x <2 x i32>], align 4
|
||||
// CHECK-NEXT: [[AGG_TEMP:%.*]] = alloca [[STRUCT_S]], align 1
|
||||
// CHECK-NEXT: [[FLATCAST_TMP:%.*]] = alloca <2 x i32>, align 4
|
||||
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[S]], ptr align 1 @__const._Z5call3v.s, i32 8, i1 false)
|
||||
@ -168,14 +180,16 @@ struct Derived : BFields {
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = load double, ptr [[GEP1]], align 8
|
||||
// CHECK-NEXT: [[CONV:%.*]] = fptosi double [[TMP1]] to i32
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> [[TMP0]], i32 [[CONV]], i64 0
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = load float, ptr [[GEP2]], align 4
|
||||
// CHECK-NEXT: [[CONV4:%.*]] = fptosi float [[TMP3]] to i32
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[CONV4]], i64 1
|
||||
// CHECK-NEXT: [[BF_LOAD:%.*]] = load i24, ptr [[E]], align 1
|
||||
// CHECK-NEXT: [[BF_SHL:%.*]] = shl i24 [[BF_LOAD]], 9
|
||||
// CHECK-NEXT: [[BF_ASHR:%.*]] = ashr i24 [[BF_SHL]], 9
|
||||
// CHECK-NEXT: [[BF_CAST:%.*]] = sext i24 [[BF_ASHR]] to i32
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> [[TMP4]], i32 [[BF_CAST]], i64 2
|
||||
// ROW-CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[BF_CAST]], i64 1
|
||||
// COL-CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[BF_CAST]], i64 2
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[GEP2]], align 4
|
||||
// CHECK-NEXT: [[CONV4:%.*]] = fptosi float [[TMP4]] to i32
|
||||
// ROW-CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[CONV4]], i64 2
|
||||
// COL-CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[CONV4]], i64 1
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr [[GEP3]], align 4
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i32> [[TMP5]], i32 [[TMP6]], i64 3
|
||||
// CHECK-NEXT: store <4 x i32> [[TMP7]], ptr [[A]], align 4
|
||||
@ -201,11 +215,13 @@ void call4(Derived D) {
|
||||
// CHECK-NEXT: [[VECEXT:%.*]] = extractelement <4 x float> [[TMP2]], i32 0
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP1]], float [[VECEXT]], i64 0
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = load <4 x float>, ptr [[VECTOR_GEP]], align 16
|
||||
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <4 x float> [[TMP4]], i32 2
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x float> [[TMP3]], float [[VECEXT1]], i64 1
|
||||
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <4 x float> [[TMP4]], i32 1
|
||||
// ROW-CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x float> [[TMP3]], float [[VECEXT1]], i64 1
|
||||
// COL-CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x float> [[TMP3]], float [[VECEXT1]], i64 2
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = load <4 x float>, ptr [[VECTOR_GEP]], align 16
|
||||
// CHECK-NEXT: [[VECEXT2:%.*]] = extractelement <4 x float> [[TMP6]], i32 1
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x float> [[TMP5]], float [[VECEXT2]], i64 2
|
||||
// CHECK-NEXT: [[VECEXT2:%.*]] = extractelement <4 x float> [[TMP6]], i32 2
|
||||
// ROW-CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x float> [[TMP5]], float [[VECEXT2]], i64 2
|
||||
// COL-CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x float> [[TMP5]], float [[VECEXT2]], i64 1
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = load <4 x float>, ptr [[VECTOR_GEP]], align 16
|
||||
// CHECK-NEXT: [[VECEXT3:%.*]] = extractelement <4 x float> [[TMP8]], i32 3
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x float> [[TMP7]], float [[VECEXT3]], i64 3
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user