diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp index 31e88dcc12bf..57867fc8afb7 100644 --- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp +++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp @@ -189,13 +189,21 @@ public: // Check authenticated LR before tail calling. void emitPtrauthTailCallHardening(const MachineInstr *TC); + struct PtrAuthSchema { + PtrAuthSchema(AArch64PACKey::ID Key, uint64_t IntDisc, + const MachineOperand &AddrDiscOp); + + AArch64PACKey::ID Key; + uint64_t IntDisc; + Register AddrDisc; + bool AddrDiscIsKilled; + }; + // Emit the sequence for AUT or AUTPAC. - void emitPtrauthAuthResign(Register AUTVal, AArch64PACKey::ID AUTKey, - uint64_t AUTDisc, - const MachineOperand *AUTAddrDisc, - Register Scratch, - std::optional PACKey, - uint64_t PACDisc, Register PACAddrDisc, Value *DS); + void emitPtrauthAuthResign(Register Pointer, Register Scratch, + PtrAuthSchema AuthSchema, + std::optional SignSchema, + Value *DS); // Emit R_AARCH64_PATCHINST, the deactivation symbol relocation. Returns true // if no instruction should be emitted because the deactivation symbol is @@ -2222,12 +2230,15 @@ bool AArch64AsmPrinter::emitDeactivationSymbolRelocation(Value *DS) { return false; } +AArch64AsmPrinter::PtrAuthSchema::PtrAuthSchema( + AArch64PACKey::ID Key, uint64_t IntDisc, const MachineOperand &AddrDiscOp) + : Key(Key), IntDisc(IntDisc), AddrDisc(AddrDiscOp.getReg()), + AddrDiscIsKilled(AddrDiscOp.isKill()) {} + void AArch64AsmPrinter::emitPtrauthAuthResign( - Register AUTVal, AArch64PACKey::ID AUTKey, uint64_t AUTDisc, - const MachineOperand *AUTAddrDisc, Register Scratch, - std::optional PACKey, uint64_t PACDisc, - Register PACAddrDisc, Value *DS) { - const bool IsAUTPAC = PACKey.has_value(); + Register Pointer, Register Scratch, PtrAuthSchema AuthSchema, + std::optional SignSchema, Value *DS) { + const bool IsResign = SignSchema.has_value(); // We expand AUT/AUTPAC into a sequence of the form // @@ -2267,35 +2278,38 @@ void AArch64AsmPrinter::emitPtrauthAuthResign( } // Compute aut discriminator - Register AUTDiscReg = emitPtrauthDiscriminator( - AUTDisc, AUTAddrDisc->getReg(), Scratch, AUTAddrDisc->isKill()); + Register AUTDiscReg = + emitPtrauthDiscriminator(AuthSchema.IntDisc, AuthSchema.AddrDisc, Scratch, + AuthSchema.AddrDiscIsKilled); if (!emitDeactivationSymbolRelocation(DS)) - emitAUT(AUTKey, AUTVal, AUTDiscReg); + emitAUT(AuthSchema.Key, Pointer, AUTDiscReg); // Unchecked or checked-but-non-trapping AUT is just an "AUT": we're done. - if (!IsAUTPAC && (!ShouldCheck || !ShouldTrap)) + if (!IsResign && (!ShouldCheck || !ShouldTrap)) return; MCSymbol *EndSym = nullptr; if (ShouldCheck) { - if (IsAUTPAC && !ShouldTrap) + if (IsResign && !ShouldTrap) EndSym = createTempSymbol("resign_end_"); - emitPtrauthCheckAuthenticatedValue( - AUTVal, Scratch, AUTKey, AArch64PAuth::AuthCheckMethod::XPAC, EndSym); + emitPtrauthCheckAuthenticatedValue(Pointer, Scratch, AuthSchema.Key, + AArch64PAuth::AuthCheckMethod::XPAC, + EndSym); } // We already emitted unchecked and checked-but-non-trapping AUTs. // That left us with trapping AUTs, and AUTPACs. // Trapping AUTs don't need PAC: we're done. - if (!IsAUTPAC) + if (!IsResign) return; // Compute pac discriminator - Register PACDiscReg = emitPtrauthDiscriminator(PACDisc, PACAddrDisc, Scratch); - emitPAC(*PACKey, AUTVal, PACDiscReg); + Register PACDiscReg = emitPtrauthDiscriminator(SignSchema->IntDisc, + SignSchema->AddrDisc, Scratch); + emitPAC(SignSchema->Key, Pointer, PACDiscReg); // Lend: if (EndSym) @@ -3199,29 +3213,44 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) { return; } - case AArch64::AUTx16x17: - emitPtrauthAuthResign( - AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), - MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, - std::nullopt, 0, 0, MI->getDeactivationSymbol()); - return; + case AArch64::AUTx16x17: { + const Register Pointer = AArch64::X16; + const Register Scratch = AArch64::X17; - case AArch64::AUTxMxN: - emitPtrauthAuthResign(MI->getOperand(0).getReg(), - (AArch64PACKey::ID)MI->getOperand(3).getImm(), - MI->getOperand(4).getImm(), &MI->getOperand(5), - MI->getOperand(1).getReg(), std::nullopt, 0, 0, + PtrAuthSchema AuthSchema((AArch64PACKey::ID)MI->getOperand(0).getImm(), + MI->getOperand(1).getImm(), MI->getOperand(2)); + + emitPtrauthAuthResign(Pointer, Scratch, AuthSchema, std::nullopt, MI->getDeactivationSymbol()); return; + } - case AArch64::AUTPAC: - emitPtrauthAuthResign( - AArch64::X16, (AArch64PACKey::ID)MI->getOperand(0).getImm(), - MI->getOperand(1).getImm(), &MI->getOperand(2), AArch64::X17, - (AArch64PACKey::ID)MI->getOperand(3).getImm(), - MI->getOperand(4).getImm(), MI->getOperand(5).getReg(), - MI->getDeactivationSymbol()); + case AArch64::AUTxMxN: { + const Register Pointer = MI->getOperand(0).getReg(); + const Register Scratch = MI->getOperand(1).getReg(); + + PtrAuthSchema AuthSchema((AArch64PACKey::ID)MI->getOperand(3).getImm(), + MI->getOperand(4).getImm(), MI->getOperand(5)); + + emitPtrauthAuthResign(Pointer, Scratch, AuthSchema, std::nullopt, + MI->getDeactivationSymbol()); return; + } + + case AArch64::AUTPAC: { + const Register Pointer = AArch64::X16; + const Register Scratch = AArch64::X17; + + PtrAuthSchema AuthSchema((AArch64PACKey::ID)MI->getOperand(0).getImm(), + MI->getOperand(1).getImm(), MI->getOperand(2)); + + PtrAuthSchema SignSchema((AArch64PACKey::ID)MI->getOperand(3).getImm(), + MI->getOperand(4).getImm(), MI->getOperand(5)); + + emitPtrauthAuthResign(Pointer, Scratch, AuthSchema, SignSchema, + MI->getDeactivationSymbol()); + return; + } case AArch64::PAC: emitPtrauthSign(MI);