[x86] Enable indirect tail calls with more arguments (#137643)

X86ISelDAGToDAG's `isCalleeLoad` / `moveBelowOrigChain` tries to move
the load instruction next to the call so they can be folded, but it
would only allow a single CopyToReg node in between.

This patch makes it look through multiple CopyToReg, while being careful
to only perform the transformation when the load+call can be folded.

As part of that, it also replaces the `X86tcret_1reg` and `X86tcret_6regs`,
which checks that there are enough free registers to compute the call
address, with a more correct `X86tcret_enough_regs` that the
`moveBelowOrigChain` is also gated on.

Fixes #136848
This commit is contained in:
Hans Wennborg 2026-02-16 09:24:58 +01:00 committed by GitHub
parent 820d15c8e6
commit 7bf2d5f14b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 229 additions and 42 deletions

View File

@ -602,6 +602,7 @@ namespace {
bool onlyUsesZeroFlag(SDValue Flags) const;
bool hasNoSignFlagUses(SDValue Flags) const;
bool hasNoCarryFlagUses(SDValue Flags) const;
bool checkTCRetEnoughRegs(SDNode *N) const;
};
class X86DAGToDAGISelLegacy : public SelectionDAGISelLegacy {
@ -874,6 +875,12 @@ static bool isCalleeLoad(SDValue Callee, SDValue &Chain, bool HasCallSeq) {
LD->getExtensionType() != ISD::NON_EXTLOAD)
return false;
// If the load's outgoing chain has more than one use, we can't (currently)
// move the load since we'd most likely create a loop. TODO: Maybe it could
// work if moveBelowOrigChain() updated *all* the chain users.
if (!Callee.getValue(1).hasOneUse())
return false;
// Now let's find the callseq_start.
while (HasCallSeq && Chain.getOpcode() != ISD::CALLSEQ_START) {
if (!Chain.hasOneUse())
@ -881,20 +888,39 @@ static bool isCalleeLoad(SDValue Callee, SDValue &Chain, bool HasCallSeq) {
Chain = Chain.getOperand(0);
}
if (!Chain.getNumOperands())
while (true) {
if (!Chain.getNumOperands())
return false;
// It's not safe to move the callee (a load) across e.g. a store.
// Conservatively abort if the chain contains a node other than the ones
// below.
switch (Chain.getNode()->getOpcode()) {
case ISD::CALLSEQ_START:
case ISD::CopyToReg:
case ISD::LOAD:
break;
default:
return false;
}
if (Chain.getOperand(0).getNode() == Callee.getNode())
return true;
if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor &&
Chain.getOperand(0).getValue(0).hasOneUse() &&
Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()) &&
Callee.getValue(1).hasOneUse())
return true;
// Look past CopyToRegs. We only walk one path, so the chain mustn't branch.
if (Chain.getOperand(0).getOpcode() == ISD::CopyToReg &&
Chain.getOperand(0).getValue(0).hasOneUse()) {
Chain = Chain.getOperand(0);
continue;
}
return false;
// Since we are not checking for AA here, conservatively abort if the chain
// writes to memory. It's not safe to move the callee (a load) across a store.
if (isa<MemSDNode>(Chain.getNode()) &&
cast<MemSDNode>(Chain.getNode())->writeMem())
return false;
if (Chain.getOperand(0).getNode() == Callee.getNode())
return true;
if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor &&
Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()) &&
Callee.getValue(1).hasOneUse())
return true;
return false;
}
}
static bool isEndbrImm64(uint64_t Imm) {
@ -1363,6 +1389,8 @@ void X86DAGToDAGISel::PreprocessISelDAG() {
SDValue Load = N->getOperand(1);
if (!isCalleeLoad(Load, Chain, HasCallSeq))
continue;
if (N->getOpcode() == X86ISD::TC_RETURN && !checkTCRetEnoughRegs(N))
continue;
moveBelowOrigChain(CurDAG, Load, SDValue(N, 0), Chain);
++NumLoadMoved;
MadeChange = true;
@ -3479,6 +3507,65 @@ static bool mayUseCarryFlag(X86::CondCode CC) {
return true;
}
bool X86DAGToDAGISel::checkTCRetEnoughRegs(SDNode *N) const {
// Check that there is enough volatile registers to load the callee address.
const X86RegisterInfo *RI = Subtarget->getRegisterInfo();
unsigned AvailGPRs;
// The register classes below must stay in sync with what's used for
// TCRETURNri, TCRETURN_HIPE32ri, TCRETURN_WIN64ri, etc).
if (Subtarget->is64Bit()) {
const TargetRegisterClass *TCGPRs =
Subtarget->isCallingConvWin64(MF->getFunction().getCallingConv())
? &X86::GR64_TCW64RegClass
: &X86::GR64_TCRegClass;
// Can't use RSP or RIP for the load in general.
assert(TCGPRs->contains(X86::RSP));
assert(TCGPRs->contains(X86::RIP));
AvailGPRs = TCGPRs->getNumRegs() - 2;
} else {
const TargetRegisterClass *TCGPRs =
MF->getFunction().getCallingConv() == CallingConv::HiPE
? &X86::GR32RegClass
: &X86::GR32_TCRegClass;
// Can't use ESP for the address in general.
assert(TCGPRs->contains(X86::ESP));
AvailGPRs = TCGPRs->getNumRegs() - 1;
}
// The load's base and index need up to two registers.
unsigned LoadGPRs = 2;
assert(N->getOpcode() == X86ISD::TC_RETURN);
// X86tcret args: (*chain, ptr, imm, regs..., glue)
if (Subtarget->is32Bit()) {
// FIXME: This was carried from X86tcret_1reg which was used for 32-bit,
// but it could apply to 64-bit too.
const SDValue &BasePtr = cast<LoadSDNode>(N->getOperand(1))->getBasePtr();
if (isa<FrameIndexSDNode>(BasePtr)) {
LoadGPRs -= 2; // Base is fixed index off ESP; no regs needed.
} else if (BasePtr.getOpcode() == X86ISD::Wrapper &&
isa<GlobalAddressSDNode>(BasePtr->getOperand(0))) {
assert(!getTargetMachine().isPositionIndependent());
LoadGPRs -= 1; // Base is a global (immediate since this is non-PIC), no
// reg needed.
}
}
unsigned ArgGPRs = 0;
for (unsigned I = 3, E = N->getNumOperands(); I != E; ++I) {
if (const auto *RN = dyn_cast<RegisterSDNode>(N->getOperand(I))) {
if (!RI->isGeneralPurposeRegister(*MF, RN->getReg()))
continue;
if (++ArgGPRs + LoadGPRs > AvailGPRs)
return false;
}
}
return true;
}
/// Check whether or not the chain ending in StoreNode is suitable for doing
/// the {load; op; store} to modify transformation.
static bool isFusableLoadOpStorePattern(StoreSDNode *StoreNode,

View File

@ -1349,7 +1349,8 @@ def : Pat<(X86imp_call (i64 tglobaladdr:$dst)),
// Tailcall stuff. The TCRETURN instructions execute after the epilog, so they
// can never use callee-saved registers. That is the purpose of the GR64_TC
// register classes.
// register classes. These (GR32_TC, GR64_TC, ..) need to stay in sync with
// checkTCRetEnoughRegs.
//
// The only volatile register that is never used by the calling convention is
// %r11. This happens when calling a vararg function with 6 arguments.
@ -1366,8 +1367,7 @@ def : Pat<(X86tcret GR32:$dst, timm:$off),
// FIXME: This is disabled for 32-bit PIC mode because the global base
// register which is part of the address mode may be assigned a
// callee-saved register.
// Similar to X86tcret_6regs, here we only have 1 register left
def : Pat<(X86tcret_1reg (load addr:$dst), timm:$off),
def : Pat<(X86tcret_enough_regs (load addr:$dst), timm:$off),
(TCRETURNmi addr:$dst, timm:$off)>,
Requires<[Not64BitMode, IsNotPIC, NotUseIndirectThunkCalls]>;
@ -1391,13 +1391,13 @@ def : Pat<(X86tcret ptr_rc_tailcall:$dst, timm:$off),
(TCRETURNri64_ImpCall ptr_rc_tailcall:$dst, timm:$off)>,
Requires<[In64BitMode, NotUseIndirectThunkCalls, ImportCallOptimizationEnabled]>;
// Don't fold loads into X86tcret requiring more than 6 regs.
// Don't fold loads into X86tcret requiring too many regs.
// There wouldn't be enough scratch registers for base+index.
def : Pat<(X86tcret_6regs (load addr:$dst), timm:$off),
def : Pat<(X86tcret_enough_regs (load addr:$dst), timm:$off),
(TCRETURNmi64 addr:$dst, timm:$off)>,
Requires<[In64BitMode, IsNotWin64CCFunc, NotUseIndirectThunkCalls]>;
def : Pat<(X86tcret_6regs (load addr:$dst), timm:$off),
def : Pat<(X86tcret_enough_regs (load addr:$dst), timm:$off),
(TCRETURN_WINmi64 addr:$dst, timm:$off)>,
Requires<[IsWin64CCFunc, NotUseIndirectThunkCalls]>;

View File

@ -688,29 +688,9 @@ def X86lock_sub_nocf : PatFrag<(ops node:$lhs, node:$rhs),
return hasNoCarryFlagUses(SDValue(N, 0));
}]>;
def X86tcret_6regs : PatFrag<(ops node:$ptr, node:$off),
(X86tcret node:$ptr, node:$off), [{
// X86tcret args: (*chain, ptr, imm, regs..., glue)
unsigned NumRegs = 0;
for (unsigned i = 3, e = N->getNumOperands(); i != e; ++i)
if (isa<RegisterSDNode>(N->getOperand(i)) && ++NumRegs > 6)
return false;
return true;
}]>;
def X86tcret_1reg : PatFrag<(ops node:$ptr, node:$off),
(X86tcret node:$ptr, node:$off), [{
// X86tcret args: (*chain, ptr, imm, regs..., glue)
unsigned NumRegs = 1;
const SDValue& BasePtr = cast<LoadSDNode>(N->getOperand(1))->getBasePtr();
if (isa<FrameIndexSDNode>(BasePtr))
NumRegs = 3;
else if (BasePtr->getNumOperands() && isa<GlobalAddressSDNode>(BasePtr->getOperand(0)))
NumRegs = 3;
for (unsigned i = 3, e = N->getNumOperands(); i != e; ++i)
if (isa<RegisterSDNode>(N->getOperand(i)) && ( NumRegs-- == 0))
return false;
return true;
def X86tcret_enough_regs : PatFrag<(ops node:$ptr, node:$off),
(X86tcret node:$ptr, node:$off), [{
return checkTCRetEnoughRegs(N);
}]>;
// If this is an anyext of the remainder of an 8-bit sdivrem, use a MOVSX

View File

@ -0,0 +1,26 @@
; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu | FileCheck %s --check-prefix=LIN
; RUN: llc < %s -mtriple=x86_64-pc-windows-msvc | FileCheck %s --check-prefix=WIN
; The callee address computation should get folded into the call.
; CHECK-LABEL: f:
; CHECK-NOT: mov
; LIN: jmpq *(%rdi,%rsi,8)
; WIN: rex64 jmpq *(%rcx,%rdx,8)
define void @f(ptr %table, i64 %idx, i64 %aux1, i64 %aux2, i64 %aux3) {
entry:
%arrayidx = getelementptr inbounds ptr, ptr %table, i64 %idx
%funcptr = load ptr, ptr %arrayidx, align 8
tail call void %funcptr(ptr %table, i64 %idx, i64 %aux1, i64 %aux2, i64 %aux3)
ret void
}
; Check that we don't assert here. On Win64 this has a TokenFactor with
; multiple uses, which we can't currently fold.
define void @thunk(ptr %this, ...) {
entry:
%vtable = load ptr, ptr %this, align 8
%vfn = getelementptr inbounds nuw i8, ptr %vtable, i64 8
%0 = load ptr, ptr %vfn, align 8
musttail call void (ptr, ...) %0(ptr %this, ...)
ret void
}

View File

@ -24,3 +24,15 @@ entry:
tail call void %0()
ret void
}
; Don't fold the load+call if there's inline asm in between.
; CHECK: test3
; CHECK: mov{{.*}}
; CHECK: jmp{{.*}}
define void @test3(ptr nocapture %x) {
entry:
%0 = load ptr, ptr %x
call void asm sideeffect "", ""() ; It could do anything.
tail call void %0()
ret void
}

View File

@ -0,0 +1,82 @@
; RUN: llc < %s -mtriple=i686-linux-gnu | FileCheck %s --check-prefix=CHECK --check-prefix=LIN32
; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s --check-prefix=CHECK --check-prefix=LIN64
; RUN: llc < %s -mtriple=x86_64-pc-win32 | FileCheck %s --check-prefix=CHECK --check-prefix=WIN64
; Check that we only fold the address computation (load) into a tail call
; when we're sure there is enough volatile registers available.
@globl = global ptr null
; CHECK-LABEL: test0:
define i32 @test0(ptr %a, ptr %b) {
entry:
%func = load ptr, ptr %a
%call = tail call i32 %func()
ret i32 %call
; Call address load gets folded into the tail call.
; LIN32: jmpl *(%
; LIN64: jmpq *(%
; WIN64: jmpq *(%
}
; CHECK-LABEL: test1:
define i32 @test1(ptr %a, ptr %b) {
entry:
%func = load ptr, ptr %a
%call = tail call i32 %func(i32 inreg 1)
ret i32 %call
; Call address load gets folded into the tail call.
; LIN32: jmpl *(%
; LIN64: jmpq *(%
; WIN64: jmpq *(%
}
; CHECK-LABEL: test2:
define i32 @test2(ptr %a, ptr %b) {
entry:
%func = load ptr, ptr %a
%call = tail call i32 %func(i32 inreg 1, i32 inreg 2)
ret i32 %call
; On 32-bit we're not sure there is enough register to fold the load.
; LIN32: jmpl *%
; LIN64: jmpq *(%
; WIN64: jmpq *(%
}
; CHECK-LABEL: test2_globl:
define i32 @test2_globl(ptr %a, ptr %b) {
entry:
%func = load ptr, ptr @globl
%call = tail call i32 %func(i32 inreg 1, i32 inreg 2)
ret i32 %call
; .. but if the load is from a global, we can fold it.
; LIN32: jmpl *globl
; LIN64: jmpq *(%
; WIN64: jmpq *globl(%rip)
}
; CHECK-LABEL: test2_stack:
define i32 @test2_stack(ptr %func, ptr %b) {
entry:
%call = tail call i32 %func(i32 inreg 1, i32 inreg 2)
ret i32 %call
; and if the load is from the stack (on 32-bit, %func is passed on the stack):
; LIN32: jmpl *4(%esp)
}
define i32 @test6(ptr %a, ptr %b) {
entry:
%func = load ptr, ptr %a
%call = tail call i32 %func(i32 inreg 1, i32 inreg 2, i32 inreg 3, i32 inreg 4, i32 inreg 5, i32 inreg 6)
ret i32 %call
; LIN64: jmpq *(%
; I wasn't able to pass more than 4 arguments in registers on Win64.
; WIN64: callq *(
}