[HLSL] Codegen column-major matrix initializer lists without a vector shuffle (#186228)

Fixes #185518

The SPIR-V backend does not handle the lowering of `shufflevector`
instructions on vectors with more than 4 elements.
This PR changes the codegen of matrix init lists to directly emit
vectors with elements in column-major order when the default matrix
memory layout is column-major, as opposed to in linear/row-major order
followed by a vector shuffle.

While an alternative fix could be to change the default depth of
[`canEvaluateShuffled`](https://github.com/llvm/llvm-project/blob/main/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp#L1865-L1866)
to 16 in `InstCombineVectorOps.cpp` to eliminate the vector shuffle for
vectors of up to 16 elements in size (to handle 4x4 matrices), this
change would have broader impacts than just HLSL, which does not seem
necessary for the scope of this issue (which regards only matrix
initializer list codegen).

Another alternative fix would be to extend the `shufflevector` lowering
in the SPIR-V backend to support vectors of more than 4 elements.
However, again, this goes beyond the scope of just matrix initializer
list codegen which is so far the only case where a vector shuffle of a
vector more than 4 elements appeared.

Assisted-by: claude-opus-4.6
This commit is contained in:
Deric C. 2026-03-13 09:10:26 -07:00 committed by GitHub
parent 6ab7160073
commit fc4fed4d98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 46 additions and 51 deletions

View File

@ -2322,6 +2322,14 @@ Value *ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
unsigned ResElts = cast<llvm::FixedVectorType>(VType)->getNumElements();
// For column-major matrix types, we insert elements directly at their
// column-major positions rather than inserting sequentially and shuffling.
const ConstantMatrixType *ColMajorMT = nullptr;
if (const auto *MT = E->getType()->getAs<ConstantMatrixType>();
MT && CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
LangOptions::MatrixMemoryLayout::MatrixColMajor)
ColMajorMT = MT;
// Loop over initializers collecting the Value for each, and remembering
// whether the source was swizzle (ExtVectorElementExpr). This will allow
// us to fold the shuffle for the swizzle into the shuffle for the vector
@ -2376,7 +2384,11 @@ Value *ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
}
}
}
V = Builder.CreateInsertElement(V, Init, Builder.getInt32(CurIdx),
unsigned InsertIdx =
ColMajorMT
? ColMajorMT->mapRowMajorToColumnMajorFlattenedIndex(CurIdx)
: CurIdx;
V = Builder.CreateInsertElement(V, Init, Builder.getInt32(InsertIdx),
"vecinit");
VIsPoisonShuffle = false;
++CurIdx;
@ -2446,24 +2458,14 @@ Value *ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
// Emit remaining default initializers
for (/* Do not initialize i*/; CurIdx < ResElts; ++CurIdx) {
Value *Idx = Builder.getInt32(CurIdx);
unsigned InsertIdx =
ColMajorMT ? ColMajorMT->mapRowMajorToColumnMajorFlattenedIndex(CurIdx)
: CurIdx;
Value *Idx = Builder.getInt32(InsertIdx);
llvm::Value *Init = llvm::Constant::getNullValue(EltTy);
V = Builder.CreateInsertElement(V, Init, Idx, "vecinit");
}
// Matrix initializer lists are in row-major order but the memory layout for
// codegen is determined by the -fmatrix-memory-layout flag (default:
// column-major). When the memory layout is column-major, we need to shuffle
// the elements from row-major to column-major order.
if (const auto *MT = E->getType()->getAs<ConstantMatrixType>();
MT && CGF.getLangOpts().getDefaultMatrixMemoryLayout() ==
LangOptions::MatrixMemoryLayout::MatrixColMajor) {
SmallVector<int, 16> Mask;
for (unsigned I = 0, N = MT->getNumElementsFlattened(); I < N; ++I)
Mask.push_back(MT->mapColumnMajorToRowMajorFlattenedIndex(I));
V = Builder.CreateShuffleVector(V, Mask, "matrix.rowmajor2colmajor");
}
return V;
}

View File

