diff --git a/orc-rt/include/orc-rt/Session.h b/orc-rt/include/orc-rt/Session.h index e6963bfb50b2..0cd3d85708ea 100644 --- a/orc-rt/include/orc-rt/Session.h +++ b/orc-rt/include/orc-rt/Session.h @@ -14,6 +14,7 @@ #define ORC_RT_SESSION_H #include "orc-rt/Error.h" +#include "orc-rt/LockedAccess.h" #include "orc-rt/Service.h" #include "orc-rt/TaskDispatcher.h" #include "orc-rt/WrapperFunction.h" @@ -50,6 +51,8 @@ public: using OnCallHandlerCompleteFn = move_only_function; + using SymbolMap = std::unordered_map; + /// Provides access to the controller. class ControllerAccess { friend class Session; @@ -116,9 +119,7 @@ public: /// Note that entry into the reporter is not synchronized: it may be /// called from multiple threads concurrently. Session(std::unique_ptr Dispatcher, - ErrorReporterFn ReportError) - : Dispatcher(std::move(Dispatcher)), ReportError(std::move(ReportError)) { - } + ErrorReporterFn ReportError); // Sessions are not copyable or moveable. Session(const Session &) = delete; @@ -134,6 +135,12 @@ public: /// Report an error via the ErrorReporter function. void reportError(Error Err) { ReportError(std::move(Err)); } + /// Controller interface symbols map. + auto controllerInterface() { return LockedAccess(ControllerInterface, M); } + auto controllerInterface() const { + return LockedAccess(ControllerInterface, M); + } + /// Initiate session shutdown. /// /// Runs shutdown on registered resources in reverse order. @@ -204,8 +211,9 @@ private: std::shared_ptr CA; ErrorReporterFn ReportError; - std::mutex M; + mutable std::mutex M; std::vector> Services; + SymbolMap ControllerInterface; std::unique_ptr SI; }; diff --git a/orc-rt/lib/executor/Session.cpp b/orc-rt/lib/executor/Session.cpp index 379bcee0cb37..8f6545ee2d4b 100644 --- a/orc-rt/lib/executor/Session.cpp +++ b/orc-rt/lib/executor/Session.cpp @@ -16,6 +16,12 @@ namespace orc_rt { Session::ControllerAccess::~ControllerAccess() = default; +Session::Session(std::unique_ptr Dispatcher, + ErrorReporterFn ReportError) + : Dispatcher(std::move(Dispatcher)), ReportError(std::move(ReportError)) { + ControllerInterface["orc_rt_SessionInstance"] = static_cast(this); +} + Session::~Session() { waitForShutdown(); } void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) { diff --git a/orc-rt/unittests/SessionTest.cpp b/orc-rt/unittests/SessionTest.cpp index d4ff1a96f5c0..76875b4dcb93 100644 --- a/orc-rt/unittests/SessionTest.cpp +++ b/orc-rt/unittests/SessionTest.cpp @@ -387,6 +387,35 @@ TEST(SessionTest, CreateServiceAndUseRef) { CS.doMoreConfig(1); } +TEST(SessionTest, ControllerInterfaceContainsSessionByDefault) { + Session S(std::make_unique(), noErrors); + ASSERT_TRUE(S.controllerInterface()->count("orc_rt_SessionInstance")); + EXPECT_EQ(S.controllerInterface()->at("orc_rt_SessionInstance"), + static_cast(&S)); +} + +TEST(SessionTest, ControllerInterfaceWithRef) { + Session S(std::make_unique(), noErrors); + int X = 0, Y = 0; + S.controllerInterface().with_ref([&](Session::SymbolMap &Syms) { + Syms["orc_rt_A"] = &X; + Syms["orc_rt_B"] = &Y; + }); + + EXPECT_EQ(S.controllerInterface()->at("orc_rt_A"), &X); + EXPECT_EQ(S.controllerInterface()->at("orc_rt_B"), &Y); +} + +TEST(SessionTest, ControllerInterfaceConstAccess) { + Session S(std::make_unique(), noErrors); + int X = 0; + S.controllerInterface()->emplace("orc_rt_X", &X); + + const Session &CS = S; + ASSERT_TRUE(CS.controllerInterface()->count("orc_rt_X")); + EXPECT_EQ(CS.controllerInterface()->at("orc_rt_X"), &X); +} + TEST(ControllerAccessTest, Basics) { // Test that we can set the ControllerAccess implementation and still shut // down as expected.