diff --git a/clang/include/clang/AST/Mangle.h b/clang/include/clang/AST/Mangle.h index ca72dcfd4483..13fa0d1c880b 100644 --- a/clang/include/clang/AST/Mangle.h +++ b/clang/include/clang/AST/Mangle.h @@ -317,6 +317,11 @@ private: class Implementation; std::unique_ptr Impl; }; + +/// Constants used by LLDB for mangling. +struct LLDBManglingABI { + static constexpr llvm::StringLiteral FunctionLabelPrefix = "$__lldb_func:"; +}; } // namespace clang #endif diff --git a/clang/lib/AST/Mangle.cpp b/clang/lib/AST/Mangle.cpp index 780b2c585c81..58216667116a 100644 --- a/clang/lib/AST/Mangle.cpp +++ b/clang/lib/AST/Mangle.cpp @@ -152,8 +152,6 @@ bool MangleContext::shouldMangleDeclName(const NamedDecl *D) { return shouldMangleCXXName(D); } -static llvm::StringRef g_lldb_func_call_label_prefix = "$__lldb_func:"; - /// Given an LLDB function call label, this function prints the label /// into \c Out, together with the structor type of \c GD (if the /// decl is a constructor/destructor). LLDB knows how to handle mangled @@ -167,9 +165,9 @@ static llvm::StringRef g_lldb_func_call_label_prefix = "$__lldb_func:"; /// static void emitLLDBAsmLabel(llvm::StringRef label, GlobalDecl GD, llvm::raw_ostream &Out) { - assert(label.starts_with(g_lldb_func_call_label_prefix)); + assert(label.starts_with(LLDBManglingABI::FunctionLabelPrefix)); - Out << g_lldb_func_call_label_prefix; + Out << LLDBManglingABI::FunctionLabelPrefix; if (auto *Ctor = llvm::dyn_cast(GD.getDecl())) { Out << "C"; @@ -180,7 +178,7 @@ static void emitLLDBAsmLabel(llvm::StringRef label, GlobalDecl GD, Out << "D" << GD.getDtorType(); } - Out << label.substr(g_lldb_func_call_label_prefix.size()); + Out << label.substr(LLDBManglingABI::FunctionLabelPrefix.size()); } void MangleContext::mangleName(GlobalDecl GD, raw_ostream &Out) { @@ -216,7 +214,7 @@ void MangleContext::mangleName(GlobalDecl GD, raw_ostream &Out) { if (!UserLabelPrefix.empty()) Out << '\01'; // LLVM IR Marker for __asm("foo") - if (ALA->getLabel().starts_with(g_lldb_func_call_label_prefix)) + if (ALA->getLabel().starts_with(LLDBManglingABI::FunctionLabelPrefix)) emitLLDBAsmLabel(ALA->getLabel(), GD, Out); else Out << ALA->getLabel(); diff --git a/clang/lib/CodeGen/CGPointerAuth.cpp b/clang/lib/CodeGen/CGPointerAuth.cpp index 84b5c86e69a5..28d3289dfe04 100644 --- a/clang/lib/CodeGen/CGPointerAuth.cpp +++ b/clang/lib/CodeGen/CGPointerAuth.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "CGCXXABI.h" #include "CodeGenFunction.h" #include "CodeGenModule.h" #include "clang/CodeGen/CodeGenABITypes.h" @@ -62,8 +63,22 @@ CodeGenModule::getPointerAuthDeclDiscriminator(GlobalDecl Declaration) { uint16_t &EntityHash = PtrAuthDiscriminatorHashes[Declaration]; if (EntityHash == 0) { - StringRef Name = getMangledName(Declaration); - EntityHash = llvm::getPointerAuthStableSipHash(Name); + const auto *ND = cast(Declaration.getDecl()); + if (ND->hasAttr() && + ND->getAttr()->getLabel().starts_with( + LLDBManglingABI::FunctionLabelPrefix)) { + // If the declaration comes from LLDB, the asm label has a prefix that + // would producing a different discriminator. Compute the real C++ mangled + // name instead so the discriminator matches what the original translation + // unit used. + SmallString<256> Buffer; + llvm::raw_svector_ostream Out(Buffer); + getCXXABI().getMangleContext().mangleCXXName(Declaration, Out); + EntityHash = llvm::getPointerAuthStableSipHash(Out.str()); + } else { + StringRef Name = getMangledName(Declaration); + EntityHash = llvm::getPointerAuthStableSipHash(Name); + } } return EntityHash; diff --git a/lldb/source/Plugins/ExpressionParser/Clang/ClangExpressionParser.cpp b/lldb/source/Plugins/ExpressionParser/Clang/ClangExpressionParser.cpp index 38e5298f9cc9..a69a4ec3d529 100644 --- a/lldb/source/Plugins/ExpressionParser/Clang/ClangExpressionParser.cpp +++ b/lldb/source/Plugins/ExpressionParser/Clang/ClangExpressionParser.cpp @@ -731,6 +731,8 @@ static void SetPointerAuthOptionsForArm64e(LangOptions &lang_opts) { lang_opts.PointerAuthReturns = true; lang_opts.PointerAuthAuthTraps = true; lang_opts.PointerAuthIndirectGotos = true; + lang_opts.PointerAuthVTPtrAddressDiscrimination = true; + lang_opts.PointerAuthVTPtrTypeDiscrimination = true; lang_opts.PointerAuthObjcIsa = true; lang_opts.PointerAuthObjcClassROPointers = true; lang_opts.PointerAuthObjcInterfaceSel = true; diff --git a/lldb/test/API/commands/expression/ptrauth-vtable/Makefile b/lldb/test/API/commands/expression/ptrauth-vtable/Makefile new file mode 100644 index 000000000000..3c6bc2dd007e --- /dev/null +++ b/lldb/test/API/commands/expression/ptrauth-vtable/Makefile @@ -0,0 +1,8 @@ +CXX_SOURCES := main.cpp + +override ARCH := arm64e + +# We need an arm64e stblib. +USE_SYSTEM_STDLIB := 1 + +include Makefile.rules diff --git a/lldb/test/API/commands/expression/ptrauth-vtable/TestPtrAuthVTableExpressions.py b/lldb/test/API/commands/expression/ptrauth-vtable/TestPtrAuthVTableExpressions.py new file mode 100644 index 000000000000..92a30b6e6548 --- /dev/null +++ b/lldb/test/API/commands/expression/ptrauth-vtable/TestPtrAuthVTableExpressions.py @@ -0,0 +1,44 @@ +""" +VTable pointers are signed with a discriminator that incorporates the object's +address (PointerAuthVTPtrAddressDiscrimination) and class type ( +PointerAuthVTPtrTypeDiscrimination). +""" + +import lldb +from lldbsuite.test.decorators import * +from lldbsuite.test.lldbtest import * +from lldbsuite.test import lldbutil + + +class TestPtrAuthVTableExpressions(TestBase): + NO_DEBUG_INFO_TESTCASE = True + + @skipUnlessArm64eSupported + def test_virtual_call_on_debuggee_object(self): + self.build() + lldbutil.run_to_source_breakpoint( + self, "// break here", lldb.SBFileSpec("main.cpp", False) + ) + + self.expect_expr("d.value()", result_type="int", result_value="20") + self.expect_expr("od.value()", result_type="int", result_value="30") + + @skipUnlessArm64eSupported + def test_virtual_call_through_base_pointer(self): + self.build() + lldbutil.run_to_source_breakpoint( + self, "// break here", lldb.SBFileSpec("main.cpp", False) + ) + + self.expect_expr("base_ptr->value()", result_type="int", result_value="20") + + @skipUnlessArm64eSupported + def test_virtual_call_via_helper(self): + self.build() + lldbutil.run_to_source_breakpoint( + self, "// break here", lldb.SBFileSpec("main.cpp", False) + ) + + self.expect_expr("call_value(&d)", result_type="int", result_value="20") + self.expect_expr("call_value(&od)", result_type="int", result_value="30") + self.expect_expr("call_value(base_ptr)", result_type="int", result_value="20") diff --git a/lldb/test/API/commands/expression/ptrauth-vtable/main.cpp b/lldb/test/API/commands/expression/ptrauth-vtable/main.cpp new file mode 100644 index 000000000000..d9dec9b9a6a4 --- /dev/null +++ b/lldb/test/API/commands/expression/ptrauth-vtable/main.cpp @@ -0,0 +1,27 @@ +#include + +class Base { +public: + virtual int value() { return 10; } + virtual ~Base() = default; +}; + +class Derived : public Base { +public: + int value() override { return 20; } +}; + +class OtherDerived : public Base { +public: + int value() override { return 30; } +}; + +int call_value(Base *obj) { return obj->value(); } + +int main() { + Derived d; + OtherDerived od; + Base *base_ptr = &d; + printf("%d %d %d\n", d.value(), od.value(), base_ptr->value()); + return 0; // break here +}