[AMDGPU][True16] Generate correct reg size for reg_sequence16 in wmma src mod select (#187629)

When a f16 from a true16 insts is passed to a wmma, the src mod try to
pack it to a v4f16 using v_perm_b32. In true16 mode this is causing an
issue since v_perm_b32 takes vgpr32. Create a vgpr_32 for 16-bit src
before passing to v_perm_b32 in true16 mode so that the reg size
matched.

Ideailly we should use reg_sequence to replace v_perm_b32 in true16
mode. However, it currently hit a problem with bad code quality. With
current optimization it only shows better code quality when .hi16 is
selected in vector shuffle. Will fix it when reg allocator and coalescer
can reduce the extra mov
This commit is contained in:
Guo Chen 2026-03-25 14:59:05 -04:00 committed by GitHub
parent f0604af7ce
commit c1251ad58b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 32 deletions

View File

@ -129,6 +129,22 @@ static SDValue stripExtractLoElt(SDValue In) {
return In;
}
static SDValue emitRegSequence(llvm::SelectionDAG &CurDAG, unsigned DstRegClass,
EVT DstTy, ArrayRef<SDValue> Elts,
ArrayRef<unsigned> SubRegClass,
const SDLoc &DL) {
assert(Elts.size() == SubRegClass.size() && "array size mismatch");
unsigned NumElts = Elts.size();
SmallVector<SDValue, 17> Ops(2 * NumElts + 1);
Ops[0] = (CurDAG.getTargetConstant(DstRegClass, DL, MVT::i32));
for (unsigned i = 0; i < NumElts; ++i) {
Ops[2 * i + 1] = Elts[i];
Ops[2 * i + 2] = CurDAG.getTargetConstant(SubRegClass[i], DL, MVT::i32);
}
return SDValue(
CurDAG.getMachineNode(TargetOpcode::REG_SEQUENCE, DL, DstTy, Ops), 0);
}
} // end anonymous namespace
INITIALIZE_PASS_BEGIN(AMDGPUDAGToDAGISelLegacy, "amdgpu-isel",
@ -3736,9 +3752,9 @@ bool AMDGPUDAGToDAGISel::SelectWMMAOpSelVOP3PMods(SDValue In,
return true;
}
static MachineSDNode *buildRegSequence32(SmallVectorImpl<SDValue> &Elts,
llvm::SelectionDAG *CurDAG,
const SDLoc &DL) {
MachineSDNode *
AMDGPUDAGToDAGISel::buildRegSequence32(SmallVectorImpl<SDValue> &Elts,
const SDLoc &DL) const {
unsigned DstRegClass;
EVT DstTy;
switch (Elts.size()) {
@ -3768,9 +3784,9 @@ static MachineSDNode *buildRegSequence32(SmallVectorImpl<SDValue> &Elts,
return CurDAG->getMachineNode(TargetOpcode::REG_SEQUENCE, DL, DstTy, Ops);
}
static MachineSDNode *buildRegSequence16(SmallVectorImpl<SDValue> &Elts,
llvm::SelectionDAG *CurDAG,
const SDLoc &DL) {
MachineSDNode *
AMDGPUDAGToDAGISel::buildRegSequence16(SmallVectorImpl<SDValue> &Elts,
const SDLoc &DL) const {
SmallVector<SDValue, 8> PackedElts;
assert("unhandled Reg sequence size" &&
(Elts.size() == 8 || Elts.size() == 16));
@ -3783,6 +3799,20 @@ static MachineSDNode *buildRegSequence16(SmallVectorImpl<SDValue> &Elts,
if (isExtractHiElt(Elts[i + 1], HiSrc) && LoSrc == HiSrc) {
PackedElts.push_back(HiSrc);
} else {
if (Subtarget->useRealTrue16Insts()) {
// FIXME-TRUE16. For now pack VGPR_32 for 16-bit source before
// passing to v_perm_b32. Eventually we should use replace v_perm_b32
// by reg_sequence.
SDValue Undef = SDValue(
CurDAG->getMachineNode(TargetOpcode::IMPLICIT_DEF, DL, MVT::i16),
0);
Elts[i] =
emitRegSequence(*CurDAG, AMDGPU::VGPR_32RegClassID, MVT::i32,
{Elts[i], Undef}, {AMDGPU::lo16, AMDGPU::hi16}, DL);
Elts[i + 1] = emitRegSequence(*CurDAG, AMDGPU::VGPR_32RegClassID,
MVT::i32, {Elts[i + 1], Undef},
{AMDGPU::lo16, AMDGPU::hi16}, DL);
}
SDValue PackLoLo = CurDAG->getTargetConstant(0x05040100, DL, MVT::i32);
MachineSDNode *Packed =
CurDAG->getMachineNode(AMDGPU::V_PERM_B32_e64, DL, MVT::i32,
@ -3790,24 +3820,25 @@ static MachineSDNode *buildRegSequence16(SmallVectorImpl<SDValue> &Elts,
PackedElts.push_back(SDValue(Packed, 0));
}
}
return buildRegSequence32(PackedElts, CurDAG, DL);
return buildRegSequence32(PackedElts, DL);
}
static MachineSDNode *buildRegSequence(SmallVectorImpl<SDValue> &Elts,
llvm::SelectionDAG *CurDAG,
const SDLoc &DL, unsigned ElementSize) {
MachineSDNode *
AMDGPUDAGToDAGISel::buildRegSequence(SmallVectorImpl<SDValue> &Elts,
const SDLoc &DL,
unsigned ElementSize) const {
if (ElementSize == 16)
return buildRegSequence16(Elts, CurDAG, DL);
return buildRegSequence16(Elts, DL);
if (ElementSize == 32)
return buildRegSequence32(Elts, CurDAG, DL);
return buildRegSequence32(Elts, DL);
llvm_unreachable("Unhandled element size");
}
static void selectWMMAModsNegAbs(unsigned ModOpcode, unsigned &Mods,
SmallVectorImpl<SDValue> &Elts, SDValue &Src,
llvm::SelectionDAG *CurDAG, const SDLoc &DL,
unsigned ElementSize) {
void AMDGPUDAGToDAGISel::selectWMMAModsNegAbs(unsigned ModOpcode,
unsigned &Mods,
SmallVectorImpl<SDValue> &Elts,
SDValue &Src, const SDLoc &DL,
unsigned ElementSize) const {
if (ModOpcode == ISD::FNEG) {
Mods |= SISrcMods::NEG;
// Check if all elements also have abs modifier
@ -3819,17 +3850,17 @@ static void selectWMMAModsNegAbs(unsigned ModOpcode, unsigned &Mods,
}
if (Elts.size() != NegAbsElts.size()) {
// Neg
Src = SDValue(buildRegSequence(Elts, CurDAG, DL, ElementSize), 0);
Src = SDValue(buildRegSequence(Elts, DL, ElementSize), 0);
} else {
// Neg and Abs
Mods |= SISrcMods::NEG_HI;
Src = SDValue(buildRegSequence(NegAbsElts, CurDAG, DL, ElementSize), 0);
Src = SDValue(buildRegSequence(NegAbsElts, DL, ElementSize), 0);
}
} else {
assert(ModOpcode == ISD::FABS);
// Abs
Mods |= SISrcMods::NEG_HI;
Src = SDValue(buildRegSequence(Elts, CurDAG, DL, ElementSize), 0);
Src = SDValue(buildRegSequence(Elts, DL, ElementSize), 0);
}
}
@ -3868,7 +3899,7 @@ bool AMDGPUDAGToDAGISel::SelectWMMAModsF16Neg(SDValue In, SDValue &Src,
// All elements have neg modifier
if (BV->getNumOperands() * 2 == EltsF16.size()) {
Src = SDValue(buildRegSequence16(EltsF16, CurDAG, SDLoc(In)), 0);
Src = SDValue(buildRegSequence16(EltsF16, SDLoc(In)), 0);
Mods |= SISrcMods::NEG;
Mods |= SISrcMods::NEG_HI;
}
@ -3887,7 +3918,7 @@ bool AMDGPUDAGToDAGISel::SelectWMMAModsF16Neg(SDValue In, SDValue &Src,
// All pairs of elements have neg modifier
if (BV->getNumOperands() == EltsV2F16.size()) {
Src = SDValue(buildRegSequence32(EltsV2F16, CurDAG, SDLoc(In)), 0);
Src = SDValue(buildRegSequence32(EltsV2F16, SDLoc(In)), 0);
Mods |= SISrcMods::NEG;
Mods |= SISrcMods::NEG_HI;
}
@ -3918,8 +3949,7 @@ bool AMDGPUDAGToDAGISel::SelectWMMAModsF16NegAbs(SDValue In, SDValue &Src,
// All elements have ModOpcode modifier
if (BV->getNumOperands() * 2 == EltsF16.size())
selectWMMAModsNegAbs(ModOpcode, Mods, EltsF16, Src, CurDAG, SDLoc(In),
16);
selectWMMAModsNegAbs(ModOpcode, Mods, EltsF16, Src, SDLoc(In), 16);
}
// mods are on v2f16 elements
@ -3938,8 +3968,7 @@ bool AMDGPUDAGToDAGISel::SelectWMMAModsF16NegAbs(SDValue In, SDValue &Src,
// All elements have ModOpcode modifier
if (BV->getNumOperands() == EltsV2F16.size())
selectWMMAModsNegAbs(ModOpcode, Mods, EltsV2F16, Src, CurDAG, SDLoc(In),
32);
selectWMMAModsNegAbs(ModOpcode, Mods, EltsV2F16, Src, SDLoc(In), 32);
}
SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);
@ -3967,8 +3996,7 @@ bool AMDGPUDAGToDAGISel::SelectWMMAModsF32NegAbs(SDValue In, SDValue &Src,
// All elements had ModOpcode modifier
if (BV->getNumOperands() == EltsF32.size())
selectWMMAModsNegAbs(ModOpcode, Mods, EltsF32, Src, CurDAG, SDLoc(In),
32);
selectWMMAModsNegAbs(ModOpcode, Mods, EltsF32, Src, SDLoc(In), 32);
}
SrcMods = CurDAG->getTargetConstant(Mods, SDLoc(In), MVT::i32);

View File

@ -92,6 +92,17 @@ private:
bool isUniformLoad(const SDNode *N) const;
bool isUniformBr(const SDNode *N) const;
MachineSDNode *buildRegSequence16(SmallVectorImpl<SDValue> &Elts,
const SDLoc &DL) const;
MachineSDNode *buildRegSequence32(SmallVectorImpl<SDValue> &Elts,
const SDLoc &DL) const;
MachineSDNode *buildRegSequence(SmallVectorImpl<SDValue> &Elts,
const SDLoc &DL, unsigned ElementSize) const;
void selectWMMAModsNegAbs(unsigned ModOpcode, unsigned &Mods,
SmallVectorImpl<SDValue> &Elts, SDValue &Src,
const SDLoc &DL, unsigned ElementSize) const;
// Returns true if ISD::AND SDNode `N`'s masking of the shift amount operand's
// `ShAmtBits` bits is unneeded.
bool isUnneededShiftMask(const SDNode *N, unsigned ShAmtBits) const;

View File

@ -1,6 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
; RUN: llc -mtriple=amdgcn -mcpu=gfx1170 < %s | FileCheck %s --check-prefixes=GCN,GFX1170
; RUN: llc -mtriple=amdgcn -mcpu=gfx1200 < %s | FileCheck %s --check-prefixes=GCN,GFX12
; RUN: llc -mtriple=amdgcn -mcpu=gfx1200 -mattr=+real-true16 < %s | FileCheck %s --check-prefixes=GCN,GFX12
; RUN: llc -mtriple=amdgcn -mcpu=gfx1200 -mattr=-real-true16 < %s | FileCheck %s --check-prefixes=GCN,GFX12
define amdgpu_ps void @test_wmma_f32_16x16x16_f16_negA(<8 x half> %A, <8 x half> %B, <8 x float> %C, ptr addrspace(1) %out) {
; GCN-LABEL: test_wmma_f32_16x16x16_f16_negA:
@ -411,13 +412,11 @@ define amdgpu_ps void @test_wmma_f16_16x16x16_f16_negC_pack(<8 x half> %A, <8 x
; GFX1170-NEXT: flat_load_b128 v[12:15], v[8:9] offset:16
; GFX1170-NEXT: flat_load_b128 v[16:19], v[8:9]
; GFX1170-NEXT: s_waitcnt vmcnt(1) lgkmcnt(1)
; GFX1170-NEXT: v_mov_b16_e32 v8.l, v15.l
; GFX1170-NEXT: v_mov_b16_e32 v9.l, v14.l
; GFX1170-NEXT: v_perm_b32 v15, v15, v14, 0x5040100
; GFX1170-NEXT: v_perm_b32 v14, v13, v12, 0x5040100
; GFX1170-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
; GFX1170-NEXT: v_perm_b32 v13, v19, v18, 0x5040100
; GFX1170-NEXT: v_perm_b32 v12, v17, v16, 0x5040100
; GFX1170-NEXT: v_perm_b32 v15, v8, v9, 0x5040100
; GFX1170-NEXT: s_delay_alu instid0(VALU_DEP_1)
; GFX1170-NEXT: v_wmma_f16_16x16x16_f16 v[12:15], v[0:3], v[4:7], v[12:15] neg_lo:[0,0,1]
; GFX1170-NEXT: global_store_b128 v[10:11], v[12:15], off