diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index bd41d077c903..6b114ee497a8 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -166,8 +166,7 @@ public: /// /// If an unexpected error occurs, the MainLoop will be terminated and a log /// message will include additional information about the termination reason. - virtual llvm::Expected - RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0; + virtual llvm::Error RegisterMessageHandler(MessageHandler &handler) = 0; protected: template inline auto Logv(const char *Fmt, Ts &&...Vals) { @@ -182,29 +181,27 @@ public: using Message = typename JSONTransport::Message; using MessageHandler = typename JSONTransport::MessageHandler; - IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) - : m_in(in), m_out(out) {} + IOTransport(MainLoop &loop, lldb::IOObjectSP in, lldb::IOObjectSP out) + : m_loop(loop), m_in(in), m_out(out) {} llvm::Error Send(const typename Proto::Evt &evt) override { return Write(evt); } + llvm::Error Send(const typename Proto::Req &req) override { return Write(req); } + llvm::Error Send(const typename Proto::Resp &resp) override { return Write(resp); } - llvm::Expected - RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override { + llvm::Error RegisterMessageHandler(MessageHandler &handler) override { Status status; - MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject( + m_read_handle = m_loop.RegisterReadObject( m_in, [this, &handler](MainLoopBase &base) { OnRead(base, handler); }, status); - if (status.Fail()) { - return status.takeError(); - } - return read_handle; + return status.takeError(); } /// Public for testing purposes, otherwise this should be an implementation @@ -263,11 +260,15 @@ private: handler.OnError(llvm::make_error( std::string(m_buffer.str()))); handler.OnClosed(); + // On EOF, remove the read handle from the MainLoop. + m_read_handle.reset(); } } + MainLoop &m_loop; lldb::IOObjectSP m_in; lldb::IOObjectSP m_out; + MainLoop::ReadHandleUP m_read_handle; }; /// A transport class for JSON with a HTTP header. diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index f185d51f4119..498c54bed780 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -21,6 +21,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" #include "llvm/Support/Signals.h" +#include #include #include #include @@ -40,7 +41,7 @@ public: void AddTool(std::unique_ptr tool); void AddResourceProvider(std::unique_ptr resource_provider); - llvm::Error Accept(lldb_private::MainLoop &, MCPTransportUP); + llvm::Error Accept(MCPTransportUP); protected: MCPBinderUP Bind(MCPTransport &); @@ -70,7 +71,6 @@ private: LogCallback m_log_callback; struct Client { - ReadHandleUP handle; MCPTransportUP transport; MCPBinderUP binder; }; diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h index b7a1eb778d66..ceadf1dbd82b 100644 --- a/lldb/include/lldb/Protocol/MCP/Transport.h +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -83,8 +83,8 @@ using LogCallback = llvm::unique_function; class Transport final : public lldb_private::transport::JSONRPCTransport { public: - Transport(lldb::IOObjectSP in, lldb::IOObjectSP out, - LogCallback log_callback = {}); + Transport(lldb_private::MainLoop &loop, lldb::IOObjectSP in, + lldb::IOObjectSP out, LogCallback log_callback = {}); virtual ~Transport() = default; /// Transport is not copyable. diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index 77a3ba6574cd..c92f80bc166b 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -66,11 +66,11 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { lldb::IOObjectSP io_sp = std::move(socket); auto transport_up = std::make_unique( - io_sp, io_sp, [client_name](llvm::StringRef message) { + m_loop, io_sp, io_sp, [client_name](llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message); }); - if (auto error = m_server->Accept(m_loop, std::move(transport_up))) + if (auto error = m_server->Accept(std::move(transport_up))) LLDB_LOG_ERROR(log, std::move(error), "{0}:"); } diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index e0f2a6ccea1f..abcd25133705 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -13,11 +13,12 @@ #include "lldb/Host/MainLoop.h" #include "lldb/Host/Socket.h" #include "lldb/Protocol/MCP/Server.h" -#include "lldb/Protocol/MCP/Transport.h" -#include +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include #include +#include #include -#include #include namespace lldb_private::mcp { diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index a8871ff5c39f..ecbea4b9022c 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -144,7 +144,7 @@ MCPBinderUP Server::Bind(MCPTransport &transport) { return binder_up; } -llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { +llvm::Error Server::Accept(MCPTransportUP transport) { MCPBinderUP binder = Bind(*transport); MCPTransport *transport_ptr = transport.get(); binder->OnDisconnect([this, transport_ptr]() { @@ -156,12 +156,10 @@ llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { Logv("Transport error: {0}", llvm::toString(std::move(err))); }); - auto handle = transport->RegisterMessageHandler(loop, *binder); - if (!handle) - return handle.takeError(); + if (llvm::Error err = transport->RegisterMessageHandler(*binder)) + return err; - m_instances[transport_ptr] = - Client{std::move(*handle), std::move(transport), std::move(binder)}; + m_instances[transport_ptr] = Client{std::move(transport), std::move(binder)}; return llvm::Error::success(); } diff --git a/lldb/source/Protocol/MCP/Transport.cpp b/lldb/source/Protocol/MCP/Transport.cpp index cccdc3b5bd65..1dc01a9f5900 100644 --- a/lldb/source/Protocol/MCP/Transport.cpp +++ b/lldb/source/Protocol/MCP/Transport.cpp @@ -13,9 +13,10 @@ using namespace lldb_protocol::mcp; using namespace llvm; -Transport::Transport(lldb::IOObjectSP in, lldb::IOObjectSP out, - LogCallback log_callback) - : JSONRPCTransport(in, out), m_log_callback(std::move(log_callback)) {} +Transport::Transport(lldb_private::MainLoop &loop, lldb::IOObjectSP in, + lldb::IOObjectSP out, LogCallback log_callback) + : JSONRPCTransport(loop, in, out), m_log_callback(std::move(log_callback)) { +} void Transport::Log(StringRef message) { if (m_log_callback) diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index 2eee318237d4..b76b05c5d145 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -1051,9 +1051,8 @@ void DAP::TransportHandler() { m_queue_cv.notify_all(); }); - auto handle = transport.RegisterMessageHandler(m_loop, *this); - if (!handle) { - DAP_LOG_ERROR(log, handle.takeError(), + if (llvm::Error err = transport.RegisterMessageHandler(*this)) { + DAP_LOG_ERROR(log, std::move(err), "registering message handler failed: {0}"); std::lock_guard guard(m_queue_mutex); m_error_occurred = true; diff --git a/lldb/tools/lldb-dap/Transport.cpp b/lldb/tools/lldb-dap/Transport.cpp index b3512385d657..b149a8ee8f02 100644 --- a/lldb/tools/lldb-dap/Transport.cpp +++ b/lldb/tools/lldb-dap/Transport.cpp @@ -17,9 +17,9 @@ using namespace lldb_private; namespace lldb_dap { -Transport::Transport(lldb_dap::Log &log, lldb::IOObjectSP input, - lldb::IOObjectSP output) - : HTTPDelimitedJSONTransport(input, output), m_log(log) {} +Transport::Transport(lldb_dap::Log &log, lldb_private::MainLoop &loop, + lldb::IOObjectSP input, lldb::IOObjectSP output) + : HTTPDelimitedJSONTransport(loop, input, output), m_log(log) {} void Transport::Log(llvm::StringRef message) { // Emit the message directly, since this log was forwarded. diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index b20a93475d2d..42f7caf93831 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -35,8 +35,8 @@ class Transport final : public lldb_private::transport::HTTPDelimitedJSONTransport< ProtocolDescriptor> { public: - Transport(lldb_dap::Log &log, lldb::IOObjectSP input, - lldb::IOObjectSP output); + Transport(lldb_dap::Log &log, lldb_private::MainLoop &loop, + lldb::IOObjectSP input, lldb::IOObjectSP output); virtual ~Transport() = default; void Log(llvm::StringRef message) override; diff --git a/lldb/tools/lldb-dap/tool/lldb-dap.cpp b/lldb/tools/lldb-dap/tool/lldb-dap.cpp index 15c63543e86f..babc3c98646c 100644 --- a/lldb/tools/lldb-dap/tool/lldb-dap.cpp +++ b/lldb/tools/lldb-dap/tool/lldb-dap.cpp @@ -47,7 +47,6 @@ #include "llvm/Support/Threading.h" #include "llvm/Support/WithColor.h" #include "llvm/Support/raw_ostream.h" -#include #include #include #include @@ -463,7 +462,7 @@ static llvm::Error serveConnection( DAP_LOG(client_log, "client connected"); MainLoop loop; - Transport transport(client_log, io, io); + Transport transport(client_log, loop, io, io); DAP dap(client_log, default_repl_mode, pre_init_commands, no_lldbinit, client_name, transport, loop); @@ -738,7 +737,7 @@ int main(int argc, char *argv[]) { constexpr llvm::StringLiteral client_name = "stdio"; MainLoop loop; Log client_log = log.WithPrefix("(stdio)"); - Transport transport(client_log, input, output); + Transport transport(client_log, loop, input, output); DAP dap(client_log, default_repl_mode, pre_init_commands, no_lldbinit, client_name, transport, loop); diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index a9231085637c..1afac18833a0 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -35,7 +35,7 @@ using lldb_private::MainLoop; using lldb_private::Pipe; void TransportBase::SetUp() { - std::tie(to_client, to_server) = TestDAPTransport::createPair(); + std::tie(to_client, to_server) = TestDAPTransport::createPair(loop); log = std::make_unique(llvm::outs(), log_mutex); dap = std::make_unique( @@ -46,13 +46,8 @@ void TransportBase::SetUp() { /*client_name=*/"test_client", /*transport=*/*to_client, /*loop=*/loop); - auto server_handle = to_server->RegisterMessageHandler(loop, *dap); - EXPECT_THAT_EXPECTED(server_handle, Succeeded()); - handles[0] = std::move(*server_handle); - - auto client_handle = to_client->RegisterMessageHandler(loop, client); - EXPECT_THAT_EXPECTED(client_handle, Succeeded()); - handles[1] = std::move(*client_handle); + EXPECT_THAT_ERROR(to_server->RegisterMessageHandler(*dap), Succeeded()); + EXPECT_THAT_ERROR(to_client->RegisterMessageHandler(client), Succeeded()); } void TransportBase::Run() { diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index f1c7e6b98972..c35482937743 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -59,7 +59,6 @@ protected: lldb_private::SubsystemRAII subsystems; lldb_private::MainLoop loop; - lldb_private::MainLoop::ReadHandleUP handles[2]; std::unique_ptr log; lldb_dap::Log::Mutex log_mutex; diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 710907af3794..2c26f9421377 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -247,19 +247,22 @@ template class JSONTransportTest : public PipePairTest { protected: SubsystemRAII subsystems; + MainLoop loop; test_protocol::MessageHandler message_handler; std::unique_ptr transport; - MainLoop loop; void SetUp() override { PipePairTest::SetUp(); transport = std::make_unique( - std::make_shared(input.GetReadFileDescriptor(), + loop, + std::make_shared(input.ReleaseReadFileDescriptor(), File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(output.GetWriteFileDescriptor(), + NativeFile::Owned), + std::make_shared(output.ReleaseWriteFileDescriptor(), File::eOpenOptionWriteOnly, - NativeFile::Unowned)); + NativeFile::Owned)); + EXPECT_THAT_ERROR(transport->RegisterMessageHandler(message_handler), + Succeeded()); } /// Run the transport MainLoop and return any messages received. @@ -272,17 +275,13 @@ protected: loop.RequestTermination(); }); } - bool addition_succeeded = loop.AddCallback( + bool registered_timeout = loop.AddCallback( [](MainLoopBase &loop) { loop.RequestTermination(); FAIL() << "timeout"; }, timeout); - EXPECT_TRUE(addition_succeeded); - auto handle = transport->RegisterMessageHandler(loop, message_handler); - if (!handle) - return handle.takeError(); - + EXPECT_TRUE(registered_timeout); return loop.Run().takeError(); } @@ -360,14 +359,13 @@ protected: MainLoop loop; void SetUp() override { - std::tie(to_remote, from_remote) = test_protocol::Transport::createPair(); + std::tie(to_remote, from_remote) = + test_protocol::Transport::createPair(loop); binder = std::make_unique(*to_remote); - auto binder_handle = to_remote->RegisterMessageHandler(loop, remote); - EXPECT_THAT_EXPECTED(binder_handle, Succeeded()); - - auto remote_handle = from_remote->RegisterMessageHandler(loop, *binder); - EXPECT_THAT_EXPECTED(remote_handle, Succeeded()); + EXPECT_THAT_ERROR(to_remote->RegisterMessageHandler(remote), Succeeded()); + EXPECT_THAT_ERROR(from_remote->RegisterMessageHandler(*binder), + Succeeded()); } void Run() { @@ -502,8 +500,8 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { transport = - std::make_unique(nullptr, nullptr); - ASSERT_THAT_ERROR(Run(/*close_input=*/false), + std::make_unique(loop, nullptr, nullptr); + ASSERT_THAT_ERROR(transport->RegisterMessageHandler(message_handler), FailedWithMessage("IO object is not valid.")); } @@ -624,8 +622,8 @@ TEST_F(JSONRPCTransportTest, Write) { } TEST_F(JSONRPCTransportTest, InvalidTransport) { - transport = std::make_unique(nullptr, nullptr); - ASSERT_THAT_ERROR(Run(/*close_input=*/false), + transport = std::make_unique(loop, nullptr, nullptr); + ASSERT_THAT_ERROR(transport->RegisterMessageHandler(message_handler), FailedWithMessage("IO object is not valid.")); } diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index 97f32e2fbb1b..9a5b75edeeb9 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -157,22 +157,18 @@ public: } void SetUp() override { - std::tie(to_client, to_server) = Transport::createPair(); + std::tie(to_client, to_server) = Transport::createPair(loop); server_up = std::make_unique( "lldb-mcp", "0.1.0", [this](StringRef msg) { logged_messages.push_back(msg.str()); }); binder = server_up->Bind(*to_client); - auto server_handle = to_server->RegisterMessageHandler(loop, *binder); - EXPECT_THAT_EXPECTED(server_handle, Succeeded()); binder->OnError([](llvm::Error error) { llvm::errs() << formatv("Server transport error: {0}", error); }); - handles[0] = std::move(*server_handle); - auto client_handle = to_client->RegisterMessageHandler(loop, client); - EXPECT_THAT_EXPECTED(client_handle, Succeeded()); - handles[1] = std::move(*client_handle); + EXPECT_THAT_ERROR(to_server->RegisterMessageHandler(*binder), Succeeded()); + EXPECT_THAT_ERROR(to_client->RegisterMessageHandler(client), Succeeded()); } template diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h index bacf8ca36aa0..4623c365c960 100644 --- a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -30,72 +30,51 @@ public: static std::pair>, std::unique_ptr>> - createPair() { + createPair(lldb_private::MainLoop &loop) { std::unique_ptr> transports[2] = { - std::make_unique>(), - std::make_unique>()}; + std::make_unique>(loop), + std::make_unique>(loop)}; return std::make_pair(std::move(transports[0]), std::move(transports[1])); } - explicit TestTransport() { - llvm::Expected dummy_file = - lldb_private::FileSystem::Instance().Open( - lldb_private::FileSpec(lldb_private::FileSystem::DEV_NULL), - lldb_private::File::eOpenOptionReadWrite); - EXPECT_THAT_EXPECTED(dummy_file, llvm::Succeeded()); - m_dummy_file = std::move(*dummy_file); - } + explicit TestTransport(lldb_private::MainLoop &loop) : m_loop(loop) {} llvm::Error Send(const typename Proto::Evt &evt) override { - EXPECT_TRUE(m_loop && m_handler) - << "Send called before RegisterMessageHandler"; - m_loop->AddPendingCallback([this, evt](lldb_private::MainLoopBase &) { + EXPECT_TRUE(m_handler) << "Send called before RegisterMessageHandler"; + m_loop.AddPendingCallback([this, evt](lldb_private::MainLoopBase &) { m_handler->Received(evt); }); return llvm::Error::success(); } llvm::Error Send(const typename Proto::Req &req) override { - EXPECT_TRUE(m_loop && m_handler) - << "Send called before RegisterMessageHandler"; - m_loop->AddPendingCallback([this, req](lldb_private::MainLoopBase &) { + EXPECT_TRUE(m_handler) << "Send called before RegisterMessageHandler"; + m_loop.AddPendingCallback([this, req](lldb_private::MainLoopBase &) { m_handler->Received(req); }); return llvm::Error::success(); } llvm::Error Send(const typename Proto::Resp &resp) override { - EXPECT_TRUE(m_loop && m_handler) - << "Send called before RegisterMessageHandler"; - m_loop->AddPendingCallback([this, resp](lldb_private::MainLoopBase &) { + EXPECT_TRUE(m_handler) << "Send called before RegisterMessageHandler"; + m_loop.AddPendingCallback([this, resp](lldb_private::MainLoopBase &) { m_handler->Received(resp); }); return llvm::Error::success(); } - llvm::Expected - RegisterMessageHandler(lldb_private::MainLoop &loop, - MessageHandler &handler) override { - if (!m_loop) - m_loop = &loop; + llvm::Error RegisterMessageHandler(MessageHandler &handler) override { if (!m_handler) m_handler = &handler; - lldb_private::Status status; - auto handle = loop.RegisterReadObject( - m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); - if (status.Fail()) - return status.takeError(); - return handle; + return llvm::Error::success(); } protected: void Log(llvm::StringRef message) override {}; private: - lldb_private::MainLoop *m_loop = nullptr; + lldb_private::MainLoop &m_loop; MessageHandler *m_handler = nullptr; - // Dummy file for registering with the MainLoop. - lldb::FileSP m_dummy_file = nullptr; }; template