From 7bf2d5f14b2cd0a26f4779673dd8c51b46dea095 Mon Sep 17 00:00:00 2001 From: Hans Wennborg Date: Mon, 16 Feb 2026 09:24:58 +0100 Subject: [PATCH] [x86] Enable indirect tail calls with more arguments (#137643) X86ISelDAGToDAG's `isCalleeLoad` / `moveBelowOrigChain` tries to move the load instruction next to the call so they can be folded, but it would only allow a single CopyToReg node in between. This patch makes it look through multiple CopyToReg, while being careful to only perform the transformation when the load+call can be folded. As part of that, it also replaces the `X86tcret_1reg` and `X86tcret_6regs`, which checks that there are enough free registers to compute the call address, with a more correct `X86tcret_enough_regs` that the `moveBelowOrigChain` is also gated on. Fixes #136848 --- llvm/lib/Target/X86/X86ISelDAGToDAG.cpp | 113 ++++++++++++++++-- llvm/lib/Target/X86/X86InstrCompiler.td | 12 +- llvm/lib/Target/X86/X86InstrFragments.td | 26 +--- llvm/test/CodeGen/X86/fold-call-4.ll | 26 ++++ llvm/test/CodeGen/X86/fold-call.ll | 12 ++ .../CodeGen/X86/tailcall-mem_enoughregs.ll | 82 +++++++++++++ 6 files changed, 229 insertions(+), 42 deletions(-) create mode 100644 llvm/test/CodeGen/X86/fold-call-4.ll create mode 100644 llvm/test/CodeGen/X86/tailcall-mem_enoughregs.ll diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp index 7607fad150db..101ea3e231a5 100644 --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -602,6 +602,7 @@ namespace { bool onlyUsesZeroFlag(SDValue Flags) const; bool hasNoSignFlagUses(SDValue Flags) const; bool hasNoCarryFlagUses(SDValue Flags) const; + bool checkTCRetEnoughRegs(SDNode *N) const; }; class X86DAGToDAGISelLegacy : public SelectionDAGISelLegacy { @@ -874,6 +875,12 @@ static bool isCalleeLoad(SDValue Callee, SDValue &Chain, bool HasCallSeq) { LD->getExtensionType() != ISD::NON_EXTLOAD) return false; + // If the load's outgoing chain has more than one use, we can't (currently) + // move the load since we'd most likely create a loop. TODO: Maybe it could + // work if moveBelowOrigChain() updated *all* the chain users. + if (!Callee.getValue(1).hasOneUse()) + return false; + // Now let's find the callseq_start. while (HasCallSeq && Chain.getOpcode() != ISD::CALLSEQ_START) { if (!Chain.hasOneUse()) @@ -881,20 +888,39 @@ static bool isCalleeLoad(SDValue Callee, SDValue &Chain, bool HasCallSeq) { Chain = Chain.getOperand(0); } - if (!Chain.getNumOperands()) + while (true) { + if (!Chain.getNumOperands()) + return false; + + // It's not safe to move the callee (a load) across e.g. a store. + // Conservatively abort if the chain contains a node other than the ones + // below. + switch (Chain.getNode()->getOpcode()) { + case ISD::CALLSEQ_START: + case ISD::CopyToReg: + case ISD::LOAD: + break; + default: + return false; + } + + if (Chain.getOperand(0).getNode() == Callee.getNode()) + return true; + if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor && + Chain.getOperand(0).getValue(0).hasOneUse() && + Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()) && + Callee.getValue(1).hasOneUse()) + return true; + + // Look past CopyToRegs. We only walk one path, so the chain mustn't branch. + if (Chain.getOperand(0).getOpcode() == ISD::CopyToReg && + Chain.getOperand(0).getValue(0).hasOneUse()) { + Chain = Chain.getOperand(0); + continue; + } + return false; - // Since we are not checking for AA here, conservatively abort if the chain - // writes to memory. It's not safe to move the callee (a load) across a store. - if (isa(Chain.getNode()) && - cast(Chain.getNode())->writeMem()) - return false; - if (Chain.getOperand(0).getNode() == Callee.getNode()) - return true; - if (Chain.getOperand(0).getOpcode() == ISD::TokenFactor && - Callee.getValue(1).isOperandOf(Chain.getOperand(0).getNode()) && - Callee.getValue(1).hasOneUse()) - return true; - return false; + } } static bool isEndbrImm64(uint64_t Imm) { @@ -1363,6 +1389,8 @@ void X86DAGToDAGISel::PreprocessISelDAG() { SDValue Load = N->getOperand(1); if (!isCalleeLoad(Load, Chain, HasCallSeq)) continue; + if (N->getOpcode() == X86ISD::TC_RETURN && !checkTCRetEnoughRegs(N)) + continue; moveBelowOrigChain(CurDAG, Load, SDValue(N, 0), Chain); ++NumLoadMoved; MadeChange = true; @@ -3479,6 +3507,65 @@ static bool mayUseCarryFlag(X86::CondCode CC) { return true; } +bool X86DAGToDAGISel::checkTCRetEnoughRegs(SDNode *N) const { + // Check that there is enough volatile registers to load the callee address. + + const X86RegisterInfo *RI = Subtarget->getRegisterInfo(); + unsigned AvailGPRs; + // The register classes below must stay in sync with what's used for + // TCRETURNri, TCRETURN_HIPE32ri, TCRETURN_WIN64ri, etc). + if (Subtarget->is64Bit()) { + const TargetRegisterClass *TCGPRs = + Subtarget->isCallingConvWin64(MF->getFunction().getCallingConv()) + ? &X86::GR64_TCW64RegClass + : &X86::GR64_TCRegClass; + // Can't use RSP or RIP for the load in general. + assert(TCGPRs->contains(X86::RSP)); + assert(TCGPRs->contains(X86::RIP)); + AvailGPRs = TCGPRs->getNumRegs() - 2; + } else { + const TargetRegisterClass *TCGPRs = + MF->getFunction().getCallingConv() == CallingConv::HiPE + ? &X86::GR32RegClass + : &X86::GR32_TCRegClass; + // Can't use ESP for the address in general. + assert(TCGPRs->contains(X86::ESP)); + AvailGPRs = TCGPRs->getNumRegs() - 1; + } + + // The load's base and index need up to two registers. + unsigned LoadGPRs = 2; + + assert(N->getOpcode() == X86ISD::TC_RETURN); + // X86tcret args: (*chain, ptr, imm, regs..., glue) + + if (Subtarget->is32Bit()) { + // FIXME: This was carried from X86tcret_1reg which was used for 32-bit, + // but it could apply to 64-bit too. + const SDValue &BasePtr = cast(N->getOperand(1))->getBasePtr(); + if (isa(BasePtr)) { + LoadGPRs -= 2; // Base is fixed index off ESP; no regs needed. + } else if (BasePtr.getOpcode() == X86ISD::Wrapper && + isa(BasePtr->getOperand(0))) { + assert(!getTargetMachine().isPositionIndependent()); + LoadGPRs -= 1; // Base is a global (immediate since this is non-PIC), no + // reg needed. + } + } + + unsigned ArgGPRs = 0; + for (unsigned I = 3, E = N->getNumOperands(); I != E; ++I) { + if (const auto *RN = dyn_cast(N->getOperand(I))) { + if (!RI->isGeneralPurposeRegister(*MF, RN->getReg())) + continue; + if (++ArgGPRs + LoadGPRs > AvailGPRs) + return false; + } + } + + return true; +} + /// Check whether or not the chain ending in StoreNode is suitable for doing /// the {load; op; store} to modify transformation. static bool isFusableLoadOpStorePattern(StoreSDNode *StoreNode, diff --git a/llvm/lib/Target/X86/X86InstrCompiler.td b/llvm/lib/Target/X86/X86InstrCompiler.td index 6c8a7d7c83f0..f6fdc1cf5934 100644 --- a/llvm/lib/Target/X86/X86InstrCompiler.td +++ b/llvm/lib/Target/X86/X86InstrCompiler.td @@ -1349,7 +1349,8 @@ def : Pat<(X86imp_call (i64 tglobaladdr:$dst)), // Tailcall stuff. The TCRETURN instructions execute after the epilog, so they // can never use callee-saved registers. That is the purpose of the GR64_TC -// register classes. +// register classes. These (GR32_TC, GR64_TC, ..) need to stay in sync with +// checkTCRetEnoughRegs. // // The only volatile register that is never used by the calling convention is // %r11. This happens when calling a vararg function with 6 arguments. @@ -1366,8 +1367,7 @@ def : Pat<(X86tcret GR32:$dst, timm:$off), // FIXME: This is disabled for 32-bit PIC mode because the global base // register which is part of the address mode may be assigned a // callee-saved register. -// Similar to X86tcret_6regs, here we only have 1 register left -def : Pat<(X86tcret_1reg (load addr:$dst), timm:$off), +def : Pat<(X86tcret_enough_regs (load addr:$dst), timm:$off), (TCRETURNmi addr:$dst, timm:$off)>, Requires<[Not64BitMode, IsNotPIC, NotUseIndirectThunkCalls]>; @@ -1391,13 +1391,13 @@ def : Pat<(X86tcret ptr_rc_tailcall:$dst, timm:$off), (TCRETURNri64_ImpCall ptr_rc_tailcall:$dst, timm:$off)>, Requires<[In64BitMode, NotUseIndirectThunkCalls, ImportCallOptimizationEnabled]>; -// Don't fold loads into X86tcret requiring more than 6 regs. +// Don't fold loads into X86tcret requiring too many regs. // There wouldn't be enough scratch registers for base+index. -def : Pat<(X86tcret_6regs (load addr:$dst), timm:$off), +def : Pat<(X86tcret_enough_regs (load addr:$dst), timm:$off), (TCRETURNmi64 addr:$dst, timm:$off)>, Requires<[In64BitMode, IsNotWin64CCFunc, NotUseIndirectThunkCalls]>; -def : Pat<(X86tcret_6regs (load addr:$dst), timm:$off), +def : Pat<(X86tcret_enough_regs (load addr:$dst), timm:$off), (TCRETURN_WINmi64 addr:$dst, timm:$off)>, Requires<[IsWin64CCFunc, NotUseIndirectThunkCalls]>; diff --git a/llvm/lib/Target/X86/X86InstrFragments.td b/llvm/lib/Target/X86/X86InstrFragments.td index 0d6443d002d0..adbb8b821700 100644 --- a/llvm/lib/Target/X86/X86InstrFragments.td +++ b/llvm/lib/Target/X86/X86InstrFragments.td @@ -688,29 +688,9 @@ def X86lock_sub_nocf : PatFrag<(ops node:$lhs, node:$rhs), return hasNoCarryFlagUses(SDValue(N, 0)); }]>; -def X86tcret_6regs : PatFrag<(ops node:$ptr, node:$off), - (X86tcret node:$ptr, node:$off), [{ - // X86tcret args: (*chain, ptr, imm, regs..., glue) - unsigned NumRegs = 0; - for (unsigned i = 3, e = N->getNumOperands(); i != e; ++i) - if (isa(N->getOperand(i)) && ++NumRegs > 6) - return false; - return true; -}]>; - -def X86tcret_1reg : PatFrag<(ops node:$ptr, node:$off), - (X86tcret node:$ptr, node:$off), [{ - // X86tcret args: (*chain, ptr, imm, regs..., glue) - unsigned NumRegs = 1; - const SDValue& BasePtr = cast(N->getOperand(1))->getBasePtr(); - if (isa(BasePtr)) - NumRegs = 3; - else if (BasePtr->getNumOperands() && isa(BasePtr->getOperand(0))) - NumRegs = 3; - for (unsigned i = 3, e = N->getNumOperands(); i != e; ++i) - if (isa(N->getOperand(i)) && ( NumRegs-- == 0)) - return false; - return true; +def X86tcret_enough_regs : PatFrag<(ops node:$ptr, node:$off), + (X86tcret node:$ptr, node:$off), [{ + return checkTCRetEnoughRegs(N); }]>; // If this is an anyext of the remainder of an 8-bit sdivrem, use a MOVSX diff --git a/llvm/test/CodeGen/X86/fold-call-4.ll b/llvm/test/CodeGen/X86/fold-call-4.ll new file mode 100644 index 000000000000..2c99f2cb6264 --- /dev/null +++ b/llvm/test/CodeGen/X86/fold-call-4.ll @@ -0,0 +1,26 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-linux-gnu | FileCheck %s --check-prefix=LIN +; RUN: llc < %s -mtriple=x86_64-pc-windows-msvc | FileCheck %s --check-prefix=WIN + +; The callee address computation should get folded into the call. +; CHECK-LABEL: f: +; CHECK-NOT: mov +; LIN: jmpq *(%rdi,%rsi,8) +; WIN: rex64 jmpq *(%rcx,%rdx,8) +define void @f(ptr %table, i64 %idx, i64 %aux1, i64 %aux2, i64 %aux3) { +entry: + %arrayidx = getelementptr inbounds ptr, ptr %table, i64 %idx + %funcptr = load ptr, ptr %arrayidx, align 8 + tail call void %funcptr(ptr %table, i64 %idx, i64 %aux1, i64 %aux2, i64 %aux3) + ret void +} + +; Check that we don't assert here. On Win64 this has a TokenFactor with +; multiple uses, which we can't currently fold. +define void @thunk(ptr %this, ...) { +entry: + %vtable = load ptr, ptr %this, align 8 + %vfn = getelementptr inbounds nuw i8, ptr %vtable, i64 8 + %0 = load ptr, ptr %vfn, align 8 + musttail call void (ptr, ...) %0(ptr %this, ...) + ret void +} diff --git a/llvm/test/CodeGen/X86/fold-call.ll b/llvm/test/CodeGen/X86/fold-call.ll index 8be817618cd9..25b4df778768 100644 --- a/llvm/test/CodeGen/X86/fold-call.ll +++ b/llvm/test/CodeGen/X86/fold-call.ll @@ -24,3 +24,15 @@ entry: tail call void %0() ret void } + +; Don't fold the load+call if there's inline asm in between. +; CHECK: test3 +; CHECK: mov{{.*}} +; CHECK: jmp{{.*}} +define void @test3(ptr nocapture %x) { +entry: + %0 = load ptr, ptr %x + call void asm sideeffect "", ""() ; It could do anything. + tail call void %0() + ret void +} diff --git a/llvm/test/CodeGen/X86/tailcall-mem_enoughregs.ll b/llvm/test/CodeGen/X86/tailcall-mem_enoughregs.ll new file mode 100644 index 000000000000..f242053ecb15 --- /dev/null +++ b/llvm/test/CodeGen/X86/tailcall-mem_enoughregs.ll @@ -0,0 +1,82 @@ +; RUN: llc < %s -mtriple=i686-linux-gnu | FileCheck %s --check-prefix=CHECK --check-prefix=LIN32 +; RUN: llc < %s -mtriple=x86_64-linux-gnu | FileCheck %s --check-prefix=CHECK --check-prefix=LIN64 +; RUN: llc < %s -mtriple=x86_64-pc-win32 | FileCheck %s --check-prefix=CHECK --check-prefix=WIN64 + +; Check that we only fold the address computation (load) into a tail call +; when we're sure there is enough volatile registers available. + +@globl = global ptr null + +; CHECK-LABEL: test0: +define i32 @test0(ptr %a, ptr %b) { +entry: + %func = load ptr, ptr %a + %call = tail call i32 %func() + ret i32 %call + +; Call address load gets folded into the tail call. +; LIN32: jmpl *(% +; LIN64: jmpq *(% +; WIN64: jmpq *(% +} + +; CHECK-LABEL: test1: +define i32 @test1(ptr %a, ptr %b) { +entry: + %func = load ptr, ptr %a + %call = tail call i32 %func(i32 inreg 1) + ret i32 %call + +; Call address load gets folded into the tail call. +; LIN32: jmpl *(% +; LIN64: jmpq *(% +; WIN64: jmpq *(% +} + +; CHECK-LABEL: test2: +define i32 @test2(ptr %a, ptr %b) { +entry: + %func = load ptr, ptr %a + %call = tail call i32 %func(i32 inreg 1, i32 inreg 2) + ret i32 %call + +; On 32-bit we're not sure there is enough register to fold the load. +; LIN32: jmpl *% +; LIN64: jmpq *(% +; WIN64: jmpq *(% +} + +; CHECK-LABEL: test2_globl: +define i32 @test2_globl(ptr %a, ptr %b) { +entry: + %func = load ptr, ptr @globl + %call = tail call i32 %func(i32 inreg 1, i32 inreg 2) + ret i32 %call + +; .. but if the load is from a global, we can fold it. +; LIN32: jmpl *globl +; LIN64: jmpq *(% +; WIN64: jmpq *globl(%rip) +} + +; CHECK-LABEL: test2_stack: +define i32 @test2_stack(ptr %func, ptr %b) { +entry: + %call = tail call i32 %func(i32 inreg 1, i32 inreg 2) + ret i32 %call + +; and if the load is from the stack (on 32-bit, %func is passed on the stack): +; LIN32: jmpl *4(%esp) +} + +define i32 @test6(ptr %a, ptr %b) { +entry: + %func = load ptr, ptr %a + %call = tail call i32 %func(i32 inreg 1, i32 inreg 2, i32 inreg 3, i32 inreg 4, i32 inreg 5, i32 inreg 6) + ret i32 %call + +; LIN64: jmpq *(% + +; I wasn't able to pass more than 4 arguments in registers on Win64. +; WIN64: callq *( +}