[lldb] Adjusting the base MCP protocol types per the spec. (#153297)

* This adjusts the `Request`/`Response` types to have an `id` that is
either a string or a number.
* Merges 'Error' into 'Response' to have a single response type that
represents both errors and results.
* Adjusts the `Error.data` field to by any JSON value.
* Adds `operator==` support to the base protocol types and simplifies
the tests.
This commit is contained in:
John Harrison 2025-08-12 17:56:52 -07:00 committed by GitHub
parent c14ca4520f
commit 350f6abb83
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 174 additions and 124 deletions

View File

@ -26,7 +26,7 @@ public:
const std::string &getMessage() const { return m_message; }
lldb_protocol::mcp::Error toProtcolError() const;
lldb_protocol::mcp::Error toProtocolError() const;
static constexpr int64_t kResourceNotFound = -32002;
static constexpr int64_t kInternalError = -32603;

View File

@ -23,50 +23,72 @@ namespace lldb_protocol::mcp {
static llvm::StringLiteral kProtocolVersion = "2024-11-05";
/// A Request or Response 'id'.
///
/// NOTE: This differs from the JSON-RPC 2.0 spec. The MCP spec says this must
/// be a string or number, excluding a json 'null' as a valid id.
using Id = std::variant<int64_t, std::string>;
/// A request that expects a response.
struct Request {
uint64_t id = 0;
/// The request id.
Id id = 0;
/// The method to be invoked.
std::string method;
/// The method's params.
std::optional<llvm::json::Value> params;
};
llvm::json::Value toJSON(const Request &);
bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path);
struct ErrorInfo {
int64_t code = 0;
std::string message;
std::string data;
};
llvm::json::Value toJSON(const ErrorInfo &);
bool fromJSON(const llvm::json::Value &, ErrorInfo &, llvm::json::Path);
bool operator==(const Request &, const Request &);
struct Error {
uint64_t id = 0;
ErrorInfo error;
/// The error type that occurred.
int64_t code = 0;
/// A short description of the error. The message SHOULD be limited to a
/// concise single sentence.
std::string message;
/// Additional information about the error. The value of this member is
/// defined by the sender (e.g. detailed error information, nested errors
/// etc.).
std::optional<llvm::json::Value> data;
};
llvm::json::Value toJSON(const Error &);
bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path);
bool operator==(const Error &, const Error &);
/// A response to a request, either an error or a result.
struct Response {
uint64_t id = 0;
std::optional<llvm::json::Value> result;
std::optional<ErrorInfo> error;
/// The request id.
Id id = 0;
/// The result of the request, either an Error or the JSON value of the
/// response.
std::variant<Error, llvm::json::Value> result;
};
llvm::json::Value toJSON(const Response &);
bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path);
bool operator==(const Response &, const Response &);
/// A notification which does not expect a response.
struct Notification {
/// The method to be invoked.
std::string method;
/// The notification's params.
std::optional<llvm::json::Value> params;
};
llvm::json::Value toJSON(const Notification &);
bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path);
bool operator==(const Notification &, const Notification &);
/// A general message as defined by the JSON-RPC 2.0 spec.
using Message = std::variant<Request, Response, Notification>;
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
llvm::json::Value toJSON(const Message &);
struct ToolCapability {
/// Whether this server supports notifications for changes to the tool list.
@ -176,11 +198,6 @@ struct ToolDefinition {
llvm::json::Value toJSON(const ToolDefinition &);
bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path);
using Message = std::variant<Request, Response, Notification, Error>;
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
llvm::json::Value toJSON(const Message &);
using ToolArguments = std::variant<std::monostate, llvm::json::Value>;
} // namespace lldb_protocol::mcp

View File