@ -33,17 +33,16 @@ RWStructuredBuffer<float> In;
// CHECK-NEXT: [[TMP0:%.*]] = load float, ptr [[CALL]], align 4
// CHECK-NEXT: [[VECINIT:%.*]] = insertelement <6 x float> poison, float [[TMP0]], i32 0
// CHECK-NEXT: [[TMP1:%.*]] = load float, ptr [[CALL1]], align 4
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <6 x float> [[VECINIT]], float [[TMP1]], i32 1
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <6 x float> [[VECINIT]], float [[TMP1]], i32 3
// CHECK-NEXT: [[TMP2:%.*]] = load float, ptr [[CALL2]], align 4
// CHECK-NEXT: [[VECINIT7:%.*]] = insertelement <6 x float> [[VECINIT6]], float [[TMP2]], i32 2
// CHECK-NEXT: [[VECINIT7:%.*]] = insertelement <6 x float> [[VECINIT6]], float [[TMP2]], i32 1
// CHECK-NEXT: [[TMP3:%.*]] = load float, ptr [[CALL3]], align 4
// CHECK-NEXT: [[VECINIT8:%.*]] = insertelement <6 x float> [[VECINIT7]], float [[TMP3]], i32 3
// CHECK-NEXT: [[VECINIT8:%.*]] = insertelement <6 x float> [[VECINIT7]], float [[TMP3]], i32 4
// CHECK-NEXT: [[TMP4:%.*]] = load float, ptr [[CALL4]], align 4
// CHECK-NEXT: [[VECINIT9:%.*]] = insertelement <6 x float> [[VECINIT8]], float [[TMP4]], i32 4
// CHECK-NEXT: [[VECINIT9:%.*]] = insertelement <6 x float> [[VECINIT8]], float [[TMP4]], i32 2
// CHECK-NEXT: [[TMP5:%.*]] = load float, ptr [[CALL5]], align 4
// CHECK-NEXT: [[VECINIT10:%.*]] = insertelement <6 x float> [[VECINIT9]], float [[TMP5]], i32 5
// CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <6 x float> [[VECINIT10]], <6 x float> poison, <6 x i32> <i32 0, i32 2, i32 4, i32 1, i32 3, i32 5>
// CHECK-NEXT: ret <6 x float> [[MATRIX_ROWMAJOR2COLMAJOR]]
// CHECK-NEXT: ret <6 x float> [[VECINIT10]]
//
float3x2 case2() {
// vec[0] = Call
@ -70,21 +69,20 @@ float3x2 case2() {
// CHECK-NEXT: [[VECINIT:%.*]] = insertelement <6 x float> poison, float [[VECEXT]], i32 0
// CHECK-NEXT: [[TMP1:%.*]] = load <3 x float>, ptr [[A_ADDR]], align 16
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <3 x float> [[TMP1]], i64 1
// CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <6 x float> [[VECINIT]], float [[VECEXT1]], i32 1
// CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <6 x float> [[VECINIT]], float [[VECEXT1]], i32 3
// CHECK-NEXT: [[TMP2:%.*]] = load <3 x float>, ptr [[A_ADDR]], align 16
// CHECK-NEXT: [[VECEXT3:%.*]] = extractelement <3 x float> [[TMP2]], i64 2
// CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <6 x float> [[VECINIT2]], float [[VECEXT3]], i32 2
// CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <6 x float> [[VECINIT2]], float [[VECEXT3]], i32 1
// CHECK-NEXT: [[TMP3:%.*]] = load <3 x float>, ptr [[B_ADDR]], align 16
// CHECK-NEXT: [[VECEXT5:%.*]] = extractelement <3 x float> [[TMP3]], i64 0
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <6 x float> [[VECINIT4]], float [[VECEXT5]], i32 3
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <6 x float> [[VECINIT4]], float [[VECEXT5]], i32 4
// CHECK-NEXT: [[TMP4:%.*]] = load <3 x float>, ptr [[B_ADDR]], align 16
// CHECK-NEXT: [[VECEXT7:%.*]] = extractelement <3 x float> [[TMP4]], i64 1
// CHECK-NEXT: [[VECINIT8:%.*]] = insertelement <6 x float> [[VECINIT6]], float [[VECEXT7]], i32 4
// CHECK-NEXT: [[VECINIT8:%.*]] = insertelement <6 x float> [[VECINIT6]], float [[VECEXT7]], i32 2
// CHECK-NEXT: [[TMP5:%.*]] = load <3 x float>, ptr [[B_ADDR]], align 16
// CHECK-NEXT: [[VECEXT9:%.*]] = extractelement <3 x float> [[TMP5]], i64 2
// CHECK-NEXT: [[VECINIT10:%.*]] = insertelement <6 x float> [[VECINIT8]], float [[VECEXT9]], i32 5
// CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <6 x float> [[VECINIT10]], <6 x float> poison, <6 x i32> <i32 0, i32 2, i32 4, i32 1, i32 3, i32 5>
// CHECK-NEXT: ret <6 x float> [[MATRIX_ROWMAJOR2COLMAJOR]]
// CHECK-NEXT: ret <6 x float> [[VECINIT10]]
//
float3x2 case3(float3 a, float3 b) {
// vec[0] = A[0]

View File

@ -35,8 +35,7 @@ export float test_row1_col0() {
return M[1][0];
}
// Verify the shuffle is emitted for non-constant init lists when the memory
// layout is column-major, and not emitted when it is row-major.
// Verify that elements are inserted at the correct positions according to the default matrix memory layout.
export float2x3 test_dynamic(float a, float b, float c,
float d, float e, float f) {
@ -44,17 +43,18 @@ export float2x3 test_dynamic(float a, float b, float c,
// CHECK: [[A:%.*]] = load float, ptr %a.addr
// CHECK: [[VECINIT0:%.*]] = insertelement <6 x float> poison, float [[A]], i32 0
// CHECK: [[B:%.*]] = load float, ptr %b.addr
// CHECK: [[VECINIT1:%.*]] = insertelement <6 x float> [[VECINIT0]], float [[B]], i32 1
// COL-CHECK: [[VECINIT1:%.*]] = insertelement <6 x float> [[VECINIT0]], float [[B]], i32 2
// ROW-CHECK: [[VECINIT1:%.*]] = insertelement <6 x float> [[VECINIT0]], float [[B]], i32 1
// CHECK: [[C:%.*]] = load float, ptr %c.addr
// CHECK: [[VECINIT2:%.*]] = insertelement <6 x float> [[VECINIT1]], float [[C]], i32 2
// COL-CHECK: [[VECINIT2:%.*]] = insertelement <6 x float> [[VECINIT1]], float [[C]], i32 4
// ROW-CHECK: [[VECINIT2:%.*]] = insertelement <6 x float> [[VECINIT1]], float [[C]], i32 2
// CHECK: [[D:%.*]] = load float, ptr %d.addr
// CHECK: [[VECINIT3:%.*]] = insertelement <6 x float> [[VECINIT2]], float [[D]], i32 3
// COL-CHECK: [[VECINIT3:%.*]] = insertelement <6 x float> [[VECINIT2]], float [[D]], i32 1
// ROW-CHECK: [[VECINIT3:%.*]] = insertelement <6 x float> [[VECINIT2]], float [[D]], i32 3
// CHECK: [[E:%.*]] = load float, ptr %e.addr
// CHECK: [[VECINIT4:%.*]] = insertelement <6 x float> [[VECINIT3]], float [[E]], i32 4
// COL-CHECK: [[VECINIT4:%.*]] = insertelement <6 x float> [[VECINIT3]], float [[E]], i32 3
// ROW-CHECK: [[VECINIT4:%.*]] = insertelement <6 x float> [[VECINIT3]], float [[E]], i32 4
// CHECK: [[F:%.*]] = load float, ptr %f.addr
// CHECK: [[VECINIT5:%.*]] = insertelement <6 x float> [[VECINIT4]], float [[F]], i32 5
// COL-CHECK: shufflevector <6 x float> [[VECINIT5]], <6 x float> poison, <6 x i32> <i32 0, i32 3, i32 1, i32 4, i32 2, i32 5>
// ROW-CHECK-NOT: shufflevector
// ROW-CHECK: store <6 x float> [[VECINIT5]], ptr
return (float2x3){a, b, c, d, e, f};
}

View File

@ -41,16 +41,16 @@ float4 fn(float2x2 m) {
// CHECK-NEXT: [[VECINIT:%.*]] = insertelement <4 x i32> poison, i32 [[VECEXT]], i32 0
// CHECK-NEXT: [[TMP1:%.*]] = load <4 x i32>, ptr [[V_ADDR]], align 16
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <4 x i32> [[TMP1]], i64 1
// CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <4 x i32> [[VECINIT]], i32 [[VECEXT1]], i32 1
// COL-CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <4 x i32> [[VECINIT]], i32 [[VECEXT1]], i32 2
// ROW-CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <4 x i32> [[VECINIT]], i32 [[VECEXT1]], i32 1
// CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, ptr [[V_ADDR]], align 16
// CHECK-NEXT: [[VECEXT3:%.*]] = extractelement <4 x i32> [[TMP2]], i64 2
// CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <4 x i32> [[VECINIT2]], i32 [[VECEXT3]], i32 2
// COL-CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <4 x i32> [[VECINIT2]], i32 [[VECEXT3]], i32 1
// ROW-CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <4 x i32> [[VECINIT2]], i32 [[VECEXT3]], i32 2
// CHECK-NEXT: [[TMP3:%.*]] = load <4 x i32>, ptr [[V_ADDR]], align 16
// CHECK-NEXT: [[VECEXT5:%.*]] = extractelement <4 x i32> [[TMP3]], i64 3
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <4 x i32> [[VECINIT4]], i32 [[VECEXT5]], i32 3
// COL-CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <4 x i32> [[VECINIT6]], <4 x i32> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
// COL-CHECK-NEXT: store <4 x i32> [[MATRIX_ROWMAJOR2COLMAJOR]], ptr [[M]], align 4
// ROW-CHECK-NEXT: store <4 x i32> [[VECINIT6]], ptr [[M]], align 4
// CHECK-NEXT: store <4 x i32> [[VECINIT6]], ptr [[M]], align 4
// CHECK-NEXT: [[TMP4:%.*]] = load <4 x i32>, ptr [[M]], align 4
// CHECK-NEXT: ret <4 x i32> [[TMP4]]
//
@ -70,9 +70,7 @@ int2x2 fn(int4 v) {
// CHECK-NEXT: [[TMP1:%.*]] = load <2 x i32>, ptr [[V_ADDR]], align 8
// CHECK-NEXT: [[VECEXT1:%.*]] = extractelement <2 x i32> [[TMP1]], i64 1
// CHECK-NEXT: [[VECINIT2:%.*]] = insertelement <2 x i32> [[VECINIT]], i32 [[VECEXT1]], i32 1
// COL-CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <2 x i32> [[VECINIT2]], <2 x i32> poison, <2 x i32> <i32 0, i32 1>
// COL-CHECK-NEXT: ret <2 x i32> [[MATRIX_ROWMAJOR2COLMAJOR]]
// ROW-CHECK-NEXT: ret <2 x i32> [[VECINIT2]]
// CHECK-NEXT: ret <2 x i32> [[VECINIT2]]
//
int1x2 fn1(int2 v) {
return v;
@ -96,9 +94,7 @@ int1x2 fn1(int2 v) {
// CHECK-NEXT: [[LOADEDV4:%.*]] = trunc <3 x i32> [[TMP3]] to <3 x i1>
// CHECK-NEXT: [[VECEXT5:%.*]] = extractelement <3 x i1> [[LOADEDV4]], i64 2
// CHECK-NEXT: [[VECINIT6:%.*]] = insertelement <3 x i1> [[VECINIT3]], i1 [[VECEXT5]], i32 2
// COL-CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <3 x i1> [[VECINIT6]], <3 x i1> poison, <3 x i32> <i32 0, i32 1, i32 2>
// COL-CHECK-NEXT: ret <3 x i1> [[MATRIX_ROWMAJOR2COLMAJOR]]
// ROW-CHECK-NEXT: ret <3 x i1> [[VECINIT6]]
// CHECK-NEXT: ret <3 x i1> [[VECINIT6]]
//
bool3x1 fn2(bool3 b) {
return b;

View File

@ -35,13 +35,12 @@ bool fn1() {
// CHECK-NEXT: [[TMP0:%.*]] = load i32, ptr [[V_ADDR]], align 4
// CHECK-NEXT: [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
// CHECK-NEXT: [[VECINIT:%.*]] = insertelement <4 x i1> poison, i1 [[LOADEDV]], i32 0
// CHECK-NEXT: [[VECINIT1:%.*]] = insertelement <4 x i1> [[VECINIT]], i1 true, i32 1
// CHECK-NEXT: [[VECINIT1:%.*]] = insertelement <4 x i1> [[VECINIT]], i1 true, i32 2
// CHECK-NEXT: [[TMP1:%.*]] = load i32, ptr [[V_ADDR]], align 4
// CHECK-NEXT: [[LOADEDV2:%.*]] = trunc i32 [[TMP1]] to i1
// CHECK-NEXT: [[VECINIT3:%.*]] = insertelement <4 x i1> [[VECINIT1]], i1 [[LOADEDV2]], i32 2
// CHECK-NEXT: [[VECINIT3:%.*]] = insertelement <4 x i1> [[VECINIT1]], i1 [[LOADEDV2]], i32 1
// CHECK-NEXT: [[VECINIT4:%.*]] = insertelement <4 x i1> [[VECINIT3]], i1 false, i32 3
// CHECK-NEXT: [[MATRIX_ROWMAJOR2COLMAJOR:%.*]] = shufflevector <4 x i1> [[VECINIT4]], <4 x i1> poison, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
// CHECK-NEXT: [[TMP2:%.*]] = zext <4 x i1> [[MATRIX_ROWMAJOR2COLMAJOR]] to <4 x i32>
// CHECK-NEXT: [[TMP2:%.*]] = zext <4 x i1> [[VECINIT4]] to <4 x i32>
// CHECK-NEXT: store <4 x i32> [[TMP2]], ptr [[A]], align 4
// CHECK-NEXT: [[TMP3:%.*]] = load <4 x i32>, ptr [[A]], align 4
// CHECK-NEXT: store <4 x i32> [[TMP3]], ptr [[RETVAL]], align 4