[InstCombine] Combine ptrauth constant callee into bundle. (#94706)

Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
    call(ptrauth(f)), ["ptrauth"()] ->  call f
as long as the key/discriminator are the same in constant and bundle.
This commit is contained in:
Ahmed Bougacha 2025-07-15 13:37:07 -07:00 committed by GitHub
parent 1a940bfff9
commit 42d2ae1034
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 134 additions and 0 deletions

View File

@ -4050,6 +4050,34 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) {
return nullptr;
}
Instruction *InstCombinerImpl::foldPtrAuthConstantCallee(CallBase &Call) {
auto *CPA = dyn_cast<ConstantPtrAuth>(Call.getCalledOperand());
if (!CPA)
return nullptr;
auto *CalleeF = dyn_cast<Function>(CPA->getPointer());
// If the ptrauth constant isn't based on a function pointer, bail out.
if (!CalleeF)
return nullptr;
// Inspect the call ptrauth bundle to check it matches the ptrauth constant.
auto PAB = Call.getOperandBundle(LLVMContext::OB_ptrauth);
if (!PAB)
return nullptr;
auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
Value *Discriminator = PAB->Inputs[1];
// If the bundle doesn't match, this is probably going to fail to auth.
if (!CPA->isKnownCompatibleWith(Key, Discriminator, DL))
return nullptr;
// If the bundle matches the constant, proceed in making this a direct call.
auto *NewCall = CallBase::removeOperandBundle(&Call, LLVMContext::OB_ptrauth);
NewCall->setCalledOperand(CalleeF);
return NewCall;
}
bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call,
const TargetLibraryInfo *TLI) {
// Note: We only handle cases which can't be driven from generic attributes
@ -4210,6 +4238,10 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
if (IntrinsicInst *II = findInitTrampoline(Callee))
return transformCallThroughTrampoline(Call, *II);
// Combine calls to ptrauth constants.
if (Instruction *NewCall = foldPtrAuthConstantCallee(Call))
return NewCall;
if (isa<InlineAsm>(Callee) && !Call.doesNotThrow()) {
InlineAsm *IA = cast<InlineAsm>(Callee);
if (!IA->canThrow()) {

View File

@ -283,6 +283,11 @@ private:
Instruction *transformCallThroughTrampoline(CallBase &Call,
IntrinsicInst &Tramp);
/// Try to optimize a call to a ptrauth constant, into its ptrauth bundle:
/// call(ptrauth(f)), ["ptrauth"()] -> call f
/// as long as the key/discriminator are the same in constant and bundle.
Instruction *foldPtrAuthConstantCallee(CallBase &Call);
// Return (a, b) if (LHS, RHS) is known to be (a, b) or (b, a).
// Otherwise, return std::nullopt
// Currently it matches:

View File

@ -0,0 +1,97 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
declare i64 @f(i32)
declare ptr @f2(i32)
define i32 @test_ptrauth_call(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call(
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
; CHECK-NEXT: ret i32 [[V0]]
;
%v0 = call i32 ptrauth(ptr @f, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
ret i32 %v0
}
define i32 @test_ptrauth_call_disc(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call_disc(
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
; CHECK-NEXT: ret i32 [[V0]]
;
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 5678) ]
ret i32 %v0
}
@f_addr_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)
define i32 @test_ptrauth_call_addr_disc(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call_addr_disc(
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
; CHECK-NEXT: ret i32 [[V0]]
;
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 0, ptr @f_addr_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 ptrtoint (ptr @f_addr_disc.ref to i64)) ]
ret i32 %v0
}
@f_both_disc.ref = constant ptr ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)
define i32 @test_ptrauth_call_blend(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call_blend(
; CHECK-NEXT: [[V0:%.*]] = call i32 @f(i32 [[A0:%.*]])
; CHECK-NEXT: ret i32 [[V0]]
;
%v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 1234)
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
ret i32 %v0
}
define i64 @test_ptrauth_call_cast(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call_cast(
; CHECK-NEXT: [[V0:%.*]] = call i64 @f2(i32 [[A0:%.*]])
; CHECK-NEXT: ret i64 [[V0]]
;
%v0 = call i64 ptrauth(ptr @f2, i32 0)(i32 %a0) [ "ptrauth"(i32 0, i64 0) ]
ret i64 %v0
}
define i32 @test_ptrauth_call_mismatch_key(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call_mismatch_key(
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 0, i64 5678) ]
; CHECK-NEXT: ret i32 [[V0]]
;
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 0, i64 5678) ]
ret i32 %v0
}
define i32 @test_ptrauth_call_mismatch_disc(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call_mismatch_disc(
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 5678)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 0) ]
; CHECK-NEXT: ret i32 [[V0]]
;
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 5678)(i32 %a0) [ "ptrauth"(i32 1, i64 0) ]
ret i32 %v0
}
define i32 @test_ptrauth_call_mismatch_blend(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call_mismatch_blend(
; CHECK-NEXT: [[V:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 [[V]]) ]
; CHECK-NEXT: ret i32 [[V0]]
;
%v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_both_disc.ref to i64), i64 0)
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
ret i32 %v0
}
define i32 @test_ptrauth_call_mismatch_blend_addr(i32 %a0) {
; CHECK-LABEL: @test_ptrauth_call_mismatch_blend_addr(
; CHECK-NEXT: [[V:%.*]] = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_addr_disc.ref to i64), i64 1234)
; CHECK-NEXT: [[V0:%.*]] = call i32 ptrauth (ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 [[A0:%.*]]) [ "ptrauth"(i32 1, i64 [[V]]) ]
; CHECK-NEXT: ret i32 [[V0]]
;
%v = call i64 @llvm.ptrauth.blend(i64 ptrtoint (ptr @f_addr_disc.ref to i64), i64 1234)
%v0 = call i32 ptrauth(ptr @f, i32 1, i64 1234, ptr @f_both_disc.ref)(i32 %a0) [ "ptrauth"(i32 1, i64 %v) ]
ret i32 %v0
}
declare i64 @llvm.ptrauth.blend(i64, i64)