[X86][MemFold] Allow masked load folding if masks are equal (#161074)
Inspired by #160920#issuecomment-3341816198
This commit is contained in:
parent
cac0635ee9
commit
abffc542ff
@ -3238,6 +3238,7 @@ multiclass avx512_load<bits<8> opc, string OpcodeStr, string Name,
|
||||
(_.VT _.RC:$src1),
|
||||
(_.VT _.RC:$src0))))], _.ExeDomain>,
|
||||
EVEX, EVEX_K, Sched<[Sched.RR]>;
|
||||
let mayLoad = 1, canFoldAsLoad = 1 in
|
||||
def rmk : AVX512PI<opc, MRMSrcMem, (outs _.RC:$dst),
|
||||
(ins _.RC:$src0, _.KRCWM:$mask, _.MemOp:$src1),
|
||||
!strconcat(OpcodeStr, "\t{$src1, ${dst} {${mask}}|",
|
||||
@ -3248,6 +3249,7 @@ multiclass avx512_load<bits<8> opc, string OpcodeStr, string Name,
|
||||
(_.VT _.RC:$src0))))], _.ExeDomain>,
|
||||
EVEX, EVEX_K, Sched<[Sched.RM]>;
|
||||
}
|
||||
let mayLoad = 1, canFoldAsLoad = 1 in
|
||||
def rmkz : AVX512PI<opc, MRMSrcMem, (outs _.RC:$dst),
|
||||
(ins _.KRCWM:$mask, _.MemOp:$src),
|
||||
OpcodeStr #"\t{$src, ${dst} {${mask}} {z}|"#
|
||||
|
||||
@ -8113,6 +8113,39 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
|
||||
MachineBasicBlock::iterator InsertPt, MachineInstr &LoadMI,
|
||||
LiveIntervals *LIS) const {
|
||||
|
||||
// If LoadMI is a masked load, check MI having the same mask.
|
||||
const MCInstrDesc &MCID = get(LoadMI.getOpcode());
|
||||
unsigned NumOps = MCID.getNumOperands();
|
||||
if (NumOps >= 3) {
|
||||
Register MaskReg;
|
||||
const MachineOperand &Op1 = LoadMI.getOperand(1);
|
||||
const MachineOperand &Op2 = LoadMI.getOperand(2);
|
||||
|
||||
auto IsVKWMClass = [](const TargetRegisterClass *RC) {
|
||||
return RC == &X86::VK2WMRegClass || RC == &X86::VK4WMRegClass ||
|
||||
RC == &X86::VK8WMRegClass || RC == &X86::VK16WMRegClass ||
|
||||
RC == &X86::VK32WMRegClass || RC == &X86::VK64WMRegClass;
|
||||
};
|
||||
|
||||
if (Op1.isReg() && IsVKWMClass(getRegClass(MCID, 1, &RI)))
|
||||
MaskReg = Op1.getReg();
|
||||
else if (Op2.isReg() && IsVKWMClass(getRegClass(MCID, 2, &RI)))
|
||||
MaskReg = Op2.getReg();
|
||||
|
||||
if (MaskReg) {
|
||||
bool HasSameMask = false;
|
||||
for (unsigned I = 1, E = MI.getDesc().getNumOperands(); I < E; ++I) {
|
||||
const MachineOperand &Op = MI.getOperand(I);
|
||||
if (Op.isReg() && Op.getReg() == MaskReg) {
|
||||
HasSameMask = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!HasSameMask)
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Support the case where LoadMI loads a wide register, but MI
|
||||
// only uses a subreg.
|
||||
for (auto Op : Ops) {
|
||||
@ -8121,7 +8154,6 @@ MachineInstr *X86InstrInfo::foldMemoryOperandImpl(
|
||||
}
|
||||
|
||||
// If loading from a FrameIndex, fold directly from the FrameIndex.
|
||||
unsigned NumOps = LoadMI.getDesc().getNumOperands();
|
||||
int FrameIndex;
|
||||
if (isLoadFromStackSlot(LoadMI, FrameIndex)) {
|
||||
if (isNonFoldablePartialRegisterLoad(LoadMI, MI, MF))
|
||||
|
||||
@ -2119,8 +2119,7 @@ define void @ktest_1(<8 x double> %in, ptr %base) {
|
||||
; KNL-LABEL: ktest_1:
|
||||
; KNL: ## %bb.0:
|
||||
; KNL-NEXT: vcmpgtpd (%rdi), %zmm0, %k1
|
||||
; KNL-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z}
|
||||
; KNL-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1}
|
||||
; KNL-NEXT: vcmpltpd 8(%rdi), %zmm0, %k0 {%k1}
|
||||
; KNL-NEXT: kmovw %k0, %eax
|
||||
; KNL-NEXT: testb %al, %al
|
||||
; KNL-NEXT: je LBB44_2
|
||||
@ -2152,8 +2151,7 @@ define void @ktest_1(<8 x double> %in, ptr %base) {
|
||||
; AVX512BW-LABEL: ktest_1:
|
||||
; AVX512BW: ## %bb.0:
|
||||
; AVX512BW-NEXT: vcmpgtpd (%rdi), %zmm0, %k1
|
||||
; AVX512BW-NEXT: vmovupd 8(%rdi), %zmm1 {%k1} {z}
|
||||
; AVX512BW-NEXT: vcmpltpd %zmm1, %zmm0, %k0 {%k1}
|
||||
; AVX512BW-NEXT: vcmpltpd 8(%rdi), %zmm0, %k0 {%k1}
|
||||
; AVX512BW-NEXT: kmovd %k0, %eax
|
||||
; AVX512BW-NEXT: testb %al, %al
|
||||
; AVX512BW-NEXT: je LBB44_2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user