//===- SessionTest.cpp ----------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Tests for orc-rt's Session.h APIs. // //===----------------------------------------------------------------------===// #include "orc-rt/Session.h" #include "orc-rt/SPSWrapperFunction.h" #include "orc-rt/ThreadPoolTaskDispatcher.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include #include #include #include using namespace orc_rt; using ::testing::Eq; using ::testing::Optional; class MockService : public Service { public: enum class Op { Detach, Shutdown }; static void noop(Op) {} MockService(std::optional &DetachOpIdx, std::optional &ShutdownOpIdx, size_t &OpIdx, move_only_function GenResult = noop) : DetachOpIdx(DetachOpIdx), ShutdownOpIdx(ShutdownOpIdx), OpIdx(OpIdx), GenResult(std::move(GenResult)) {} void onDetach(OnCompleteFn OnComplete) override { DetachOpIdx = OpIdx++; GenResult(Op::Detach); OnComplete(); } void onShutdown(OnCompleteFn OnComplete) override { ShutdownOpIdx = OpIdx++; GenResult(Op::Shutdown); OnComplete(); } private: std::optional &DetachOpIdx; std::optional &ShutdownOpIdx; size_t &OpIdx; move_only_function GenResult; }; class ConfigurableService : public Service { public: ConfigurableService(int ConstructorOption) {} void onDetach(OnCompleteFn OnComplete) override { OnComplete(); } void onShutdown(OnCompleteFn OnComplete) override { OnComplete(); } void doMoreConfig(int) noexcept {} }; class NoDispatcher : public TaskDispatcher { public: void dispatch(std::unique_ptr T) override { assert(false && "strictly no dispatching!"); } void shutdown() override {} }; class EnqueueingDispatcher : public TaskDispatcher { public: using OnShutdownRunFn = move_only_function; EnqueueingDispatcher(std::deque> &Tasks, OnShutdownRunFn OnShutdownRun = {}) : Tasks(Tasks), OnShutdownRun(std::move(OnShutdownRun)) {} void dispatch(std::unique_ptr T) override { Tasks.push_back(std::move(T)); } void shutdown() override { if (OnShutdownRun) OnShutdownRun(); } /// Run up to NumTasks (arbitrarily many if NumTasks == std::nullopt) tasks /// from the front of the queue, returning the number actually run. static size_t runTasksFromFront(std::deque> &Tasks, std::optional NumTasks = std::nullopt) { size_t NumRun = 0; while (!Tasks.empty() && (!NumTasks || NumRun != *NumTasks)) { auto T = std::move(Tasks.front()); Tasks.pop_front(); T->run(); ++NumRun; } return NumRun; } private: std::deque> &Tasks; OnShutdownRunFn OnShutdownRun; }; class MockControllerAccess : public Session::ControllerAccess { public: MockControllerAccess(Session &SS) : Session::ControllerAccess(SS), SS(SS) {} void disconnect() override { std::unique_lock Lock(M); Shutdown = true; ShutdownCV.wait(Lock, [this]() { return Shutdown && Outstanding == 0; }); } void callController(OnCallHandlerCompleteFn OnComplete, HandlerTag T, WrapperFunctionBuffer ArgBytes) override { // Simulate a call to the controller by dispatching a task to run the // requested function. size_t CId; { std::scoped_lock Lock(M); if (Shutdown) return; CId = CallId++; Pending[CId] = std::move(OnComplete); ++Outstanding; } SS.dispatch(makeGenericTask([this, CId, OnComplete = std::move(OnComplete), T, ArgBytes = std::move(ArgBytes)]() mutable { auto Fn = reinterpret_cast(T); Fn(reinterpret_cast(this), CId, wfReturn, ArgBytes.release()); })); bool Notify = false; { std::scoped_lock Lock(M); if (--Outstanding == 0 && Shutdown) Notify = true; } if (Notify) ShutdownCV.notify_all(); } void sendWrapperResult(uint64_t CallId, WrapperFunctionBuffer ResultBytes) override { // Respond to a simulated call by the controller. OnCallHandlerCompleteFn OnComplete; { std::scoped_lock Lock(M); if (Shutdown) { assert(Pending.empty() && "Shut down but results still pending?"); return; } auto I = Pending.find(CallId); assert(I != Pending.end()); OnComplete = std::move(I->second); Pending.erase(I); ++Outstanding; } SS.dispatch( makeGenericTask([OnComplete = std::move(OnComplete), ResultBytes = std::move(ResultBytes)]() mutable { OnComplete(std::move(ResultBytes)); })); bool Notify = false; { std::scoped_lock Lock(M); if (--Outstanding == 0 && Shutdown) Notify = true; } if (Notify) ShutdownCV.notify_all(); } void callFromController(OnCallHandlerCompleteFn OnComplete, orc_rt_WrapperFunction Fn, WrapperFunctionBuffer ArgBytes) { size_t CId = 0; bool BailOut = false; { std::scoped_lock Lock(M); if (!Shutdown) { CId = CallId++; Pending[CId] = std::move(OnComplete); ++Outstanding; } else BailOut = true; } if (BailOut) return OnComplete(WrapperFunctionBuffer::createOutOfBandError( "Controller disconnected")); handleWrapperCall(CId, Fn, std::move(ArgBytes)); bool Notify = false; { std::scoped_lock Lock(M); if (--Outstanding == 0 && Shutdown) Notify = true; } if (Notify) ShutdownCV.notify_all(); } /// Simulate start of outstanding operation. void incOutstanding() { std::scoped_lock Lock(M); ++Outstanding; } /// Simulate end of outstanding operation. void decOutstanding() { bool Notify = false; { std::scoped_lock Lock(M); if (--Outstanding == 0 && Shutdown) Notify = true; } if (Notify) ShutdownCV.notify_all(); } private: static void wfReturn(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionBuffer ResultBytes) { // Abuse "session" to refer to the ControllerAccess object. // We can just re-use sendFunctionResult for this. reinterpret_cast(S)->sendWrapperResult( CallId, WrapperFunctionBuffer(ResultBytes)); } Session &SS; std::mutex M; bool Shutdown = false; size_t Outstanding = 0; size_t CallId = 0; std::unordered_map Pending; std::condition_variable ShutdownCV; }; class CallViaMockControllerAccess { public: CallViaMockControllerAccess(MockControllerAccess &CA, orc_rt_WrapperFunction Fn) : CA(CA), Fn(Fn) {} void operator()(Session::OnCallHandlerCompleteFn OnComplete, WrapperFunctionBuffer ArgBytes) { CA.callFromController(std::move(OnComplete), Fn, std::move(ArgBytes)); } private: MockControllerAccess &CA; orc_rt_WrapperFunction Fn; }; // Non-overloaded version of cantFail: allows easy construction of // move_only_functionss. static void noErrors(Error Err) { cantFail(std::move(Err)); } TEST(SessionTest, TrivialConstructionAndDestruction) { Session S(std::make_unique(), noErrors); } TEST(SessionTest, ReportError) { Error E = Error::success(); cantFail(std::move(E)); // Force error into checked state. Session S(std::make_unique(), [&](Error Err) { E = std::move(Err); }); S.reportError(make_error("foo")); if (E) EXPECT_EQ(toString(std::move(E)), "foo"); else ADD_FAILURE() << "Missing error value"; } TEST(SessionTest, DispatchTask) { int X = 0; std::deque> Tasks; Session S(std::make_unique(Tasks), noErrors); EXPECT_EQ(Tasks.size(), 0U); S.dispatch(makeGenericTask([&]() { ++X; })); EXPECT_EQ(Tasks.size(), 1U); auto T = std::move(Tasks.front()); Tasks.pop_front(); T->run(); EXPECT_EQ(X, 1); } TEST(SessionTest, SingleService) { size_t OpIdx = 0; std::optional DetachOpIdx; std::optional ShutdownOpIdx; { Session S(std::make_unique(), noErrors); S.addService( std::make_unique(DetachOpIdx, ShutdownOpIdx, OpIdx)); } EXPECT_EQ(OpIdx, 1U); EXPECT_EQ(DetachOpIdx, std::nullopt); EXPECT_THAT(ShutdownOpIdx, Optional(Eq(0))); } TEST(SessionTest, MultipleServices) { size_t OpIdx = 0; std::optional DetachOpIdx[3]; std::optional ShutdownOpIdx[3]; { Session S(std::make_unique(), noErrors); for (size_t I = 0; I != 3; ++I) S.addService(std::make_unique(DetachOpIdx[I], ShutdownOpIdx[I], OpIdx)); } EXPECT_EQ(OpIdx, 3U); // Expect shutdown in reverse order. for (size_t I = 0; I != 3; ++I) { EXPECT_EQ(DetachOpIdx[I], std::nullopt); EXPECT_THAT(ShutdownOpIdx[I], Optional(Eq(2 - I))); } } TEST(SessionTest, ExpectedShutdownSequence) { // Check that Session shutdown results in... // 1. Services being shut down. // 2. The TaskDispatcher being shut down. // 3. A call to OnShutdownComplete. size_t OpIdx = 0; std::optional DetachOpIdx; std::optional ShutdownOpIdx; bool DispatcherShutDown = false; bool SessionShutdownComplete = false; std::deque> Tasks; Session S(std::make_unique( Tasks, [&]() { EXPECT_TRUE(ShutdownOpIdx); EXPECT_EQ(*ShutdownOpIdx, 0); EXPECT_FALSE(SessionShutdownComplete); DispatcherShutDown = true; }), noErrors); S.addService( std::make_unique(DetachOpIdx, ShutdownOpIdx, OpIdx)); S.shutdown([&]() { EXPECT_TRUE(DispatcherShutDown); SessionShutdownComplete = true; }); S.waitForShutdown(); EXPECT_TRUE(SessionShutdownComplete); } TEST(SessionTest, AddServiceAndUseRef) { Session S(std::make_unique(), noErrors); auto &CS = S.addService(std::make_unique(42)); CS.doMoreConfig(1); } TEST(SessionTest, CreateServiceAndUseRef) { Session S(std::make_unique(), noErrors); auto &CS = S.createService(42); CS.doMoreConfig(1); } TEST(ControllerAccessTest, Basics) { // Test that we can set the ControllerAccess implementation and still shut // down as expected. std::deque> Tasks; Session S(std::make_unique(Tasks), noErrors); auto CA = std::make_shared(S); S.setController(CA); EnqueueingDispatcher::runTasksFromFront(Tasks); S.waitForShutdown(); } static void add_sps_wrapper(orc_rt_SessionRef S, uint64_t CallId, orc_rt_WrapperFunctionReturn Return, orc_rt_WrapperFunctionBuffer ArgBytes) { SPSWrapperFunction::handle( S, CallId, Return, ArgBytes, [](move_only_function Return, int32_t X, int32_t Y) { Return(X + Y); }); } TEST(ControllerAccessTest, ValidCallToController) { // Simulate a call to a controller handler. std::deque> Tasks; Session S(std::make_unique(Tasks), noErrors); auto CA = std::make_shared(S); S.setController(CA); int32_t Result = 0; SPSWrapperFunction::call( CallViaSession(S, reinterpret_cast(add_sps_wrapper)), [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EnqueueingDispatcher::runTasksFromFront(Tasks); EXPECT_EQ(Result, 42); S.waitForShutdown(); } TEST(ControllerAccessTest, CallToControllerBeforeAttach) { // Expect calls to the controller prior to attaching to fail. std::deque> Tasks; Session S(std::make_unique(Tasks), noErrors); Error Err = Error::success(); SPSWrapperFunction::call( CallViaSession(S, reinterpret_cast(add_sps_wrapper)), [&](Expected R) { ErrorAsOutParameter _(Err); Err = R.takeError(); }, 41, 1); EXPECT_EQ(toString(std::move(Err)), "no controller attached"); S.waitForShutdown(); } TEST(ControllerAccessTest, CallToControllerAfterDetach) { // Expect calls to the controller prior to attaching to fail. std::deque> Tasks; Session S(std::make_unique(Tasks), noErrors); auto CA = std::make_shared(S); S.setController(CA); S.detachFromController(); Error Err = Error::success(); SPSWrapperFunction::call( CallViaSession(S, reinterpret_cast(add_sps_wrapper)), [&](Expected R) { ErrorAsOutParameter _(Err); Err = R.takeError(); }, 41, 1); EXPECT_EQ(toString(std::move(Err)), "no controller attached"); S.waitForShutdown(); } TEST(ControllerAccessTest, CallFromController) { // Simulate a call from the controller. std::deque> Tasks; Session S(std::make_unique(Tasks), noErrors); auto CA = std::make_shared(S); S.setController(CA); int32_t Result = 0; SPSWrapperFunction::call( CallViaMockControllerAccess(*CA, add_sps_wrapper), [&](Expected R) { Result = cantFail(std::move(R)); }, 41, 1); EnqueueingDispatcher::runTasksFromFront(Tasks); EXPECT_EQ(Result, 42); S.waitForShutdown(); } TEST(ControllerAccessTest, RedundantAsyncShutdown) { // Check that redundant calls to shutdown have their callbacks run. std::deque> Tasks; Session S(std::make_unique(Tasks), noErrors); S.waitForShutdown(); bool RedundantCallbackRan = false; S.shutdown([&]() { RedundantCallbackRan = true; }); EXPECT_TRUE(RedundantCallbackRan); }