@ -25,10 +25,10 @@ std::error_code MCPError::convertToErrorCode() const {
return llvm::inconvertibleErrorCode();
}
lldb_protocol::mcp::Error MCPError::toProtcolError() const {
lldb_protocol::mcp::Error MCPError::toProtocolError() const {
lldb_protocol::mcp::Error error;
error.error.code = m_error_code;
error.error.message = m_message;
error.code = m_error_code;
error.message = m_message;
return error;
}

View File

@ -26,8 +26,45 @@ static bool mapRaw(const json::Value &Params, StringLiteral Prop,
return true;
}
static llvm::json::Value toJSON(const Id &Id) {
if (const int64_t *I = std::get_if<int64_t>(&Id))
return json::Value(*I);
if (const std::string *S = std::get_if<std::string>(&Id))
return json::Value(*S);
llvm_unreachable("unexpected type in protocol::Id");
}
static bool mapId(const llvm::json::Value &V, StringLiteral Prop, Id &Id,
llvm::json::Path P) {
const auto *O = V.getAsObject();
if (!O) {
P.report("expected object");
return false;
}
const auto *E = O->get(Prop);
if (!E) {
P.field(Prop).report("not found");
return false;
}
if (auto S = E->getAsString()) {
Id = S->str();
return true;
}
if (auto I = E->getAsInteger()) {
Id = *I;
return true;
}
P.report("expected string or number");
return false;
}
llvm::json::Value toJSON(const Request &R) {
json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}};
json::Object Result{
{"jsonrpc", "2.0"}, {"id", toJSON(R.id)}, {"method", R.method}};
if (R.params)
Result.insert({"params", R.params});
return Result;
@ -35,47 +72,75 @@ llvm::json::Value toJSON(const Request &R) {
bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) {
llvm::json::ObjectMapper O(V, P);
if (!O || !O.map("id", R.id) || !O.map("method", R.method))
return false;
return mapRaw(V, "params", R.params, P);
return O && mapId(V, "id", R.id, P) && O.map("method", R.method) &&
mapRaw(V, "params", R.params, P);
}
llvm::json::Value toJSON(const ErrorInfo &EI) {
llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}};
if (!EI.data.empty())
Result.insert({"data", EI.data});
return Result;
}
bool fromJSON(const llvm::json::Value &V, ErrorInfo &EI, llvm::json::Path P) {
llvm::json::ObjectMapper O(V, P);
return O && O.map("code", EI.code) && O.map("message", EI.message) &&
O.mapOptional("data", EI.data);
bool operator==(const Request &a, const Request &b) {
return a.id == b.id && a.method == b.method && a.params == b.params;
}
llvm::json::Value toJSON(const Error &E) {
return json::Object{{"jsonrpc", "2.0"}, {"id", E.id}, {"error", E.error}};
llvm::json::Object Result{{"code", E.code}, {"message", E.message}};
if (E.data)
Result.insert({"data", *E.data});
return Result;
}
bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) {
llvm::json::ObjectMapper O(V, P);
return O && O.map("id", E.id) && O.map("error", E.error);
return O && O.map("code", E.code) && O.map("message", E.message) &&
mapRaw(V, "data", E.data, P);
}
bool operator==(const Error &a, const Error &b) {
return a.code == b.code && a.message == b.message && a.data == b.data;
}
llvm::json::Value toJSON(const Response &R) {
llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}};
if (R.result)
Result.insert({"result", R.result});
if (R.error)
Result.insert({"error", R.error});
llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", toJSON(R.id)}};
if (const Error *error = std::get_if<Error>(&R.result))
Result.insert({"error", *error});
if (const json::Value *result = std::get_if<json::Value>(&R.result))
Result.insert({"result", *result});
return Result;
}
bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) {
llvm::json::ObjectMapper O(V, P);
if (!O || !O.map("id", R.id) || !O.map("error", R.error))
const json::Object *E = V.getAsObject();
if (!E) {
P.report("expected object");
return false;
return mapRaw(V, "result", R.result, P);
}
const json::Value *result = E->get("result");
const json::Value *raw_error = E->get("error");
if (result && raw_error) {
P.report("'result' and 'error' fields are mutually exclusive");
return false;
}
if (!result && !raw_error) {
P.report("'result' or 'error' fields are required'");
return false;
}
if (result) {
R.result = std::move(*result);
} else {
Error error;
if (!fromJSON(*raw_error, error, P))
return false;
R.result = std::move(error);
}
return mapId(V, "id", R.id, P);
}
bool operator==(const Response &a, const Response &b) {
return a.id == b.id && a.result == b.result;
}
llvm::json::Value toJSON(const Notification &N) {
@ -97,6 +162,10 @@ bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) {
return true;
}
bool operator==(const Notification &a, const Notification &b) {
return a.method == b.method && a.params == b.params;
}
llvm::json::Value toJSON(const ToolCapability &TC) {
return llvm::json::Object{{"listChanged", TC.listChanged}};
}
@ -235,24 +304,16 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) {
return true;
}
if (O->get("error")) {
Error E;
if (!fromJSON(V, E, P))
return false;
M = std::move(E);
return true;
}
if (O->get("result")) {
Response R;
if (O->get("method")) {
Request R;
if (!fromJSON(V, R, P))
return false;
M = std::move(R);
return true;
}
if (O->get("method")) {
Request R;
if (O->get("result") || O->get("error")) {
Response R;
if (!fromJSON(V, R, P))
return false;
M = std::move(R);

View File

@ -66,13 +66,15 @@ Server::HandleData(llvm::StringRef data) {
Error protocol_error;
llvm::handleAllErrors(
response.takeError(),
[&](const MCPError &err) { protocol_error = err.toProtcolError(); },
[&](const MCPError &err) { protocol_error = err.toProtocolError(); },
[&](const llvm::ErrorInfoBase &err) {
protocol_error.error.code = MCPError::kInternalError;
protocol_error.error.message = err.message();
protocol_error.code = MCPError::kInternalError;
protocol_error.message = err.message();
});
protocol_error.id = request->id;
return protocol_error;
Response error_response;
error_response.id = request->id;
error_response.result = std::move(protocol_error);
return error_response;
}
return *response;
@ -84,9 +86,6 @@ Server::HandleData(llvm::StringRef data) {
return std::nullopt;
}
if (std::get_if<Error>(&(*message)))
return llvm::createStringError("unexpected MCP message: error");
if (std::get_if<Response>(&(*message)))
return llvm::createStringError("unexpected MCP message: response");
@ -123,11 +122,11 @@ void Server::AddNotificationHandler(llvm::StringRef method,
llvm::Expected<Response> Server::InitializeHandler(const Request &request) {
Response response;
response.result.emplace(llvm::json::Object{
response.result = llvm::json::Object{
{"protocolVersion", mcp::kProtocolVersion},
{"capabilities", GetCapabilities()},
{"serverInfo",
llvm::json::Object{{"name", m_name}, {"version", m_version}}}});
llvm::json::Object{{"name", m_name}, {"version", m_version}}}};
return response;
}
@ -138,7 +137,7 @@ llvm::Expected<Response> Server::ToolsListHandler(const Request &request) {
for (const auto &tool : m_tools)
tools.emplace_back(toJSON(tool.second->GetDefinition()));
response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}});
response.result = llvm::json::Object{{"tools", std::move(tools)}};
return response;
}
@ -173,7 +172,7 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
if (!text_result)
return text_result.takeError();
response.result.emplace(toJSON(*text_result));
response.result = toJSON(*text_result);
return response;
}
@ -189,8 +188,7 @@ llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) {
for (const Resource &resource : resource_provider_up->GetResources())
resources.push_back(resource);
}
response.result.emplace(
llvm::json::Object{{"resources", std::move(resources)}});
response.result = llvm::json::Object{{"resources", std::move(resources)}};
return response;
}
@ -226,7 +224,7 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
return result.takeError();
Response response;
response.result.emplace(std::move(*result));
response.result = std::move(*result);
return response;
}

