From 81c0f3023fc38e3ea720045407a17f47653ea2ac Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Tue, 21 Jan 2025 19:02:35 +1100 Subject: [PATCH] [ORC] Add ExecutorSymbolDef toPtr / fromPtr convenience functions. This will simplify conversion of a number of APIs from ExecutorAddr to ExecutorSymbolDef. --- .../BuildingAJIT/Chapter1/toy.cpp | 2 +- .../BuildingAJIT/Chapter2/toy.cpp | 2 +- .../BuildingAJIT/Chapter3/toy.cpp | 2 +- .../BuildingAJIT/Chapter4/toy.cpp | 2 +- llvm/examples/Kaleidoscope/Chapter4/toy.cpp | 2 +- llvm/examples/Kaleidoscope/Chapter5/toy.cpp | 2 +- llvm/examples/Kaleidoscope/Chapter6/toy.cpp | 2 +- llvm/examples/Kaleidoscope/Chapter7/toy.cpp | 2 +- .../Orc/Shared/ExecutorSymbolDef.h | 31 ++++++++++++++++++ .../Orc/ExecutorAddressTest.cpp | 32 +++++++++++++++++++ 10 files changed, 71 insertions(+), 8 deletions(-) diff --git a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter1/toy.cpp b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter1/toy.cpp index 1b35ba404d29..426886c72e54 100644 --- a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter1/toy.cpp +++ b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter1/toy.cpp @@ -1159,7 +1159,7 @@ static void HandleTopLevelExpression() { // Get the symbol's address and cast it to the right type (takes no // arguments, returns a double) so we can call it as a native function. - auto *FP = Sym.getAddress().toPtr(); + auto *FP = Sym.toPtr(); fprintf(stderr, "Evaluated to %f\n", FP()); // Delete the anonymous expression module from the JIT. diff --git a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter2/toy.cpp b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter2/toy.cpp index 1b35ba404d29..426886c72e54 100644 --- a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter2/toy.cpp +++ b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter2/toy.cpp @@ -1159,7 +1159,7 @@ static void HandleTopLevelExpression() { // Get the symbol's address and cast it to the right type (takes no // arguments, returns a double) so we can call it as a native function. - auto *FP = Sym.getAddress().toPtr(); + auto *FP = Sym.toPtr(); fprintf(stderr, "Evaluated to %f\n", FP()); // Delete the anonymous expression module from the JIT. diff --git a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter3/toy.cpp b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter3/toy.cpp index 1b35ba404d29..426886c72e54 100644 --- a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter3/toy.cpp +++ b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter3/toy.cpp @@ -1159,7 +1159,7 @@ static void HandleTopLevelExpression() { // Get the symbol's address and cast it to the right type (takes no // arguments, returns a double) so we can call it as a native function. - auto *FP = Sym.getAddress().toPtr(); + auto *FP = Sym.toPtr(); fprintf(stderr, "Evaluated to %f\n", FP()); // Delete the anonymous expression module from the JIT. diff --git a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter4/toy.cpp b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter4/toy.cpp index 2c8d4941291e..1891635dbfd3 100644 --- a/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter4/toy.cpp +++ b/llvm/examples/Kaleidoscope/BuildingAJIT/Chapter4/toy.cpp @@ -1157,7 +1157,7 @@ static void HandleTopLevelExpression() { // Get the symbol's address and cast it to the right type (takes no // arguments, returns a double) so we can call it as a native function. - auto *FP = Sym.getAddress().toPtr(); + auto *FP = Sym.toPtr(); fprintf(stderr, "Evaluated to %f\n", FP()); // Delete the anonymous expression module from the JIT. diff --git a/llvm/examples/Kaleidoscope/Chapter4/toy.cpp b/llvm/examples/Kaleidoscope/Chapter4/toy.cpp index 1bbc294bf352..0f58391c5066 100644 --- a/llvm/examples/Kaleidoscope/Chapter4/toy.cpp +++ b/llvm/examples/Kaleidoscope/Chapter4/toy.cpp @@ -643,7 +643,7 @@ static void HandleTopLevelExpression() { // Get the symbol's address and cast it to the right type (takes no // arguments, returns a double) so we can call it as a native function. - double (*FP)() = ExprSymbol.getAddress().toPtr(); + double (*FP)() = ExprSymbol.toPtr(); fprintf(stderr, "Evaluated to %f\n", FP()); // Delete the anonymous expression module from the JIT. diff --git a/llvm/examples/Kaleidoscope/Chapter5/toy.cpp b/llvm/examples/Kaleidoscope/Chapter5/toy.cpp index 48936bddb1d4..7117eaf4982b 100644 --- a/llvm/examples/Kaleidoscope/Chapter5/toy.cpp +++ b/llvm/examples/Kaleidoscope/Chapter5/toy.cpp @@ -917,7 +917,7 @@ static void HandleTopLevelExpression() { // Get the symbol's address and cast it to the right type (takes no // arguments, returns a double) so we can call it as a native function. - double (*FP)() = ExprSymbol.getAddress().toPtr(); + double (*FP)() = ExprSymbol.toPtr(); fprintf(stderr, "Evaluated to %f\n", FP()); // Delete the anonymous expression module from the JIT. diff --git a/llvm/examples/Kaleidoscope/Chapter6/toy.cpp b/llvm/examples/Kaleidoscope/Chapter6/toy.cpp index ebe4322287b2..cb7b6cc8651c 100644 --- a/llvm/examples/Kaleidoscope/Chapter6/toy.cpp +++ b/llvm/examples/Kaleidoscope/Chapter6/toy.cpp @@ -1036,7 +1036,7 @@ static void HandleTopLevelExpression() { // Get the symbol's address and cast it to the right type (takes no // arguments, returns a double) so we can call it as a native function. - double (*FP)() = ExprSymbol.getAddress().toPtr(); + double (*FP)() = ExprSymbol.toPtr(); fprintf(stderr, "Evaluated to %f\n", FP()); // Delete the anonymous expression module from the JIT. diff --git a/llvm/examples/Kaleidoscope/Chapter7/toy.cpp b/llvm/examples/Kaleidoscope/Chapter7/toy.cpp index 374f2c03b48e..91b7191a07c6 100644 --- a/llvm/examples/Kaleidoscope/Chapter7/toy.cpp +++ b/llvm/examples/Kaleidoscope/Chapter7/toy.cpp @@ -1207,7 +1207,7 @@ static void HandleTopLevelExpression() { // Get the symbol's address and cast it to the right type (takes no // arguments, returns a double) so we can call it as a native function. - double (*FP)() = ExprSymbol.getAddress().toPtr(); + double (*FP)() = ExprSymbol.toPtr(); fprintf(stderr, "Evaluated to %f\n", FP()); // Delete the anonymous expression module from the JIT. diff --git a/llvm/include/llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h b/llvm/include/llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h index 68ccdf83bd12..0756ab5ea988 100644 --- a/llvm/include/llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h +++ b/llvm/include/llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h @@ -23,6 +23,37 @@ namespace orc { /// Represents a defining location for a JIT symbol. class ExecutorSymbolDef { public: + /// Create an ExecutorSymbolDef from the given pointer. + /// Warning: This should only be used when JITing in-process. + template > + static ExecutorSymbolDef fromPtr(T *Ptr, + JITSymbolFlags BaseFlags = JITSymbolFlags(), + UnwrapFn &&Unwrap = UnwrapFn()) { + auto *UP = Unwrap(Ptr); + JITSymbolFlags Flags = BaseFlags; + if (std::is_function_v) + Flags |= JITSymbolFlags::Callable; + return ExecutorSymbolDef( + ExecutorAddr::fromPtr(UP, ExecutorAddr::rawPtr()), Flags); + } + + /// Cast this ExecutorSymbolDef to a pointer of the given type. + /// Warning: This should only be used when JITing in-process. + template >> + std::enable_if_t::value, T> + toPtr(WrapFn &&Wrap = WrapFn()) const { + return Addr.toPtr(std::forward(Wrap)); + } + + /// Cast this ExecutorSymbolDef to a pointer of the given function type. + /// Warning: This should only be used when JITing in-process. + template > + std::enable_if_t::value, T *> + toPtr(WrapFn &&Wrap = WrapFn()) const { + return Addr.toPtr(std::forward(Wrap)); + } + ExecutorSymbolDef() = default; ExecutorSymbolDef(ExecutorAddr Addr, JITSymbolFlags Flags) : Addr(Addr), Flags(Flags) {} diff --git a/llvm/unittests/ExecutionEngine/Orc/ExecutorAddressTest.cpp b/llvm/unittests/ExecutionEngine/Orc/ExecutorAddressTest.cpp index e8b22b3d4bbb..3de77031291c 100644 --- a/llvm/unittests/ExecutionEngine/Orc/ExecutorAddressTest.cpp +++ b/llvm/unittests/ExecutionEngine/Orc/ExecutorAddressTest.cpp @@ -8,6 +8,7 @@ #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h" #include "OrcTestCommon.h" +#include "llvm/ExecutionEngine/Orc/Shared/ExecutorSymbolDef.h" using namespace llvm; using namespace llvm::orc; @@ -107,4 +108,35 @@ TEST(ExecutorAddrTest, AddrRanges) { EXPECT_GT(R1, R0); } +TEST(ExecutorSymbolDef, PointerConversion) { + int X = 0; + + auto XHiddenSym = ExecutorSymbolDef::fromPtr(&X); + int *XHiddenPtr = XHiddenSym.toPtr(); + + auto XExportedSym = ExecutorSymbolDef::fromPtr(&X, JITSymbolFlags::Exported); + int *XExportedPtr = XExportedSym.toPtr(); + + EXPECT_EQ(XHiddenPtr, &X); + EXPECT_EQ(XExportedPtr, &X); + + EXPECT_EQ(XHiddenSym.getFlags(), JITSymbolFlags()); + EXPECT_EQ(XExportedSym.getFlags(), JITSymbolFlags::Exported); +} + +TEST(ExecutorSymbolDef, FunctionPointerConversion) { + auto FHiddenSym = ExecutorSymbolDef::fromPtr(&F); + void (*FHiddenPtr)() = FHiddenSym.toPtr(); + + auto FExportedSym = ExecutorSymbolDef::fromPtr(&F, JITSymbolFlags::Exported); + void (*FExportedPtr)() = FExportedSym.toPtr(); + + EXPECT_EQ(FHiddenPtr, &F); + EXPECT_EQ(FExportedPtr, &F); + + EXPECT_EQ(FHiddenSym.getFlags(), JITSymbolFlags::Callable); + EXPECT_EQ(FExportedSym.getFlags(), + JITSymbolFlags::Exported | JITSymbolFlags::Callable); +} + } // namespace