[lldb] Adjusting the base MCP protocol types per the spec. (#153297)
* This adjusts the `Request`/`Response` types to have an `id` that is either a string or a number. * Merges 'Error' into 'Response' to have a single response type that represents both errors and results. * Adjusts the `Error.data` field to by any JSON value. * Adds `operator==` support to the base protocol types and simplifies the tests.
This commit is contained in:
parent
c14ca4520f
commit
350f6abb83
@ -26,7 +26,7 @@ public:
|
||||
|
||||
const std::string &getMessage() const { return m_message; }
|
||||
|
||||
lldb_protocol::mcp::Error toProtcolError() const;
|
||||
lldb_protocol::mcp::Error toProtocolError() const;
|
||||
|
||||
static constexpr int64_t kResourceNotFound = -32002;
|
||||
static constexpr int64_t kInternalError = -32603;
|
||||
|
@ -23,50 +23,72 @@ namespace lldb_protocol::mcp {
|
||||
|
||||
static llvm::StringLiteral kProtocolVersion = "2024-11-05";
|
||||
|
||||
/// A Request or Response 'id'.
|
||||
///
|
||||
/// NOTE: This differs from the JSON-RPC 2.0 spec. The MCP spec says this must
|
||||
/// be a string or number, excluding a json 'null' as a valid id.
|
||||
using Id = std::variant<int64_t, std::string>;
|
||||
|
||||
/// A request that expects a response.
|
||||
struct Request {
|
||||
uint64_t id = 0;
|
||||
/// The request id.
|
||||
Id id = 0;
|
||||
/// The method to be invoked.
|
||||
std::string method;
|
||||
/// The method's params.
|
||||
std::optional<llvm::json::Value> params;
|
||||
};
|
||||
|
||||
llvm::json::Value toJSON(const Request &);
|
||||
bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path);
|
||||
|
||||
struct ErrorInfo {
|
||||
int64_t code = 0;
|
||||
std::string message;
|
||||
std::string data;
|
||||
};
|
||||
|
||||
llvm::json::Value toJSON(const ErrorInfo &);
|
||||
bool fromJSON(const llvm::json::Value &, ErrorInfo &, llvm::json::Path);
|
||||
bool operator==(const Request &, const Request &);
|
||||
|
||||
struct Error {
|
||||
uint64_t id = 0;
|
||||
ErrorInfo error;
|
||||
/// The error type that occurred.
|
||||
int64_t code = 0;
|
||||
/// A short description of the error. The message SHOULD be limited to a
|
||||
/// concise single sentence.
|
||||
std::string message;
|
||||
/// Additional information about the error. The value of this member is
|
||||
/// defined by the sender (e.g. detailed error information, nested errors
|
||||
/// etc.).
|
||||
std::optional<llvm::json::Value> data;
|
||||
};
|
||||
|
||||
llvm::json::Value toJSON(const Error &);
|
||||
bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path);
|
||||
bool operator==(const Error &, const Error &);
|
||||
|
||||
/// A response to a request, either an error or a result.
|
||||
struct Response {
|
||||
uint64_t id = 0;
|
||||
std::optional<llvm::json::Value> result;
|
||||
std::optional<ErrorInfo> error;
|
||||
/// The request id.
|
||||
Id id = 0;
|
||||
/// The result of the request, either an Error or the JSON value of the
|
||||
/// response.
|
||||
std::variant<Error, llvm::json::Value> result;
|
||||
};
|
||||
|
||||
llvm::json::Value toJSON(const Response &);
|
||||
bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path);
|
||||
bool operator==(const Response &, const Response &);
|
||||
|
||||
/// A notification which does not expect a response.
|
||||
struct Notification {
|
||||
/// The method to be invoked.
|
||||
std::string method;
|
||||
/// The notification's params.
|
||||
std::optional<llvm::json::Value> params;
|
||||
};
|
||||
|
||||
llvm::json::Value toJSON(const Notification &);
|
||||
bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path);
|
||||
bool operator==(const Notification &, const Notification &);
|
||||
|
||||
/// A general message as defined by the JSON-RPC 2.0 spec.
|
||||
using Message = std::variant<Request, Response, Notification>;
|
||||
|
||||
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const Message &);
|
||||
|
||||
struct ToolCapability {
|
||||
/// Whether this server supports notifications for changes to the tool list.
|
||||
@ -176,11 +198,6 @@ struct ToolDefinition {
|
||||
llvm::json::Value toJSON(const ToolDefinition &);
|
||||
bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path);
|
||||
|
||||
using Message = std::variant<Request, Response, Notification, Error>;
|
||||
|
||||
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
|
||||
llvm::json::Value toJSON(const Message &);
|
||||
|
||||
using ToolArguments = std::variant<std::monostate, llvm::json::Value>;
|
||||
|
||||
} // namespace lldb_protocol::mcp
|
||||
|
@ -25,10 +25,10 @@ std::error_code MCPError::convertToErrorCode() const {
|
||||
return llvm::inconvertibleErrorCode();
|
||||
}
|
||||
|
||||
lldb_protocol::mcp::Error MCPError::toProtcolError() const {
|
||||
lldb_protocol::mcp::Error MCPError::toProtocolError() const {
|
||||
lldb_protocol::mcp::Error error;
|
||||
error.error.code = m_error_code;
|
||||
error.error.message = m_message;
|
||||
error.code = m_error_code;
|
||||
error.message = m_message;
|
||||
return error;
|
||||
}
|
||||
|
||||
|
@ -26,8 +26,45 @@ static bool mapRaw(const json::Value &Params, StringLiteral Prop,
|
||||
return true;
|
||||
}
|
||||
|
||||
static llvm::json::Value toJSON(const Id &Id) {
|
||||
if (const int64_t *I = std::get_if<int64_t>(&Id))
|
||||
return json::Value(*I);
|
||||
if (const std::string *S = std::get_if<std::string>(&Id))
|
||||
return json::Value(*S);
|
||||
llvm_unreachable("unexpected type in protocol::Id");
|
||||
}
|
||||
|
||||
static bool mapId(const llvm::json::Value &V, StringLiteral Prop, Id &Id,
|
||||
llvm::json::Path P) {
|
||||
const auto *O = V.getAsObject();
|
||||
if (!O) {
|
||||
P.report("expected object");
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto *E = O->get(Prop);
|
||||
if (!E) {
|
||||
P.field(Prop).report("not found");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (auto S = E->getAsString()) {
|
||||
Id = S->str();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (auto I = E->getAsInteger()) {
|
||||
Id = *I;
|
||||
return true;
|
||||
}
|
||||
|
||||
P.report("expected string or number");
|
||||
return false;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const Request &R) {
|
||||
json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}};
|
||||
json::Object Result{
|
||||
{"jsonrpc", "2.0"}, {"id", toJSON(R.id)}, {"method", R.method}};
|
||||
if (R.params)
|
||||
Result.insert({"params", R.params});
|
||||
return Result;
|
||||
@ -35,47 +72,75 @@ llvm::json::Value toJSON(const Request &R) {
|
||||
|
||||
bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) {
|
||||
llvm::json::ObjectMapper O(V, P);
|
||||
if (!O || !O.map("id", R.id) || !O.map("method", R.method))
|
||||
return false;
|
||||
return mapRaw(V, "params", R.params, P);
|
||||
return O && mapId(V, "id", R.id, P) && O.map("method", R.method) &&
|
||||
mapRaw(V, "params", R.params, P);
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const ErrorInfo &EI) {
|
||||
llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}};
|
||||
if (!EI.data.empty())
|
||||
Result.insert({"data", EI.data});
|
||||
return Result;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value &V, ErrorInfo &EI, llvm::json::Path P) {
|
||||
llvm::json::ObjectMapper O(V, P);
|
||||
return O && O.map("code", EI.code) && O.map("message", EI.message) &&
|
||||
O.mapOptional("data", EI.data);
|
||||
bool operator==(const Request &a, const Request &b) {
|
||||
return a.id == b.id && a.method == b.method && a.params == b.params;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const Error &E) {
|
||||
return json::Object{{"jsonrpc", "2.0"}, {"id", E.id}, {"error", E.error}};
|
||||
llvm::json::Object Result{{"code", E.code}, {"message", E.message}};
|
||||
if (E.data)
|
||||
Result.insert({"data", *E.data});
|
||||
return Result;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) {
|
||||
llvm::json::ObjectMapper O(V, P);
|
||||
return O && O.map("id", E.id) && O.map("error", E.error);
|
||||
return O && O.map("code", E.code) && O.map("message", E.message) &&
|
||||
mapRaw(V, "data", E.data, P);
|
||||
}
|
||||
|
||||
bool operator==(const Error &a, const Error &b) {
|
||||
return a.code == b.code && a.message == b.message && a.data == b.data;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const Response &R) {
|
||||
llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}};
|
||||
if (R.result)
|
||||
Result.insert({"result", R.result});
|
||||
if (R.error)
|
||||
Result.insert({"error", R.error});
|
||||
llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", toJSON(R.id)}};
|
||||
|
||||
if (const Error *error = std::get_if<Error>(&R.result))
|
||||
Result.insert({"error", *error});
|
||||
if (const json::Value *result = std::get_if<json::Value>(&R.result))
|
||||
Result.insert({"result", *result});
|
||||
return Result;
|
||||
}
|
||||
|
||||
bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) {
|
||||
llvm::json::ObjectMapper O(V, P);
|
||||
if (!O || !O.map("id", R.id) || !O.map("error", R.error))
|
||||
const json::Object *E = V.getAsObject();
|
||||
if (!E) {
|
||||
P.report("expected object");
|
||||
return false;
|
||||
return mapRaw(V, "result", R.result, P);
|
||||
}
|
||||
|
||||
const json::Value *result = E->get("result");
|
||||
const json::Value *raw_error = E->get("error");
|
||||
|
||||
if (result && raw_error) {
|
||||
P.report("'result' and 'error' fields are mutually exclusive");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!result && !raw_error) {
|
||||
P.report("'result' or 'error' fields are required'");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (result) {
|
||||
R.result = std::move(*result);
|
||||
} else {
|
||||
Error error;
|
||||
if (!fromJSON(*raw_error, error, P))
|
||||
return false;
|
||||
R.result = std::move(error);
|
||||
}
|
||||
|
||||
return mapId(V, "id", R.id, P);
|
||||
}
|
||||
|
||||
bool operator==(const Response &a, const Response &b) {
|
||||
return a.id == b.id && a.result == b.result;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const Notification &N) {
|
||||
@ -97,6 +162,10 @@ bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool operator==(const Notification &a, const Notification &b) {
|
||||
return a.method == b.method && a.params == b.params;
|
||||
}
|
||||
|
||||
llvm::json::Value toJSON(const ToolCapability &TC) {
|
||||
return llvm::json::Object{{"listChanged", TC.listChanged}};
|
||||
}
|
||||
@ -235,24 +304,16 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (O->get("error")) {
|
||||
Error E;
|
||||
if (!fromJSON(V, E, P))
|
||||
return false;
|
||||
M = std::move(E);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (O->get("result")) {
|
||||
Response R;
|
||||
if (O->get("method")) {
|
||||
Request R;
|
||||
if (!fromJSON(V, R, P))
|
||||
return false;
|
||||
M = std::move(R);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (O->get("method")) {
|
||||
Request R;
|
||||
if (O->get("result") || O->get("error")) {
|
||||
Response R;
|
||||
if (!fromJSON(V, R, P))
|
||||
return false;
|
||||
M = std::move(R);
|
||||
|
@ -66,13 +66,15 @@ Server::HandleData(llvm::StringRef data) {
|
||||
Error protocol_error;
|
||||
llvm::handleAllErrors(
|
||||
response.takeError(),
|
||||
[&](const MCPError &err) { protocol_error = err.toProtcolError(); },
|
||||
[&](const MCPError &err) { protocol_error = err.toProtocolError(); },
|
||||
[&](const llvm::ErrorInfoBase &err) {
|
||||
protocol_error.error.code = MCPError::kInternalError;
|
||||
protocol_error.error.message = err.message();
|
||||
protocol_error.code = MCPError::kInternalError;
|
||||
protocol_error.message = err.message();
|
||||
});
|
||||
protocol_error.id = request->id;
|
||||
return protocol_error;
|
||||
Response error_response;
|
||||
error_response.id = request->id;
|
||||
error_response.result = std::move(protocol_error);
|
||||
return error_response;
|
||||
}
|
||||
|
||||
return *response;
|
||||
@ -84,9 +86,6 @@ Server::HandleData(llvm::StringRef data) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (std::get_if<Error>(&(*message)))
|
||||
return llvm::createStringError("unexpected MCP message: error");
|
||||
|
||||
if (std::get_if<Response>(&(*message)))
|
||||
return llvm::createStringError("unexpected MCP message: response");
|
||||
|
||||
@ -123,11 +122,11 @@ void Server::AddNotificationHandler(llvm::StringRef method,
|
||||
|
||||
llvm::Expected<Response> Server::InitializeHandler(const Request &request) {
|
||||
Response response;
|
||||
response.result.emplace(llvm::json::Object{
|
||||
response.result = llvm::json::Object{
|
||||
{"protocolVersion", mcp::kProtocolVersion},
|
||||
{"capabilities", GetCapabilities()},
|
||||
{"serverInfo",
|
||||
llvm::json::Object{{"name", m_name}, {"version", m_version}}}});
|
||||
llvm::json::Object{{"name", m_name}, {"version", m_version}}}};
|
||||
return response;
|
||||
}
|
||||
|
||||
@ -138,7 +137,7 @@ llvm::Expected<Response> Server::ToolsListHandler(const Request &request) {
|
||||
for (const auto &tool : m_tools)
|
||||
tools.emplace_back(toJSON(tool.second->GetDefinition()));
|
||||
|
||||
response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}});
|
||||
response.result = llvm::json::Object{{"tools", std::move(tools)}};
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -173,7 +172,7 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) {
|
||||
if (!text_result)
|
||||
return text_result.takeError();
|
||||
|
||||
response.result.emplace(toJSON(*text_result));
|
||||
response.result = toJSON(*text_result);
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -189,8 +188,7 @@ llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) {
|
||||
for (const Resource &resource : resource_provider_up->GetResources())
|
||||
resources.push_back(resource);
|
||||
}
|
||||
response.result.emplace(
|
||||
llvm::json::Object{{"resources", std::move(resources)}});
|
||||
response.result = llvm::json::Object{{"resources", std::move(resources)}};
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -226,7 +224,7 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) {
|
||||
return result.takeError();
|
||||
|
||||
Response response;
|
||||
response.result.emplace(std::move(*result));
|
||||
response.result = std::move(*result);
|
||||
return response;
|
||||
}
|
||||
|
||||
|
@ -149,9 +149,7 @@ TEST(ProtocolMCPTest, MessageWithRequest) {
|
||||
const Request &deserialized_request =
|
||||
std::get<Request>(*deserialized_message);
|
||||
|
||||
EXPECT_EQ(request.id, deserialized_request.id);
|
||||
EXPECT_EQ(request.method, deserialized_request.method);
|
||||
EXPECT_EQ(request.params, deserialized_request.params);
|
||||
EXPECT_EQ(request, deserialized_request);
|
||||
}
|
||||
|
||||
TEST(ProtocolMCPTest, MessageWithResponse) {
|
||||
@ -168,8 +166,7 @@ TEST(ProtocolMCPTest, MessageWithResponse) {
|
||||
const Response &deserialized_response =
|
||||
std::get<Response>(*deserialized_message);
|
||||
|
||||
EXPECT_EQ(response.id, deserialized_response.id);
|
||||
EXPECT_EQ(response.result, deserialized_response.result);
|
||||
EXPECT_EQ(response, deserialized_response);
|
||||
}
|
||||
|
||||
TEST(ProtocolMCPTest, MessageWithNotification) {
|
||||
@ -186,49 +183,28 @@ TEST(ProtocolMCPTest, MessageWithNotification) {
|
||||
const Notification &deserialized_notification =
|
||||
std::get<Notification>(*deserialized_message);
|
||||
|
||||
EXPECT_EQ(notification.method, deserialized_notification.method);
|
||||
EXPECT_EQ(notification.params, deserialized_notification.params);
|
||||
EXPECT_EQ(notification, deserialized_notification);
|
||||
}
|
||||
|
||||
TEST(ProtocolMCPTest, MessageWithError) {
|
||||
ErrorInfo error_info;
|
||||
error_info.code = -32603;
|
||||
error_info.message = "Internal error";
|
||||
|
||||
TEST(ProtocolMCPTest, MessageWithErrorResponse) {
|
||||
Error error;
|
||||
error.id = 3;
|
||||
error.error = error_info;
|
||||
error.code = -32603;
|
||||
error.message = "Internal error";
|
||||
|
||||
Message message = error;
|
||||
Response error_response;
|
||||
error_response.id = 3;
|
||||
error_response.result = error;
|
||||
|
||||
Message message = error_response;
|
||||
|
||||
llvm::Expected<Message> deserialized_message = roundtripJSON(message);
|
||||
ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded());
|
||||
|
||||
ASSERT_TRUE(std::holds_alternative<Error>(*deserialized_message));
|
||||
const Error &deserialized_error = std::get<Error>(*deserialized_message);
|
||||
ASSERT_TRUE(std::holds_alternative<Response>(*deserialized_message));
|
||||
const Response &deserialized_error =
|
||||
std::get<Response>(*deserialized_message);
|
||||
|
||||
EXPECT_EQ(error.id, deserialized_error.id);
|
||||
EXPECT_EQ(error.error.code, deserialized_error.error.code);
|
||||
EXPECT_EQ(error.error.message, deserialized_error.error.message);
|
||||
}
|
||||
|
||||
TEST(ProtocolMCPTest, ResponseWithError) {
|
||||
ErrorInfo error_info;
|
||||
error_info.code = -32700;
|
||||
error_info.message = "Parse error";
|
||||
|
||||
Response response;
|
||||
response.id = 4;
|
||||
response.error = error_info;
|
||||
|
||||
llvm::Expected<Response> deserialized_response = roundtripJSON(response);
|
||||
ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded());
|
||||
|
||||
EXPECT_EQ(response.id, deserialized_response->id);
|
||||
EXPECT_FALSE(deserialized_response->result.has_value());
|
||||
ASSERT_TRUE(deserialized_response->error.has_value());
|
||||
EXPECT_EQ(response.error->code, deserialized_response->error->code);
|
||||
EXPECT_EQ(response.error->message, deserialized_response->error->message);
|
||||
EXPECT_EQ(error_response, deserialized_error);
|
||||
}
|
||||
|
||||
TEST(ProtocolMCPTest, Resource) {
|
||||
|
@ -200,7 +200,7 @@ public:
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(ProtocolServerMCPTest, Intialization) {
|
||||
TEST_F(ProtocolServerMCPTest, Initialization) {
|
||||
llvm::StringLiteral request =
|
||||
R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json";
|
||||
llvm::StringLiteral response =
|
||||
|
@ -11,11 +11,11 @@
|
||||
|
||||
#include "lldb/Core/ModuleSpec.h"
|
||||
#include "lldb/Utility/DataBuffer.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/Twine.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Support/FileSystem.h"
|
||||
#include "llvm/Support/FileUtilities.h"
|
||||
#include "llvm/Support/JSON.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <string>
|
||||
|
||||
#define ASSERT_NO_ERROR(x) \
|
||||
@ -61,12 +61,10 @@ private:
|
||||
};
|
||||
|
||||
template <typename T> static llvm::Expected<T> roundtripJSON(const T &input) {
|
||||
llvm::json::Value value = toJSON(input);
|
||||
llvm::json::Path::Root root;
|
||||
T output;
|
||||
if (!fromJSON(value, output, root))
|
||||
return root.getError();
|
||||
return output;
|
||||
std::string encoded;
|
||||
llvm::raw_string_ostream OS(encoded);
|
||||
OS << toJSON(input);
|
||||
return llvm::json::parse<T>(encoded);
|
||||
}
|
||||
} // namespace lldb_private
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user