[HLSL][Matrix] Make HLSLElementwiseCast respect matrix memory layout (#184429)

Fixes #184379

Changes the implementation of HLSLElementwiseCast to respect matrix
memory layout.
The new implementation reads from the `LoadList` array in row-major
order as opposed to column-major in the old implementation, which makes
more sense because `LoadList` is always interpreted in row-major order
when read as a matrix.
The writes to the allocation `V` for the destination matrix now respects
the default matrix memory layout.

Assisted-by: claude-opus-4.6
This commit is contained in:
Deric C. 2026-03-09 11:10:38 -07:00 committed by GitHub
parent d3a22eaab8
commit f1a2fd2abb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 57 additions and 35 deletions

View File

@ -2539,7 +2539,7 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
QualType DestTy, SourceLocation Loc) {
SmallVector<LValue, 16> LoadList;
CGF.FlattenAccessAndTypeLValue(SrcVal, LoadList);
// Dest is either a vector or a builtin?
// Dest is either a vector, constant matrix, or a builtin
// if its a vector create a temp alloca to store into and return that
if (auto *VecTy = DestTy->getAs<VectorType>()) {
assert(LoadList.size() >= VecTy->getNumElements() &&
@ -2564,20 +2564,26 @@ static Value *EmitHLSLElementwiseCast(CodeGenFunction &CGF, LValue SrcVal,
"Flattened type on RHS must have the same number or more elements "
"than vector on LHS.");
bool IsRowMajor = CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
LangOptions::MatrixMemoryLayout::MatrixRowMajor;
llvm::Value *V = CGF.Builder.CreateLoad(
CGF.CreateIRTempWithoutCast(DestTy, "flatcast.tmp"));
// V is an allocated temporary to build the truncated matrix into.
for (unsigned I = 0, E = MatTy->getNumElementsFlattened(); I < E; I++) {
unsigned ColMajorIndex =
(I % MatTy->getNumRows()) * MatTy->getNumColumns() +
(I / MatTy->getNumRows());
RValue RVal = CGF.EmitLoadOfLValue(LoadList[ColMajorIndex], Loc);
assert(RVal.isScalar() &&
"All flattened source values should be scalars.");
llvm::Value *Cast = CGF.EmitScalarConversion(
RVal.getScalarVal(), LoadList[ColMajorIndex].getType(),
MatTy->getElementType(), Loc);
V = CGF.Builder.CreateInsertElement(V, Cast, I);
// V is an allocated temporary for constructing the matrix.
for (unsigned Row = 0, RE = MatTy->getNumRows(); Row < RE; Row++) {
for (unsigned Col = 0, CE = MatTy->getNumColumns(); Col < CE; Col++) {
// When interpreted as a matrix, \p LoadList is *always* row-major order
// regardless of the default matrix memory layout.
unsigned LoadIdx = MatTy->getRowMajorFlattenedIndex(Row, Col);
RValue RVal = CGF.EmitLoadOfLValue(LoadList[LoadIdx], Loc);
assert(RVal.isScalar() &&
"All flattened source values should be scalars.");
llvm::Value *Cast = CGF.EmitScalarConversion(
RVal.getScalarVal(), LoadList[LoadIdx].getType(),
MatTy->getElementType(), Loc);
unsigned MatrixIdx = MatTy->getFlattenedIndex(Row, Col, IsRowMajor);
V = CGF.Builder.CreateInsertElement(V, Cast, MatrixIdx);
}
}
return V;
}

View File

@ -1,12 +1,14 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 6
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -fnative-half-type -fnative-int16-type -o - %s | FileCheck %s
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -fnative-half-type -fnative-int16-type -fmatrix-memory-layout=row-major -o - %s | FileCheck %s --check-prefixes=CHECK,ROW-CHECK
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -fnative-half-type -fnative-int16-type -fmatrix-memory-layout=column-major -o - %s | FileCheck %s --check-prefixes=CHECK,COL-CHECK
// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast0u11matrix_typeILm3ELm2EfE(
// CHECK-SAME: <6 x float> noundef nofpclass(nan inf) [[F32:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[F32_ADDR:%.*]] = alloca [2 x <3 x float>], align 4
// CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
// ROW-CHECK-NEXT: [[F32_ADDR:%.*]] = alloca [3 x <2 x float>], align 4
// ROW-CHECK-NEXT: [[I32:%.*]] = alloca [3 x <2 x i32>], align 4
// COL-CHECK-NEXT: [[F32_ADDR:%.*]] = alloca [2 x <3 x float>], align 4
// COL-CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
// CHECK-NEXT: store <6 x float> [[F32]], ptr [[F32_ADDR]], align 4
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x float>, ptr [[F32_ADDR]], align 4
// CHECK-NEXT: [[CONV:%.*]] = fptosi <6 x float> [[TMP0]] to <6 x i32>
@ -22,8 +24,10 @@ int3x2 elementwise_type_cast0(float3x2 f32) {
// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast1u11matrix_typeILm3ELm2EsE(
// CHECK-SAME: <6 x i16> noundef [[I16_32:%.*]]) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[I16_32_ADDR:%.*]] = alloca [2 x <3 x i16>], align 2
// CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
// ROW-CHECK-NEXT: [[I16_32_ADDR:%.*]] = alloca [3 x <2 x i16>], align 2
// ROW-CHECK-NEXT: [[I32:%.*]] = alloca [3 x <2 x i32>], align 4
// COL-CHECK-NEXT: [[I16_32_ADDR:%.*]] = alloca [2 x <3 x i16>], align 2
// COL-CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
// CHECK-NEXT: store <6 x i16> [[I16_32]], ptr [[I16_32_ADDR]], align 2
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x i16>, ptr [[I16_32_ADDR]], align 2
// CHECK-NEXT: [[CONV:%.*]] = sext <6 x i16> [[TMP0]] to <6 x i32>
@ -39,8 +43,10 @@ int3x2 elementwise_type_cast1(int16_t3x2 i16_32) {
// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast2u11matrix_typeILm3ELm2ElE(
// CHECK-SAME: <6 x i64> noundef [[I64_32:%.*]]) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[I64_32_ADDR:%.*]] = alloca [2 x <3 x i64>], align 8
// CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
// ROW-CHECK-NEXT: [[I64_32_ADDR:%.*]] = alloca [3 x <2 x i64>], align 8
// ROW-CHECK-NEXT: [[I32:%.*]] = alloca [3 x <2 x i32>], align 4
// COL-CHECK-NEXT: [[I64_32_ADDR:%.*]] = alloca [2 x <3 x i64>], align 8
// COL-CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
// CHECK-NEXT: store <6 x i64> [[I64_32]], ptr [[I64_32_ADDR]], align 8
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x i64>, ptr [[I64_32_ADDR]], align 8
// CHECK-NEXT: [[CONV:%.*]] = trunc <6 x i64> [[TMP0]] to <6 x i32>
@ -56,8 +62,10 @@ int3x2 elementwise_type_cast2(int64_t3x2 i64_32) {
// CHECK-LABEL: define hidden noundef <6 x i16> @_Z22elementwise_type_cast3u11matrix_typeILm2ELm3EDhE(
// CHECK-SAME: <6 x half> noundef nofpclass(nan inf) [[H23:%.*]]) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[H23_ADDR:%.*]] = alloca [3 x <2 x half>], align 2
// CHECK-NEXT: [[I23:%.*]] = alloca [3 x <2 x i16>], align 2
// ROW-CHECK-NEXT: [[H23_ADDR:%.*]] = alloca [2 x <3 x half>], align 2
// ROW-CHECK-NEXT: [[I23:%.*]] = alloca [2 x <3 x i16>], align 2
// COL-CHECK-NEXT: [[H23_ADDR:%.*]] = alloca [3 x <2 x half>], align 2
// COL-CHECK-NEXT: [[I23:%.*]] = alloca [3 x <2 x i16>], align 2
// CHECK-NEXT: store <6 x half> [[H23]], ptr [[H23_ADDR]], align 2
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x half>, ptr [[H23_ADDR]], align 2
// CHECK-NEXT: [[CONV:%.*]] = fptosi <6 x half> [[TMP0]] to <6 x i16>
@ -73,8 +81,10 @@ int16_t2x3 elementwise_type_cast3(half2x3 h23) {
// CHECK-LABEL: define hidden noundef <6 x i32> @_Z22elementwise_type_cast4u11matrix_typeILm3ELm2EdE(
// CHECK-SAME: <6 x double> noundef nofpclass(nan inf) [[D32:%.*]]) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[D32_ADDR:%.*]] = alloca [2 x <3 x double>], align 8
// CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
// ROW-CHECK-NEXT: [[D32_ADDR:%.*]] = alloca [3 x <2 x double>], align 8
// ROW-CHECK-NEXT: [[I32:%.*]] = alloca [3 x <2 x i32>], align 4
// COL-CHECK-NEXT: [[D32_ADDR:%.*]] = alloca [2 x <3 x double>], align 8
// COL-CHECK-NEXT: [[I32:%.*]] = alloca [2 x <3 x i32>], align 4
// CHECK-NEXT: store <6 x double> [[D32]], ptr [[D32_ADDR]], align 8
// CHECK-NEXT: [[TMP0:%.*]] = load <6 x double>, ptr [[D32_ADDR]], align 8
// CHECK-NEXT: [[CONV:%.*]] = fptosi <6 x double> [[TMP0]] to <6 x i32>
@ -91,7 +101,8 @@ int3x2 elementwise_type_cast4(double3x2 d32) {
// CHECK-SAME: ) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[A:%.*]] = alloca [2 x [1 x i32]], align 4
// CHECK-NEXT: [[B:%.*]] = alloca [1 x <2 x i32>], align 4
// ROW-CHECK-NEXT: [[B:%.*]] = alloca [2 x <1 x i32>], align 4
// COL-CHECK-NEXT: [[B:%.*]] = alloca [1 x <2 x i32>], align 4
// CHECK-NEXT: [[AGG_TEMP:%.*]] = alloca [2 x [1 x i32]], align 4
// CHECK-NEXT: [[FLATCAST_TMP:%.*]] = alloca <2 x i32>, align 4
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 4 [[A]], ptr align 4 @__const._Z5call2v.A, i32 8, i1 false)
@ -120,7 +131,8 @@ struct S {
// CHECK-SAME: ) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
// CHECK-NEXT: [[A:%.*]] = alloca [1 x <2 x i32>], align 4
// ROW-CHECK-NEXT: [[A:%.*]] = alloca [2 x <1 x i32>], align 4
// COL-CHECK-NEXT: [[A:%.*]] = alloca [1 x <2 x i32>], align 4
// CHECK-NEXT: [[AGG_TEMP:%.*]] = alloca [[STRUCT_S]], align 1
// CHECK-NEXT: [[FLATCAST_TMP:%.*]] = alloca <2 x i32>, align 4
// CHECK-NEXT: call void @llvm.memcpy.p0.p0.i32(ptr align 1 [[S]], ptr align 1 @__const._Z5call3v.s, i32 8, i1 false)
@ -168,14 +180,16 @@ struct Derived : BFields {
// CHECK-NEXT: [[TMP1:%.*]] = load double, ptr [[GEP1]], align 8
// CHECK-NEXT: [[CONV:%.*]] = fptosi double [[TMP1]] to i32
// CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x i32> [[TMP0]], i32 [[CONV]], i64 0
// CHECK-NEXT: [[TMP3:%.*]] = load float, ptr [[GEP2]], align 4
// CHECK-NEXT: [[CONV4:%.*]] = fptosi float [[TMP3]] to i32
// CHECK-NEXT: [[TMP4:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[CONV4]], i64 1
// CHECK-NEXT: [[BF_LOAD:%.*]] = load i24, ptr [[E]], align 1
// CHECK-NEXT: [[BF_SHL:%.*]] = shl i24 [[BF_LOAD]], 9
// CHECK-NEXT: [[BF_ASHR:%.*]] = ashr i24 [[BF_SHL]], 9
// CHECK-NEXT: [[BF_CAST:%.*]] = sext i24 [[BF_ASHR]] to i32
// CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> [[TMP4]], i32 [[BF_CAST]], i64 2
// ROW-CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[BF_CAST]], i64 1
// COL-CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x i32> [[TMP2]], i32 [[BF_CAST]], i64 2
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[GEP2]], align 4
// CHECK-NEXT: [[CONV4:%.*]] = fptosi float [[TMP4]] to i32
// ROW-CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[CONV4]], i64 2
// COL-CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x i32> [[TMP3]], i32 [[CONV4]], i64 1
// CHECK-NEXT: [[TMP6:%.*]] = load i32, ptr [[GEP3]], align 4
// CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x i32> [[TMP5]], i32 [[TMP6]], i64 3
// CHECK-NEXT: store <4 x i32> [[TMP7]], ptr [[A]], align 4
@ -201,11 +215,13 @@ void call4(Derived D) {
// CHECK-NEXT: [[VECEXT:%.*]] = extractelement <4 x float> [[TMP2]], i32 0
// CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x float> [[TMP1]], float [[VECEXT]], i64 0
// CHECK-NEXT: [[TMP4:%.*]] = load <4 x float>, ptr [[VECTOR_GEP]], align 16
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <4 x float> [[TMP4]], i32 2
// CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x float> [[TMP3]], float [[VECEXT1]], i64 1
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <4 x float> [[TMP4]], i32 1
// ROW-CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x float> [[TMP3]], float [[VECEXT1]], i64 1
// COL-CHECK-NEXT: [[TMP5:%.*]] = insertelement <4 x float> [[TMP3]], float [[VECEXT1]], i64 2
// CHECK-NEXT: [[TMP6:%.*]] = load <4 x float>, ptr [[VECTOR_GEP]], align 16
// CHECK-NEXT: [[VECEXT2:%.*]] = extractelement <4 x float> [[TMP6]], i32 1
// CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x float> [[TMP5]], float [[VECEXT2]], i64 2
// CHECK-NEXT: [[VECEXT2:%.*]] = extractelement <4 x float> [[TMP6]], i32 2
// ROW-CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x float> [[TMP5]], float [[VECEXT2]], i64 2
// COL-CHECK-NEXT: [[TMP7:%.*]] = insertelement <4 x float> [[TMP5]], float [[VECEXT2]], i64 1
// CHECK-NEXT: [[TMP8:%.*]] = load <4 x float>, ptr [[VECTOR_GEP]], align 16
// CHECK-NEXT: [[VECEXT3:%.*]] = extractelement <4 x float> [[TMP8]], i32 3
// CHECK-NEXT: [[TMP9:%.*]] = insertelement <4 x float> [[TMP7]], float [[VECEXT3]], i64 3