View File

@ -149,9 +149,7 @@ TEST(ProtocolMCPTest, MessageWithRequest) {
const Request &deserialized_request =
std::get<Request>(*deserialized_message);
EXPECT_EQ(request.id, deserialized_request.id);
EXPECT_EQ(request.method, deserialized_request.method);
EXPECT_EQ(request.params, deserialized_request.params);
EXPECT_EQ(request, deserialized_request);
}
TEST(ProtocolMCPTest, MessageWithResponse) {
@ -168,8 +166,7 @@ TEST(ProtocolMCPTest, MessageWithResponse) {
const Response &deserialized_response =
std::get<Response>(*deserialized_message);
EXPECT_EQ(response.id, deserialized_response.id);
EXPECT_EQ(response.result, deserialized_response.result);
EXPECT_EQ(response, deserialized_response);
}
TEST(ProtocolMCPTest, MessageWithNotification) {
@ -186,49 +183,28 @@ TEST(ProtocolMCPTest, MessageWithNotification) {
const Notification &deserialized_notification =
std::get<Notification>(*deserialized_message);
EXPECT_EQ(notification.method, deserialized_notification.method);
EXPECT_EQ(notification.params, deserialized_notification.params);
EXPECT_EQ(notification, deserialized_notification);
}
TEST(ProtocolMCPTest, MessageWithError) {
ErrorInfo error_info;
error_info.code = -32603;
error_info.message = "Internal error";
TEST(ProtocolMCPTest, MessageWithErrorResponse) {
Error error;
error.id = 3;
error.error = error_info;
error.code = -32603;
error.message = "Internal error";
Message message = error;
Response error_response;
error_response.id = 3;
error_response.result = error;
Message message = error_response;
llvm::Expected<Message> deserialized_message = roundtripJSON(message);
ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded());
ASSERT_TRUE(std::holds_alternative<Error>(*deserialized_message));
const Error &deserialized_error = std::get<Error>(*deserialized_message);
ASSERT_TRUE(std::holds_alternative<Response>(*deserialized_message));
const Response &deserialized_error =
std::get<Response>(*deserialized_message);
EXPECT_EQ(error.id, deserialized_error.id);
EXPECT_EQ(error.error.code, deserialized_error.error.code);
EXPECT_EQ(error.error.message, deserialized_error.error.message);
}
TEST(ProtocolMCPTest, ResponseWithError) {
ErrorInfo error_info;
error_info.code = -32700;
error_info.message = "Parse error";
Response response;
response.id = 4;
response.error = error_info;
llvm::Expected<Response> deserialized_response = roundtripJSON(response);
ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded());
EXPECT_EQ(response.id, deserialized_response->id);
EXPECT_FALSE(deserialized_response->result.has_value());
ASSERT_TRUE(deserialized_response->error.has_value());
EXPECT_EQ(response.error->code, deserialized_response->error->code);
EXPECT_EQ(response.error->message, deserialized_response->error->message);
EXPECT_EQ(error_response, deserialized_error);
}
TEST(ProtocolMCPTest, Resource) {

View File

@ -200,7 +200,7 @@ public:
} // namespace
TEST_F(ProtocolServerMCPTest, Intialization) {
TEST_F(ProtocolServerMCPTest, Initialization) {
llvm::StringLiteral request =
R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json";
llvm::StringLiteral response =

View File

@ -11,11 +11,11 @@
#include "lldb/Core/ModuleSpec.h"
#include "lldb/Utility/DataBuffer.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <string>
#define ASSERT_NO_ERROR(x) \
@ -61,12 +61,10 @@ private:
};
template <typename T> static llvm::Expected<T> roundtripJSON(const T &input) {
llvm::json::Value value = toJSON(input);
llvm::json::Path::Root root;
T output;
if (!fromJSON(value, output, root))
return root.getError();
return output;
std::string encoded;
llvm::raw_string_ostream OS(encoded);
OS << toJSON(input);
return llvm::json::parse<T>(encoded);
}
} // namespace lldb_private