[LLVM][MLIR] Move LSP server support library from MLIR into LLVM (#157885)

This is a second PR on this patch (first #155572), that fixes the
linking problem for `flang-aarch64-dylib` test.

The SupportLSP library was made a component library.

---

This PR moves the generic Language Server Protocol (LSP) server support
code that was copied from clangd into MLIR, into the LLVM tree so it can
be reused by multiple subprojects.

Centralizing the generic LSP support in LLVM lowers the barrier to
building new LSP servers across the LLVM ecosystem and avoids each
subproject maintaining its own copy.

The code originated in clangd and was copied into MLIR for its LSP
server. MLIR had this code seperate to be reused by all of their LSP
server. This PR relocates the MLIR copy into LLVM as a shared component
into LLVM/Support. If this is not a suitable place, please suggest a
better one.

A follow up to this move could be deduplication with the original clangd
implementation and converge on a single shared LSP support library used
by clangd, MLIR, and future servers.
What changes

mlir/include/mlir/Tools/lsp-server-support/{Logging, Protocol,
Transport}.h moved to llvm/include/llvm/Support/LSP
mlir/lib/Tools/lsp-server-support/{Logging, Protocol, Transport}.cpp
moved to llvm/lib/Support/LSP

and their namespace was changed from mlir to llvm

I ran clang-tidy --fix and clang-format on the whole moved files (last
two commits), as they are basically new files and should hold up to the
code style used by LLVM.

MLIR LSP servers where updated to include these files from their new
location and account for the namespace change.

This PR is made as part of the LLVM IR LSP project
([RFC](https://discourse.llvm.org/t/rfc-ir-visualization-with-vs-code-extension-using-an-lsp-server/87773))
This commit is contained in:
Bertik23 2025-09-11 19:17:52 +02:00 committed by GitHub
parent 8da3ab12ce
commit a3a25996b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 2413 additions and 2232 deletions

View File

@ -1,4 +1,4 @@
//===- Logging.h - MLIR LSP Server Logging ----------------------*- C++ -*-===//
//===- Logging.h - LSP Server Logging ----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,16 +6,15 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H
#define MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H
#ifndef LLVM_SUPPORT_LSP_LOGGING_H
#define LLVM_SUPPORT_LSP_LOGGING_H
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include <memory>
#include <mutex>
namespace mlir {
namespace llvm {
namespace lsp {
/// This class represents the main interface for logging, and allows for
@ -26,21 +25,18 @@ public:
enum class Level { Debug, Info, Error };
/// Set the severity level of the logger.
static void setLogLevel(Level logLevel);
static void setLogLevel(Level LogLevel);
/// Initiate a log message at various severity levels. These should be called
/// after a call to `initialize`.
template <typename... Ts>
static void debug(const char *fmt, Ts &&...vals) {
log(Level::Debug, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...));
template <typename... Ts> static void debug(const char *Fmt, Ts &&...Vals) {
log(Level::Debug, Fmt, llvm::formatv(Fmt, std::forward<Ts>(Vals)...));
}
template <typename... Ts>
static void info(const char *fmt, Ts &&...vals) {
log(Level::Info, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...));
template <typename... Ts> static void info(const char *Fmt, Ts &&...Vals) {
log(Level::Info, Fmt, llvm::formatv(Fmt, std::forward<Ts>(Vals)...));
}
template <typename... Ts>
static void error(const char *fmt, Ts &&...vals) {
log(Level::Error, fmt, llvm::formatv(fmt, std::forward<Ts>(vals)...));
template <typename... Ts> static void error(const char *Fmt, Ts &&...Vals) {
log(Level::Error, Fmt, llvm::formatv(Fmt, std::forward<Ts>(Vals)...));
}
private:
@ -50,16 +46,16 @@ private:
static Logger &get();
/// Start a log message with the given severity level.
static void log(Level logLevel, const char *fmt,
const llvm::formatv_object_base &message);
static void log(Level LogLevel, const char *Fmt,
const llvm::formatv_object_base &Message);
/// The minimum logging level. Messages with lower level are ignored.
Level logLevel = Level::Error;
Level LogLevel = Level::Error;
/// A mutex used to guard logging.
std::mutex mutex;
std::mutex Mutex;
};
} // namespace lsp
} // namespace mlir
} // namespace llvm
#endif // MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H
#endif // LLVM_SUPPORT_LSP_LOGGING_H

View File

@ -20,20 +20,24 @@
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H
#define MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H
#ifndef LLVM_SUPPORT_LSP_PROTOCOL_H
#define LLVM_SUPPORT_LSP_PROTOCOL_H
#include "mlir/Support/LLVM.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <bitset>
#include <optional>
#include <string>
#include <utility>
#include <vector>
namespace mlir {
// This file is using the LSP syntax for identifier names which is different
// from the LLVM coding standard. To avoid the clang-tidy warnings, we're
// disabling one check here.
// NOLINTBEGIN(readability-identifier-naming)
namespace llvm {
namespace lsp {
enum class ErrorCode {
@ -1241,12 +1245,11 @@ struct CodeAction {
llvm::json::Value toJSON(const CodeAction &);
} // namespace lsp
} // namespace mlir
} // namespace llvm
namespace llvm {
template <>
struct format_provider<mlir::lsp::Position> {
static void format(const mlir::lsp::Position &pos, raw_ostream &os,
template <> struct format_provider<llvm::lsp::Position> {
static void format(const llvm::lsp::Position &pos, raw_ostream &os,
StringRef style) {
assert(style.empty() && "style modifiers for this type are not supported");
os << pos;
@ -1255,3 +1258,5 @@ struct format_provider<mlir::lsp::Position> {
} // namespace llvm
#endif
// NOLINTEND(readability-identifier-naming)

View File

@ -0,0 +1,289 @@
//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The language server protocol is usually implemented by writing messages as
// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON
// transport interface that handles this communication.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_SUPPORT_LSP_TRANSPORT_H
#define LLVM_SUPPORT_LSP_TRANSPORT_H
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
namespace llvm {
// Simple helper function that returns a string as printed from a op.
template <typename T> static std::string debugString(T &&Op) {
std::string InstrStr;
llvm::raw_string_ostream Os(InstrStr);
Os << Op;
return Os.str();
}
namespace lsp {
class MessageHandler;
//===----------------------------------------------------------------------===//
// JSONTransport
//===----------------------------------------------------------------------===//
/// The encoding style of the JSON-RPC messages (both input and output).
enum JSONStreamStyle {
/// Encoding per the LSP specification, with mandatory Content-Length header.
Standard,
/// Messages are delimited by a '// -----' line. Comment lines start with //.
Delimited
};
/// An abstract class used by the JSONTransport to read JSON message.
class JSONTransportInput {
public:
explicit JSONTransportInput(JSONStreamStyle Style = JSONStreamStyle::Standard)
: Style(Style) {}
virtual ~JSONTransportInput() = default;
virtual bool hasError() const = 0;
virtual bool isEndOfInput() const = 0;
/// Read in a message from the input stream.
LogicalResult readMessage(std::string &Json) {
return Style == JSONStreamStyle::Delimited ? readDelimitedMessage(Json)
: readStandardMessage(Json);
}
virtual LogicalResult readDelimitedMessage(std::string &Json) = 0;
virtual LogicalResult readStandardMessage(std::string &Json) = 0;
private:
/// The JSON stream style to use.
JSONStreamStyle Style;
};
/// Concrete implementation of the JSONTransportInput that reads from a file.
class JSONTransportInputOverFile : public JSONTransportInput {
public:
explicit JSONTransportInputOverFile(
std::FILE *In, JSONStreamStyle Style = JSONStreamStyle::Standard)
: JSONTransportInput(Style), In(In) {}
bool hasError() const final { return ferror(In); }
bool isEndOfInput() const final { return feof(In); }
LogicalResult readDelimitedMessage(std::string &Json) final;
LogicalResult readStandardMessage(std::string &Json) final;
private:
std::FILE *In;
};
/// A transport class that performs the JSON-RPC communication with the LSP
/// client.
class JSONTransport {
public:
JSONTransport(std::unique_ptr<JSONTransportInput> In, raw_ostream &Out,
bool PrettyOutput = false)
: In(std::move(In)), Out(Out), PrettyOutput(PrettyOutput) {}
JSONTransport(std::FILE *In, raw_ostream &Out,
JSONStreamStyle Style = JSONStreamStyle::Standard,
bool PrettyOutput = false)
: In(std::make_unique<JSONTransportInputOverFile>(In, Style)), Out(Out),
PrettyOutput(PrettyOutput) {}
/// The following methods are used to send a message to the LSP client.
void notify(StringRef Method, llvm::json::Value Params);
void call(StringRef Method, llvm::json::Value Params, llvm::json::Value Id);
void reply(llvm::json::Value Id, llvm::Expected<llvm::json::Value> Result);
/// Start executing the JSON-RPC transport.
llvm::Error run(MessageHandler &Handler);
private:
/// Dispatches the given incoming json message to the message handler.
bool handleMessage(llvm::json::Value Msg, MessageHandler &Handler);
/// Writes the given message to the output stream.
void sendMessage(llvm::json::Value Msg);
private:
/// The input to read a message from.
std::unique_ptr<JSONTransportInput> In;
SmallVector<char, 0> OutputBuffer;
/// The output file stream.
raw_ostream &Out;
/// If the output JSON should be formatted for easier readability.
bool PrettyOutput;
};
//===----------------------------------------------------------------------===//
// MessageHandler
//===----------------------------------------------------------------------===//
/// A Callback<T> is a void function that accepts Expected<T>. This is
/// accepted by functions that logically return T.
template <typename T>
using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
/// An OutgoingNotification<T> is a function used for outgoing notifications
/// send to the client.
template <typename T>
using OutgoingNotification = llvm::unique_function<void(const T &)>;
/// An OutgoingRequest<T> is a function used for outgoing requests to send to
/// the client.
template <typename T>
using OutgoingRequest =
llvm::unique_function<void(const T &, llvm::json::Value Id)>;
/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
/// client receives a response in turn. It is passed the original request's ID,
/// as well as the response result.
template <typename T>
using OutgoingRequestCallback =
std::function<void(llvm::json::Value, llvm::Expected<T>)>;
/// A handler used to process the incoming transport messages.
class MessageHandler {
public:
MessageHandler(JSONTransport &Transport) : Transport(Transport) {}
bool onNotify(StringRef Method, llvm::json::Value Value);
bool onCall(StringRef Method, llvm::json::Value Params, llvm::json::Value Id);
bool onReply(llvm::json::Value Id, llvm::Expected<llvm::json::Value> Result);
template <typename T>
static llvm::Expected<T> parse(const llvm::json::Value &Raw,
StringRef PayloadName, StringRef PayloadKind) {
T Result;
llvm::json::Path::Root Root;
if (fromJSON(Raw, Result, Root))
return std::move(Result);
// Dump the relevant parts of the broken message.
std::string Context;
llvm::raw_string_ostream Os(Context);
Root.printErrorContext(Raw, Os);
// Report the error (e.g. to the client).
return llvm::make_error<LSPError>(
llvm::formatv("failed to decode {0} {1}: {2}", PayloadName, PayloadKind,
fmt_consume(Root.getError())),
ErrorCode::InvalidParams);
}
template <typename Param, typename Result, typename ThisT>
void method(llvm::StringLiteral Method, ThisT *ThisPtr,
void (ThisT::*Handler)(const Param &, Callback<Result>)) {
MethodHandlers[Method] = [Method, Handler,
ThisPtr](llvm::json::Value RawParams,
Callback<llvm::json::Value> Reply) {
llvm::Expected<Param> Parameter =
parse<Param>(RawParams, Method, "request");
if (!Parameter)
return Reply(Parameter.takeError());
(ThisPtr->*Handler)(*Parameter, std::move(Reply));
};
}
template <typename Param, typename ThisT>
void notification(llvm::StringLiteral Method, ThisT *ThisPtr,
void (ThisT::*Handler)(const Param &)) {
NotificationHandlers[Method] = [Method, Handler,
ThisPtr](llvm::json::Value RawParams) {
llvm::Expected<Param> Parameter =
parse<Param>(RawParams, Method, "notification");
if (!Parameter) {
return llvm::consumeError(llvm::handleErrors(
Parameter.takeError(), [](const LSPError &LspError) {
Logger::error("JSON parsing error: {0}",
LspError.message.c_str());
}));
}
(ThisPtr->*Handler)(*Parameter);
};
}
/// Create an OutgoingNotification object used for the given method.
template <typename T>
OutgoingNotification<T> outgoingNotification(llvm::StringLiteral Method) {
return [&, Method](const T &Params) {
std::lock_guard<std::mutex> TransportLock(TransportOutputMutex);
Logger::info("--> {0}", Method);
Transport.notify(Method, llvm::json::Value(Params));
};
}
/// Create an OutgoingRequest function that, when called, sends a request with
/// the given method via the transport. Should the outgoing request be
/// met with a response, the result JSON is parsed and the response callback
/// is invoked.
template <typename Param, typename Result>
OutgoingRequest<Param>
outgoingRequest(llvm::StringLiteral Method,
OutgoingRequestCallback<Result> Callback) {
return [&, Method, Callback](const Param &Parameter, llvm::json::Value Id) {
auto CallbackWrapper = [Method, Callback = std::move(Callback)](
llvm::json::Value Id,
llvm::Expected<llvm::json::Value> Value) {
if (!Value)
return Callback(std::move(Id), Value.takeError());
std::string ResponseName = llvm::formatv("reply:{0}({1})", Method, Id);
llvm::Expected<Result> ParseResult =
parse<Result>(*Value, ResponseName, "response");
if (!ParseResult)
return Callback(std::move(Id), ParseResult.takeError());
return Callback(std::move(Id), *ParseResult);
};
{
std::lock_guard<std::mutex> Lock(ResponseHandlersMutex);
ResponseHandlers.insert(
{debugString(Id), std::make_pair(Method.str(), CallbackWrapper)});
}
std::lock_guard<std::mutex> TransportLock(TransportOutputMutex);
Logger::info("--> {0}({1})", Method, Id);
Transport.call(Method, llvm::json::Value(Parameter), Id);
};
}
private:
template <typename HandlerT>
using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
HandlerMap<void(llvm::json::Value)> NotificationHandlers;
HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
MethodHandlers;
/// A pair of (1) the original request's method name, and (2) the callback
/// function to be invoked for responses.
using ResponseHandlerTy =
std::pair<std::string, OutgoingRequestCallback<llvm::json::Value>>;
/// A mapping from request/response ID to response handler.
llvm::StringMap<ResponseHandlerTy> ResponseHandlers;
/// Mutex to guard insertion into the response handler map.
std::mutex ResponseHandlersMutex;
JSONTransport &Transport;
/// Mutex to guard sending output messages to the transport.
std::mutex TransportOutputMutex;
};
} // namespace lsp
} // namespace llvm
#endif

View File

@ -135,6 +135,7 @@ if (UNIX AND "${CMAKE_SYSTEM_NAME}" MATCHES "AIX")
endif()
add_subdirectory(BLAKE3)
add_subdirectory(LSP)
add_llvm_component_library(LLVMSupport
ABIBreak.cpp

View File

@ -0,0 +1,8 @@
add_llvm_component_library(LLVMSupportLSP
Protocol.cpp
Transport.cpp
Logging.cpp
DEPENDS
LLVMSupport
)

View File

@ -6,36 +6,36 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/Chrono.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::lsp;
using namespace llvm;
using namespace llvm::lsp;
void Logger::setLogLevel(Level logLevel) { get().logLevel = logLevel; }
void Logger::setLogLevel(Level LogLevel) { get().LogLevel = LogLevel; }
Logger &Logger::get() {
static Logger logger;
return logger;
static Logger Logger;
return Logger;
}
void Logger::log(Level logLevel, const char *fmt,
const llvm::formatv_object_base &message) {
Logger &logger = get();
void Logger::log(Level LogLevel, const char *Fmt,
const llvm::formatv_object_base &Message) {
Logger &Logger = get();
// Ignore messages with log levels below the current setting in the logger.
if (logLevel < logger.logLevel)
if (LogLevel < Logger.LogLevel)
return;
// An indicator character for each log level.
const char *logLevelIndicators = "DIE";
const char *LogLevelIndicators = "DIE";
// Format the message and print to errs.
llvm::sys::TimePoint<> timestamp = std::chrono::system_clock::now();
std::lock_guard<std::mutex> logGuard(logger.mutex);
llvm::sys::TimePoint<> Timestamp = std::chrono::system_clock::now();
std::lock_guard<std::mutex> LogGuard(Logger.Mutex);
llvm::errs() << llvm::formatv(
"{0}[{1:%H:%M:%S.%L}] {2}\n",
logLevelIndicators[static_cast<unsigned>(logLevel)], timestamp, message);
LogLevelIndicators[static_cast<unsigned>(LogLevel)], Timestamp, Message);
llvm::errs().flush();
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,369 @@
//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "llvm/Support/LSP/Transport.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include <atomic>
#include <optional>
#include <system_error>
#include <utility>
using namespace llvm;
using namespace llvm::lsp;
//===----------------------------------------------------------------------===//
// Reply
//===----------------------------------------------------------------------===//
namespace {
/// Function object to reply to an LSP call.
/// Each instance must be called exactly once, otherwise:
/// - if there was no reply, an error reply is sent
/// - if there were multiple replies, only the first is sent
class Reply {
public:
Reply(const llvm::json::Value &Id, StringRef Method, JSONTransport &Transport,
std::mutex &TransportOutputMutex);
Reply(Reply &&Other);
Reply &operator=(Reply &&) = delete;
Reply(const Reply &) = delete;
Reply &operator=(const Reply &) = delete;
void operator()(llvm::Expected<llvm::json::Value> Reply);
private:
std::string Method;
std::atomic<bool> Replied = {false};
llvm::json::Value Id;
JSONTransport *Transport;
std::mutex &TransportOutputMutex;
};
} // namespace
Reply::Reply(const llvm::json::Value &Id, llvm::StringRef Method,
JSONTransport &Transport, std::mutex &TransportOutputMutex)
: Method(Method), Id(Id), Transport(&Transport),
TransportOutputMutex(TransportOutputMutex) {}
Reply::Reply(Reply &&Other)
: Method(Other.Method), Replied(Other.Replied.load()),
Id(std::move(Other.Id)), Transport(Other.Transport),
TransportOutputMutex(Other.TransportOutputMutex) {
Other.Transport = nullptr;
}
void Reply::operator()(llvm::Expected<llvm::json::Value> Reply) {
if (Replied.exchange(true)) {
Logger::error("Replied twice to message {0}({1})", Method, Id);
assert(false && "must reply to each call only once!");
return;
}
assert(Transport && "expected valid transport to reply to");
std::lock_guard<std::mutex> TransportLock(TransportOutputMutex);
if (Reply) {
Logger::info("--> reply:{0}({1})", Method, Id);
Transport->reply(std::move(Id), std::move(Reply));
} else {
llvm::Error Error = Reply.takeError();
Logger::info("--> reply:{0}({1}): {2}", Method, Id, Error);
Transport->reply(std::move(Id), std::move(Error));
}
}
//===----------------------------------------------------------------------===//
// MessageHandler
//===----------------------------------------------------------------------===//
bool MessageHandler::onNotify(llvm::StringRef Method, llvm::json::Value Value) {
Logger::info("--> {0}", Method);
if (Method == "exit")
return false;
if (Method == "$cancel") {
// TODO: Add support for cancelling requests.
} else {
auto It = NotificationHandlers.find(Method);
if (It != NotificationHandlers.end())
It->second(std::move(Value));
}
return true;
}
bool MessageHandler::onCall(llvm::StringRef Method, llvm::json::Value Params,
llvm::json::Value Id) {
Logger::info("--> {0}({1})", Method, Id);
Reply Reply(Id, Method, Transport, TransportOutputMutex);
auto It = MethodHandlers.find(Method);
if (It != MethodHandlers.end()) {
It->second(std::move(Params), std::move(Reply));
} else {
Reply(llvm::make_error<LSPError>("method not found: " + Method.str(),
ErrorCode::MethodNotFound));
}
return true;
}
bool MessageHandler::onReply(llvm::json::Value Id,
llvm::Expected<llvm::json::Value> Result) {
// Find the response handler in the mapping. If it exists, move it out of the
// mapping and erase it.
ResponseHandlerTy ResponseHandler;
{
std::lock_guard<std::mutex> responseHandlersLock(ResponseHandlerTy);
auto It = ResponseHandlers.find(debugString(Id));
if (It != ResponseHandlers.end()) {
ResponseHandler = std::move(It->second);
ResponseHandlers.erase(It);
}
}
// If we found a response handler, invoke it. Otherwise, log an error.
if (ResponseHandler.second) {
Logger::info("--> reply:{0}({1})", ResponseHandler.first, Id);
ResponseHandler.second(std::move(Id), std::move(Result));
} else {
Logger::error(
"received a reply with ID {0}, but there was no such outgoing request",
Id);
if (!Result)
llvm::consumeError(Result.takeError());
}
return true;
}
//===----------------------------------------------------------------------===//
// JSONTransport
//===----------------------------------------------------------------------===//
/// Encode the given error as a JSON object.
static llvm::json::Object encodeError(llvm::Error Error) {
std::string Message;
ErrorCode Code = ErrorCode::UnknownErrorCode;
auto HandlerFn = [&](const LSPError &LspError) -> llvm::Error {
Message = LspError.message;
Code = LspError.code;
return llvm::Error::success();
};
if (llvm::Error Unhandled = llvm::handleErrors(std::move(Error), HandlerFn))
Message = llvm::toString(std::move(Unhandled));
return llvm::json::Object{
{"message", std::move(Message)},
{"code", int64_t(Code)},
};
}
/// Decode the given JSON object into an error.
llvm::Error decodeError(const llvm::json::Object &O) {
StringRef Msg = O.getString("message").value_or("Unspecified error");
if (std::optional<int64_t> Code = O.getInteger("code"))
return llvm::make_error<LSPError>(Msg.str(), ErrorCode(*Code));
return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
Msg.str());
}
void JSONTransport::notify(StringRef Method, llvm::json::Value Params) {
sendMessage(llvm::json::Object{
{"jsonrpc", "2.0"},
{"method", Method},
{"params", std::move(Params)},
});
}
void JSONTransport::call(StringRef Method, llvm::json::Value Params,
llvm::json::Value Id) {
sendMessage(llvm::json::Object{
{"jsonrpc", "2.0"},
{"id", std::move(Id)},
{"method", Method},
{"params", std::move(Params)},
});
}
void JSONTransport::reply(llvm::json::Value Id,
llvm::Expected<llvm::json::Value> Result) {
if (Result) {
return sendMessage(llvm::json::Object{
{"jsonrpc", "2.0"},
{"id", std::move(Id)},
{"result", std::move(*Result)},
});
}
sendMessage(llvm::json::Object{
{"jsonrpc", "2.0"},
{"id", std::move(Id)},
{"error", encodeError(Result.takeError())},
});
}
llvm::Error JSONTransport::run(MessageHandler &Handler) {
std::string Json;
while (!In->isEndOfInput()) {
if (In->hasError()) {
return llvm::errorCodeToError(
std::error_code(errno, std::system_category()));
}
if (succeeded(In->readMessage(Json))) {
if (llvm::Expected<llvm::json::Value> Doc = llvm::json::parse(Json)) {
if (!handleMessage(std::move(*Doc), Handler))
return llvm::Error::success();
} else {
Logger::error("JSON parse error: {0}", llvm::toString(Doc.takeError()));
}
}
}
return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
}
void JSONTransport::sendMessage(llvm::json::Value Msg) {
OutputBuffer.clear();
llvm::raw_svector_ostream os(OutputBuffer);
os << llvm::formatv(PrettyOutput ? "{0:2}\n" : "{0}", Msg);
Out << "Content-Length: " << OutputBuffer.size() << "\r\n\r\n"
<< OutputBuffer;
Out.flush();
Logger::debug(">>> {0}\n", OutputBuffer);
}
bool JSONTransport::handleMessage(llvm::json::Value Msg,
MessageHandler &Handler) {
// Message must be an object with "jsonrpc":"2.0".
llvm::json::Object *Object = Msg.getAsObject();
if (!Object ||
Object->getString("jsonrpc") != std::optional<StringRef>("2.0"))
return false;
// `id` may be any JSON value. If absent, this is a notification.
std::optional<llvm::json::Value> Id;
if (llvm::json::Value *I = Object->get("id"))
Id = std::move(*I);
std::optional<StringRef> Method = Object->getString("method");
// This is a response.
if (!Method) {
if (!Id)
return false;
if (auto *Err = Object->getObject("error"))
return Handler.onReply(std::move(*Id), decodeError(*Err));
// result should be given, use null if not.
llvm::json::Value Result = nullptr;
if (llvm::json::Value *R = Object->get("result"))
Result = std::move(*R);
return Handler.onReply(std::move(*Id), std::move(Result));
}
// Params should be given, use null if not.
llvm::json::Value Params = nullptr;
if (llvm::json::Value *P = Object->get("params"))
Params = std::move(*P);
if (Id)
return Handler.onCall(*Method, std::move(Params), std::move(*Id));
return Handler.onNotify(*Method, std::move(Params));
}
/// Tries to read a line up to and including \n.
/// If failing, feof(), ferror(), or shutdownRequested() will be set.
LogicalResult readLine(std::FILE *In, SmallVectorImpl<char> &Out) {
// Big enough to hold any reasonable header line. May not fit content lines
// in delimited mode, but performance doesn't matter for that mode.
static constexpr int BufSize = 128;
size_t Size = 0;
Out.clear();
for (;;) {
Out.resize_for_overwrite(Size + BufSize);
if (!std::fgets(&Out[Size], BufSize, In))
return failure();
clearerr(In);
// If the line contained null bytes, anything after it (including \n) will
// be ignored. Fortunately this is not a legal header or JSON.
size_t Read = std::strlen(&Out[Size]);
if (Read > 0 && Out[Size + Read - 1] == '\n') {
Out.resize(Size + Read);
return success();
}
Size += Read;
}
}
// Returns std::nullopt when:
// - ferror(), feof(), or shutdownRequested() are set.
// - Content-Length is missing or empty (protocol error)
LogicalResult
JSONTransportInputOverFile::readStandardMessage(std::string &Json) {
// A Language Server Protocol message starts with a set of HTTP headers,
// delimited by \r\n, and terminated by an empty line (\r\n).
unsigned long long ContentLength = 0;
llvm::SmallString<128> Line;
while (true) {
if (feof(In) || hasError() || failed(readLine(In, Line)))
return failure();
// Content-Length is a mandatory header, and the only one we handle.
StringRef LineRef = Line;
if (LineRef.consume_front("Content-Length: ")) {
llvm::getAsUnsignedInteger(LineRef.trim(), 0, ContentLength);
} else if (!LineRef.trim().empty()) {
// It's another header, ignore it.
continue;
} else {
// An empty line indicates the end of headers. Go ahead and read the JSON.
break;
}
}
// The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
if (ContentLength == 0 || ContentLength > 1 << 30)
return failure();
Json.resize(ContentLength);
for (size_t Pos = 0, Read; Pos < ContentLength; Pos += Read) {
Read = std::fread(&Json[Pos], 1, ContentLength - Pos, In);
if (Read == 0)
return failure();
// If we're done, the error was transient. If we're not done, either it was
// transient or we'll see it again on retry.
clearerr(In);
Pos += Read;
}
return success();
}
/// For lit tests we support a simplified syntax:
/// - messages are delimited by '// -----' on a line by itself
/// - lines starting with // are ignored.
/// This is a testing path, so favor simplicity over performance here.
/// When returning failure: feof(), ferror(), or shutdownRequested() will be
/// set.
LogicalResult
JSONTransportInputOverFile::readDelimitedMessage(std::string &Json) {
Json.clear();
llvm::SmallString<128> Line;
while (succeeded(readLine(In, Line))) {
StringRef LineRef = Line.str().trim();
if (LineRef.starts_with("//")) {
// Found a delimiter for the message.
if (LineRef == "// -----")
break;
continue;
}
Json += Line;
}
return failure(ferror(In));
}

View File

@ -125,6 +125,8 @@ add_llvm_unittest(SupportTests
intrinsics_gen
)
add_subdirectory(LSP)
target_link_libraries(SupportTests PRIVATE LLVMTestingSupport)
# Disable all warning for AlignOfTest.cpp,

View File

@ -0,0 +1,8 @@
set(LLVM_LINK_COMPONENTS
SupportLSP
)
add_llvm_unittest(LLVMSupportLSPTests
Protocol.cpp
Transport.cpp
)

View File

@ -6,12 +6,12 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "llvm/Support/LSP/Protocol.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::lsp;
using namespace llvm;
using namespace llvm::lsp;
using namespace testing;
namespace {

View File

@ -6,15 +6,15 @@
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "llvm/Support/LSP/Transport.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
using namespace mlir;
using namespace mlir::lsp;
using namespace llvm;
using namespace llvm::lsp;
using namespace testing;
namespace {
@ -88,7 +88,7 @@ protected:
TEST_F(TransportInputTest, RequestWithInvalidParams) {
struct Handler {
void onMethod(const TextDocumentItem &params,
mlir::lsp::Callback<TextDocumentIdentifier> callback) {}
llvm::lsp::Callback<TextDocumentIdentifier> callback) {}
} handler;
getMessageHandler().method("invalid-params-request", &handler,
&Handler::onMethod);

View File

@ -14,7 +14,8 @@
#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H
#define MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/SourceMgr.h"
#include <optional>
@ -45,17 +46,18 @@ bool contains(SMRange range, SMLoc loc);
/// This class represents a single include within a root file.
struct SourceMgrInclude {
SourceMgrInclude(const lsp::URIForFile &uri, const lsp::Range &range)
SourceMgrInclude(const llvm::lsp::URIForFile &uri,
const llvm::lsp::Range &range)
: uri(uri), range(range) {}
/// Build a hover for the current include file.
Hover buildHover() const;
llvm::lsp::Hover buildHover() const;
/// The URI of the file that is included.
lsp::URIForFile uri;
llvm::lsp::URIForFile uri;
/// The range of the include directive.
lsp::Range range;
llvm::lsp::Range range;
};
/// Given a source manager, gather all of the processed include files. These are

View File

@ -1,283 +0,0 @@
//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// The language server protocol is usually implemented by writing messages as
// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON
// transport interface that handles this communication.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
#define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H
#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "llvm/ADT/FunctionExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatAdapters.h"
#include "llvm/Support/JSON.h"
#include "llvm/Support/raw_ostream.h"
#include <atomic>
namespace mlir {
namespace lsp {
class MessageHandler;
//===----------------------------------------------------------------------===//
// JSONTransport
//===----------------------------------------------------------------------===//
/// The encoding style of the JSON-RPC messages (both input and output).
enum JSONStreamStyle {
/// Encoding per the LSP specification, with mandatory Content-Length header.
Standard,
/// Messages are delimited by a '// -----' line. Comment lines start with //.
Delimited
};
/// An abstract class used by the JSONTransport to read JSON message.
class JSONTransportInput {
public:
explicit JSONTransportInput(JSONStreamStyle style = JSONStreamStyle::Standard)
: style(style) {}
virtual ~JSONTransportInput() = default;
virtual bool hasError() const = 0;
virtual bool isEndOfInput() const = 0;
/// Read in a message from the input stream.
LogicalResult readMessage(std::string &json) {
return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json)
: readStandardMessage(json);
}
virtual LogicalResult readDelimitedMessage(std::string &json) = 0;
virtual LogicalResult readStandardMessage(std::string &json) = 0;
private:
/// The JSON stream style to use.
JSONStreamStyle style;
};
/// Concrete implementation of the JSONTransportInput that reads from a file.
class JSONTransportInputOverFile : public JSONTransportInput {
public:
explicit JSONTransportInputOverFile(
std::FILE *in, JSONStreamStyle style = JSONStreamStyle::Standard)
: JSONTransportInput(style), in(in) {}
bool hasError() const final { return ferror(in); }
bool isEndOfInput() const final { return feof(in); }
LogicalResult readDelimitedMessage(std::string &json) final;
LogicalResult readStandardMessage(std::string &json) final;
private:
std::FILE *in;
};
/// A transport class that performs the JSON-RPC communication with the LSP
/// client.
class JSONTransport {
public:
JSONTransport(std::unique_ptr<JSONTransportInput> in, raw_ostream &out,
bool prettyOutput = false)
: in(std::move(in)), out(out), prettyOutput(prettyOutput) {}
JSONTransport(std::FILE *in, raw_ostream &out,
JSONStreamStyle style = JSONStreamStyle::Standard,
bool prettyOutput = false)
: in(std::make_unique<JSONTransportInputOverFile>(in, style)), out(out),
prettyOutput(prettyOutput) {}
/// The following methods are used to send a message to the LSP client.
void notify(StringRef method, llvm::json::Value params);
void call(StringRef method, llvm::json::Value params, llvm::json::Value id);
void reply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result);
/// Start executing the JSON-RPC transport.
llvm::Error run(MessageHandler &handler);
private:
/// Dispatches the given incoming json message to the message handler.
bool handleMessage(llvm::json::Value msg, MessageHandler &handler);
/// Writes the given message to the output stream.
void sendMessage(llvm::json::Value msg);
private:
/// The input to read a message from.
std::unique_ptr<JSONTransportInput> in;
SmallVector<char, 0> outputBuffer;
/// The output file stream.
raw_ostream &out;
/// If the output JSON should be formatted for easier readability.
bool prettyOutput;
};
//===----------------------------------------------------------------------===//
// MessageHandler
//===----------------------------------------------------------------------===//
/// A Callback<T> is a void function that accepts Expected<T>. This is
/// accepted by functions that logically return T.
template <typename T>
using Callback = llvm::unique_function<void(llvm::Expected<T>)>;
/// An OutgoingNotification<T> is a function used for outgoing notifications
/// send to the client.
template <typename T>
using OutgoingNotification = llvm::unique_function<void(const T &)>;
/// An OutgoingRequest<T> is a function used for outgoing requests to send to
/// the client.
template <typename T>
using OutgoingRequest =
llvm::unique_function<void(const T &, llvm::json::Value id)>;
/// An `OutgoingRequestCallback` is invoked when an outgoing request to the
/// client receives a response in turn. It is passed the original request's ID,
/// as well as the response result.
template <typename T>
using OutgoingRequestCallback =
std::function<void(llvm::json::Value, llvm::Expected<T>)>;
/// A handler used to process the incoming transport messages.
class MessageHandler {
public:
MessageHandler(JSONTransport &transport) : transport(transport) {}
bool onNotify(StringRef method, llvm::json::Value value);
bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id);
bool onReply(llvm::json::Value id, llvm::Expected<llvm::json::Value> result);
template <typename T>
static llvm::Expected<T> parse(const llvm::json::Value &raw,
StringRef payloadName, StringRef payloadKind) {
T result;
llvm::json::Path::Root root;
if (fromJSON(raw, result, root))
return std::move(result);
// Dump the relevant parts of the broken message.
std::string context;
llvm::raw_string_ostream os(context);
root.printErrorContext(raw, os);
// Report the error (e.g. to the client).
return llvm::make_error<LSPError>(
llvm::formatv("failed to decode {0} {1}: {2}", payloadName, payloadKind,
fmt_consume(root.getError())),
ErrorCode::InvalidParams);
}
template <typename Param, typename Result, typename ThisT>
void method(llvm::StringLiteral method, ThisT *thisPtr,
void (ThisT::*handler)(const Param &, Callback<Result>)) {
methodHandlers[method] = [method, handler,
thisPtr](llvm::json::Value rawParams,
Callback<llvm::json::Value> reply) {
llvm::Expected<Param> param = parse<Param>(rawParams, method, "request");
if (!param)
return reply(param.takeError());
(thisPtr->*handler)(*param, std::move(reply));
};
}
template <typename Param, typename ThisT>
void notification(llvm::StringLiteral method, ThisT *thisPtr,
void (ThisT::*handler)(const Param &)) {
notificationHandlers[method] = [method, handler,
thisPtr](llvm::json::Value rawParams) {
llvm::Expected<Param> param =
parse<Param>(rawParams, method, "notification");
if (!param) {
return llvm::consumeError(
llvm::handleErrors(param.takeError(), [](const LSPError &lspError) {
Logger::error("JSON parsing error: {0}",
lspError.message.c_str());
}));
}
(thisPtr->*handler)(*param);
};
}
/// Create an OutgoingNotification object used for the given method.
template <typename T>
OutgoingNotification<T> outgoingNotification(llvm::StringLiteral method) {
return [&, method](const T &params) {
std::lock_guard<std::mutex> transportLock(transportOutputMutex);
Logger::info("--> {0}", method);
transport.notify(method, llvm::json::Value(params));
};
}
/// Create an OutgoingRequest function that, when called, sends a request with
/// the given method via the transport. Should the outgoing request be
/// met with a response, the result JSON is parsed and the response callback
/// is invoked.
template <typename Param, typename Result>
OutgoingRequest<Param>
outgoingRequest(llvm::StringLiteral method,
OutgoingRequestCallback<Result> callback) {
return [&, method, callback](const Param &param, llvm::json::Value id) {
auto callbackWrapper = [method, callback = std::move(callback)](
llvm::json::Value id,
llvm::Expected<llvm::json::Value> value) {
if (!value)
return callback(std::move(id), value.takeError());
std::string responseName = llvm::formatv("reply:{0}({1})", method, id);
llvm::Expected<Result> result =
parse<Result>(*value, responseName, "response");
if (!result)
return callback(std::move(id), result.takeError());
return callback(std::move(id), *result);
};
{
std::lock_guard<std::mutex> lock(responseHandlersMutex);
responseHandlers.insert(
{debugString(id), std::make_pair(method.str(), callbackWrapper)});
}
std::lock_guard<std::mutex> transportLock(transportOutputMutex);
Logger::info("--> {0}({1})", method, id);
transport.call(method, llvm::json::Value(param), id);
};
}
private:
template <typename HandlerT>
using HandlerMap = llvm::StringMap<llvm::unique_function<HandlerT>>;
HandlerMap<void(llvm::json::Value)> notificationHandlers;
HandlerMap<void(llvm::json::Value, Callback<llvm::json::Value>)>
methodHandlers;
/// A pair of (1) the original request's method name, and (2) the callback
/// function to be invoked for responses.
using ResponseHandlerTy =
std::pair<std::string, OutgoingRequestCallback<llvm::json::Value>>;
/// A mapping from request/response ID to response handler.
llvm::StringMap<ResponseHandlerTy> responseHandlers;
/// Mutex to guard insertion into the response handler map.
std::mutex responseHandlersMutex;
JSONTransport &transport;
/// Mutex to guard sending output messages to the transport.
std::mutex transportOutputMutex;
};
} // namespace lsp
} // namespace mlir
#endif

View File

@ -16,14 +16,16 @@
namespace llvm {
template <typename Fn>
class function_ref;
namespace lsp {
class URIForFile;
} // namespace lsp
} // namespace llvm
namespace mlir {
class DialectRegistry;
namespace lsp {
class URIForFile;
using DialectRegistryFn =
llvm::function_ref<DialectRegistry &(const URIForFile &uri)>;
llvm::function_ref<DialectRegistry &(const llvm::lsp::URIForFile &uri)>;
} // namespace lsp
} // namespace mlir

View File

@ -1,13 +1,13 @@
add_mlir_library(MLIRLspServerSupportLib
CompilationDatabase.cpp
Logging.cpp
Protocol.cpp
SourceMgrUtils.cpp
Transport.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/lsp-server-support
LINK_COMPONENTS
SupportLSP
LINK_LIBS PUBLIC
MLIRSupport
)
)

View File

@ -8,14 +8,15 @@
#include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/YAMLTraits.h"
using namespace mlir;
using namespace mlir::lsp;
using llvm::lsp::Logger;
//===----------------------------------------------------------------------===//
// YamlFileInfo

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,10 @@
using namespace mlir;
using namespace mlir::lsp;
using llvm::lsp::Hover;
using llvm::lsp::Range;
using llvm::lsp::URIForFile;
//===----------------------------------------------------------------------===//
// Utils
//===----------------------------------------------------------------------===//

View File

@ -1,369 +0,0 @@
//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "mlir/Support/ToolUtilities.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Support/Error.h"
#include <optional>
#include <system_error>
#include <utility>
using namespace mlir;
using namespace mlir::lsp;
//===----------------------------------------------------------------------===//
// Reply
//===----------------------------------------------------------------------===//
namespace {
/// Function object to reply to an LSP call.
/// Each instance must be called exactly once, otherwise:
/// - if there was no reply, an error reply is sent
/// - if there were multiple replies, only the first is sent
class Reply {
public:
Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport,
std::mutex &transportOutputMutex);
Reply(Reply &&other);
Reply &operator=(Reply &&) = delete;
Reply(const Reply &) = delete;
Reply &operator=(const Reply &) = delete;
void operator()(llvm::Expected<llvm::json::Value> reply);
private:
std::string method;
std::atomic<bool> replied = {false};
llvm::json::Value id;
JSONTransport *transport;
std::mutex &transportOutputMutex;
};
} // namespace
Reply::Reply(const llvm::json::Value &id, llvm::StringRef method,
JSONTransport &transport, std::mutex &transportOutputMutex)
: method(method), id(id), transport(&transport),
transportOutputMutex(transportOutputMutex) {}
Reply::Reply(Reply &&other)
: method(other.method), replied(other.replied.load()),
id(std::move(other.id)), transport(other.transport),
transportOutputMutex(other.transportOutputMutex) {
other.transport = nullptr;
}
void Reply::operator()(llvm::Expected<llvm::json::Value> reply) {
if (replied.exchange(true)) {
Logger::error("Replied twice to message {0}({1})", method, id);
assert(false && "must reply to each call only once!");
return;
}
assert(transport && "expected valid transport to reply to");
std::lock_guard<std::mutex> transportLock(transportOutputMutex);
if (reply) {
Logger::info("--> reply:{0}({1})", method, id);
transport->reply(std::move(id), std::move(reply));
} else {
llvm::Error error = reply.takeError();
Logger::info("--> reply:{0}({1}): {2}", method, id, error);
transport->reply(std::move(id), std::move(error));
}
}
//===----------------------------------------------------------------------===//
// MessageHandler
//===----------------------------------------------------------------------===//
bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) {
Logger::info("--> {0}", method);
if (method == "exit")
return false;
if (method == "$cancel") {
// TODO: Add support for cancelling requests.
} else {
auto it = notificationHandlers.find(method);
if (it != notificationHandlers.end())
it->second(std::move(value));
}
return true;
}
bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params,
llvm::json::Value id) {
Logger::info("--> {0}({1})", method, id);
Reply reply(id, method, transport, transportOutputMutex);
auto it = methodHandlers.find(method);
if (it != methodHandlers.end()) {
it->second(std::move(params), std::move(reply));
} else {
reply(llvm::make_error<LSPError>("method not found: " + method.str(),
ErrorCode::MethodNotFound));
}
return true;
}
bool MessageHandler::onReply(llvm::json::Value id,
llvm::Expected<llvm::json::Value> result) {
// Find the response handler in the mapping. If it exists, move it out of the
// mapping and erase it.
ResponseHandlerTy responseHandler;
{
std::lock_guard<std::mutex> responseHandlersLock(responseHandlersMutex);
auto it = responseHandlers.find(debugString(id));
if (it != responseHandlers.end()) {
responseHandler = std::move(it->second);
responseHandlers.erase(it);
}
}
// If we found a response handler, invoke it. Otherwise, log an error.
if (responseHandler.second) {
Logger::info("--> reply:{0}({1})", responseHandler.first, id);
responseHandler.second(std::move(id), std::move(result));
} else {
Logger::error(
"received a reply with ID {0}, but there was no such outgoing request",
id);
if (!result)
llvm::consumeError(result.takeError());
}
return true;
}
//===----------------------------------------------------------------------===//
// JSONTransport
//===----------------------------------------------------------------------===//
/// Encode the given error as a JSON object.
static llvm::json::Object encodeError(llvm::Error error) {
std::string message;
ErrorCode code = ErrorCode::UnknownErrorCode;
auto handlerFn = [&](const LSPError &lspError) -> llvm::Error {
message = lspError.message;
code = lspError.code;
return llvm::Error::success();
};
if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn))
message = llvm::toString(std::move(unhandled));
return llvm::json::Object{
{"message", std::move(message)},
{"code", int64_t(code)},
};
}
/// Decode the given JSON object into an error.
llvm::Error decodeError(const llvm::json::Object &o) {
StringRef msg = o.getString("message").value_or("Unspecified error");
if (std::optional<int64_t> code = o.getInteger("code"))
return llvm::make_error<LSPError>(msg.str(), ErrorCode(*code));
return llvm::make_error<llvm::StringError>(llvm::inconvertibleErrorCode(),
msg.str());
}
void JSONTransport::notify(StringRef method, llvm::json::Value params) {
sendMessage(llvm::json::Object{
{"jsonrpc", "2.0"},
{"method", method},
{"params", std::move(params)},
});
}
void JSONTransport::call(StringRef method, llvm::json::Value params,
llvm::json::Value id) {
sendMessage(llvm::json::Object{
{"jsonrpc", "2.0"},
{"id", std::move(id)},
{"method", method},
{"params", std::move(params)},
});
}
void JSONTransport::reply(llvm::json::Value id,
llvm::Expected<llvm::json::Value> result) {
if (result) {
return sendMessage(llvm::json::Object{
{"jsonrpc", "2.0"},
{"id", std::move(id)},
{"result", std::move(*result)},
});
}
sendMessage(llvm::json::Object{
{"jsonrpc", "2.0"},
{"id", std::move(id)},
{"error", encodeError(result.takeError())},
});
}
llvm::Error JSONTransport::run(MessageHandler &handler) {
std::string json;
while (!in->isEndOfInput()) {
if (in->hasError()) {
return llvm::errorCodeToError(
std::error_code(errno, std::system_category()));
}
if (succeeded(in->readMessage(json))) {
if (llvm::Expected<llvm::json::Value> doc = llvm::json::parse(json)) {
if (!handleMessage(std::move(*doc), handler))
return llvm::Error::success();
} else {
Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError()));
}
}
}
return llvm::errorCodeToError(std::make_error_code(std::errc::io_error));
}
void JSONTransport::sendMessage(llvm::json::Value msg) {
outputBuffer.clear();
llvm::raw_svector_ostream os(outputBuffer);
os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg);
out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n"
<< outputBuffer;
out.flush();
Logger::debug(">>> {0}\n", outputBuffer);
}
bool JSONTransport::handleMessage(llvm::json::Value msg,
MessageHandler &handler) {
// Message must be an object with "jsonrpc":"2.0".
llvm::json::Object *object = msg.getAsObject();
if (!object ||
object->getString("jsonrpc") != std::optional<StringRef>("2.0"))
return false;
// `id` may be any JSON value. If absent, this is a notification.
std::optional<llvm::json::Value> id;
if (llvm::json::Value *i = object->get("id"))
id = std::move(*i);
std::optional<StringRef> method = object->getString("method");
// This is a response.
if (!method) {
if (!id)
return false;
if (auto *err = object->getObject("error"))
return handler.onReply(std::move(*id), decodeError(*err));
// result should be given, use null if not.
llvm::json::Value result = nullptr;
if (llvm::json::Value *r = object->get("result"))
result = std::move(*r);
return handler.onReply(std::move(*id), std::move(result));
}
// Params should be given, use null if not.
llvm::json::Value params = nullptr;
if (llvm::json::Value *p = object->get("params"))
params = std::move(*p);
if (id)
return handler.onCall(*method, std::move(params), std::move(*id));
return handler.onNotify(*method, std::move(params));
}
/// Tries to read a line up to and including \n.
/// If failing, feof(), ferror(), or shutdownRequested() will be set.
LogicalResult readLine(std::FILE *in, SmallVectorImpl<char> &out) {
// Big enough to hold any reasonable header line. May not fit content lines
// in delimited mode, but performance doesn't matter for that mode.
static constexpr int bufSize = 128;
size_t size = 0;
out.clear();
for (;;) {
out.resize_for_overwrite(size + bufSize);
if (!std::fgets(&out[size], bufSize, in))
return failure();
clearerr(in);
// If the line contained null bytes, anything after it (including \n) will
// be ignored. Fortunately this is not a legal header or JSON.
size_t read = std::strlen(&out[size]);
if (read > 0 && out[size + read - 1] == '\n') {
out.resize(size + read);
return success();
}
size += read;
}
}
// Returns std::nullopt when:
// - ferror(), feof(), or shutdownRequested() are set.
// - Content-Length is missing or empty (protocol error)
LogicalResult
JSONTransportInputOverFile::readStandardMessage(std::string &json) {
// A Language Server Protocol message starts with a set of HTTP headers,
// delimited by \r\n, and terminated by an empty line (\r\n).
unsigned long long contentLength = 0;
llvm::SmallString<128> line;
while (true) {
if (feof(in) || hasError() || failed(readLine(in, line)))
return failure();
// Content-Length is a mandatory header, and the only one we handle.
StringRef lineRef = line;
if (lineRef.consume_front("Content-Length: ")) {
llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength);
} else if (!lineRef.trim().empty()) {
// It's another header, ignore it.
continue;
} else {
// An empty line indicates the end of headers. Go ahead and read the JSON.
break;
}
}
// The fuzzer likes crashing us by sending "Content-Length: 9999999999999999"
if (contentLength == 0 || contentLength > 1 << 30)
return failure();
json.resize(contentLength);
for (size_t pos = 0, read; pos < contentLength; pos += read) {
read = std::fread(&json[pos], 1, contentLength - pos, in);
if (read == 0)
return failure();
// If we're done, the error was transient. If we're not done, either it was
// transient or we'll see it again on retry.
clearerr(in);
pos += read;
}
return success();
}
/// For lit tests we support a simplified syntax:
/// - messages are delimited by '// -----' on a line by itself
/// - lines starting with // are ignored.
/// This is a testing path, so favor simplicity over performance here.
/// When returning failure: feof(), ferror(), or shutdownRequested() will be
/// set.
LogicalResult
JSONTransportInputOverFile::readDelimitedMessage(std::string &json) {
json.clear();
llvm::SmallString<128> line;
while (succeeded(readLine(in, line))) {
StringRef lineRef = line.str().trim();
if (lineRef.starts_with("//")) {
// Found a delimiter for the message.
if (lineRef == kDefaultSplitMarker)
break;
continue;
}
json += line;
}
return failure(ferror(in));
}

View File

@ -7,6 +7,9 @@ add_mlir_library(MLIRLspServerLib
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-lsp-server
LINK_COMPONENTS
SupportLSP
LINK_LIBS PUBLIC
MLIRBytecodeWriter
MLIRFunctionInterfaces

View File

@ -9,8 +9,8 @@
#include "LSPServer.h"
#include "MLIRServer.h"
#include "Protocol.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Transport.h"
#include <optional>
#define DEBUG_TYPE "mlir-lsp-server"
@ -18,6 +18,33 @@
using namespace mlir;
using namespace mlir::lsp;
using llvm::lsp::Callback;
using llvm::lsp::CodeAction;
using llvm::lsp::CodeActionParams;
using llvm::lsp::CompletionList;
using llvm::lsp::CompletionParams;
using llvm::lsp::DidChangeTextDocumentParams;
using llvm::lsp::DidCloseTextDocumentParams;
using llvm::lsp::DidOpenTextDocumentParams;
using llvm::lsp::DocumentSymbol;
using llvm::lsp::DocumentSymbolParams;
using llvm::lsp::Hover;
using llvm::lsp::InitializedParams;
using llvm::lsp::InitializeParams;
using llvm::lsp::JSONTransport;
using llvm::lsp::Location;
using llvm::lsp::Logger;
using llvm::lsp::MessageHandler;
using llvm::lsp::MLIRConvertBytecodeParams;
using llvm::lsp::MLIRConvertBytecodeResult;
using llvm::lsp::NoParams;
using llvm::lsp::OutgoingNotification;
using llvm::lsp::PublishDiagnosticsParams;
using llvm::lsp::ReferenceParams;
using llvm::lsp::TextDocumentPositionParams;
using llvm::lsp::TextDocumentSyncKind;
using llvm::lsp::URIForFile;
//===----------------------------------------------------------------------===//
// LSPServer
//===----------------------------------------------------------------------===//

View File

@ -13,17 +13,19 @@
namespace llvm {
struct LogicalResult;
namespace lsp {
class JSONTransport;
} // namespace lsp
} // namespace llvm
namespace mlir {
namespace lsp {
class JSONTransport;
class MLIRServer;
/// Run the main loop of the LSP server using the given MLIR server and
/// transport.
llvm::LogicalResult runMlirLSPServer(MLIRServer &server,
JSONTransport &transport);
llvm::lsp::JSONTransport &transport);
} // namespace lsp
} // namespace mlir

View File

@ -16,10 +16,10 @@
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/ToolUtilities.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Base64.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/SourceMgr.h"
#include <optional>
@ -39,9 +39,9 @@ static std::optional<lsp::Location> getLocationFromLoc(StringRef uriScheme,
llvm::Expected<lsp::URIForFile> sourceURI =
lsp::URIForFile::fromFile(loc.getFilename(), uriScheme);
if (!sourceURI) {
lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
loc.getFilename(),
llvm::toString(sourceURI.takeError()));
llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}",
loc.getFilename(),
llvm::toString(sourceURI.takeError()));
return std::nullopt;
}
@ -217,22 +217,22 @@ static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr,
// Convert the severity for the diagnostic.
switch (diag.getSeverity()) {
case DiagnosticSeverity::Note:
case mlir::DiagnosticSeverity::Note:
llvm_unreachable("expected notes to be handled separately");
case DiagnosticSeverity::Warning:
lspDiag.severity = lsp::DiagnosticSeverity::Warning;
case mlir::DiagnosticSeverity::Warning:
lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
break;
case DiagnosticSeverity::Error:
lspDiag.severity = lsp::DiagnosticSeverity::Error;
case mlir::DiagnosticSeverity::Error:
lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
break;
case DiagnosticSeverity::Remark:
lspDiag.severity = lsp::DiagnosticSeverity::Information;
case mlir::DiagnosticSeverity::Remark:
lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
break;
}
lspDiag.message = diag.str();
// Attach any notes to the main diagnostic as related information.
std::vector<lsp::DiagnosticRelatedInformation> relatedDiags;
std::vector<llvm::lsp::DiagnosticRelatedInformation> relatedDiags;
for (Diagnostic &note : diag.getNotes()) {
lsp::Location noteLoc;
if (std::optional<lsp::Location> loc =
@ -317,7 +317,7 @@ struct MLIRDocument {
void getCodeActionForDiagnostic(const lsp::URIForFile &uri,
lsp::Position &pos, StringRef severity,
StringRef message,
std::vector<lsp::TextEdit> &edits);
std::vector<llvm::lsp::TextEdit> &edits);
//===--------------------------------------------------------------------===//
// Bytecode
@ -355,7 +355,8 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
// Try to parsed the given IR string.
auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file());
if (!memBuffer) {
lsp::Logger::error("Failed to create memory buffer for file", uri.file());
llvm::lsp::Logger::error("Failed to create memory buffer for file",
uri.file());
return;
}
@ -695,8 +696,8 @@ void MLIRDocument::findDocumentSymbols(
if (SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op)) {
symbols.emplace_back(symbol.getName(),
isa<FunctionOpInterface>(op)
? lsp::SymbolKind::Function
: lsp::SymbolKind::Class,
? llvm::lsp::SymbolKind::Function
: llvm::lsp::SymbolKind::Class,
lsp::Range(sourceMgr, def->scopeLoc),
lsp::Range(sourceMgr, def->loc));
childSymbols = &symbols.back().children;
@ -704,9 +705,9 @@ void MLIRDocument::findDocumentSymbols(
} else if (op->hasTrait<OpTrait::SymbolTable>()) {
// Otherwise, if this is a symbol table push an anonymous document symbol.
symbols.emplace_back("<" + op->getName().getStringRef() + ">",
lsp::SymbolKind::Namespace,
lsp::Range(sourceMgr, def->scopeLoc),
lsp::Range(sourceMgr, def->loc));
llvm::lsp::SymbolKind::Namespace,
llvm::lsp::Range(sourceMgr, def->scopeLoc),
llvm::lsp::Range(sourceMgr, def->loc));
childSymbols = &symbols.back().children;
}
}
@ -734,9 +735,9 @@ public:
/// Signal code completion for a dialect name, with an optional prefix.
void completeDialectName(StringRef prefix) final {
for (StringRef dialect : ctx->getAvailableDialects()) {
lsp::CompletionItem item(prefix + dialect,
lsp::CompletionItemKind::Module,
/*sortText=*/"3");
llvm::lsp::CompletionItem item(prefix + dialect,
llvm::lsp::CompletionItemKind::Module,
/*sortText=*/"3");
item.detail = "dialect";
completionList.items.emplace_back(item);
}
@ -753,9 +754,9 @@ public:
if (&op.getDialect() != dialect)
continue;
lsp::CompletionItem item(
llvm::lsp::CompletionItem item(
op.getStringRef().drop_front(dialectName.size() + 1),
lsp::CompletionItemKind::Field,
llvm::lsp::CompletionItemKind::Field,
/*sortText=*/"1");
item.detail = "operation";
completionList.items.emplace_back(item);
@ -768,7 +769,8 @@ public:
// Check if we need to insert the `%` or not.
bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%';
lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable);
llvm::lsp::CompletionItem item(name,
llvm::lsp::CompletionItemKind::Variable);
if (stripPrefix)
item.insertText = name.drop_front(1).str();
item.detail = std::move(typeData);
@ -781,7 +783,7 @@ public:
// Check if we need to insert the `^` or not.
bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^';
lsp::CompletionItem item(name, lsp::CompletionItemKind::Field);
llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field);
if (stripPrefix)
item.insertText = name.drop_front(1).str();
completionList.items.emplace_back(item);
@ -790,8 +792,9 @@ public:
/// Signal a completion for the given expected token.
void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
for (StringRef token : tokens) {
lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
/*sortText=*/"0");
llvm::lsp::CompletionItem item(token,
llvm::lsp::CompletionItemKind::Keyword,
/*sortText=*/"0");
item.detail = optional ? "optional" : "";
completionList.items.emplace_back(item);
}
@ -802,7 +805,7 @@ public:
appendSimpleCompletions({"affine_set", "affine_map", "dense",
"dense_resource", "false", "loc", "sparse", "true",
"unit"},
lsp::CompletionItemKind::Field,
llvm::lsp::CompletionItemKind::Field,
/*sortText=*/"1");
completeDialectName("#");
@ -820,13 +823,14 @@ public:
appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
"bf16", "f16", "f32", "f64", "f80", "f128",
"index", "none"},
lsp::CompletionItemKind::Field,
llvm::lsp::CompletionItemKind::Field,
/*sortText=*/"1");
// Handle the builtin integer types.
for (StringRef type : {"i", "si", "ui"}) {
lsp::CompletionItem item(type + "<N>", lsp::CompletionItemKind::Field,
/*sortText=*/"1");
llvm::lsp::CompletionItem item(type + "<N>",
llvm::lsp::CompletionItemKind::Field,
/*sortText=*/"1");
item.insertText = type.str();
completionList.items.emplace_back(item);
}
@ -846,9 +850,9 @@ public:
void completeAliases(const llvm::StringMap<T> &aliases,
StringRef prefix = "") {
for (const auto &alias : aliases) {
lsp::CompletionItem item(prefix + alias.getKey(),
lsp::CompletionItemKind::Field,
/*sortText=*/"2");
llvm::lsp::CompletionItem item(prefix + alias.getKey(),
llvm::lsp::CompletionItemKind::Field,
/*sortText=*/"2");
llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
completionList.items.emplace_back(item);
}
@ -856,7 +860,7 @@ public:
/// Add a set of simple completions that all have the same kind.
void appendSimpleCompletions(ArrayRef<StringRef> completions,
lsp::CompletionItemKind kind,
llvm::lsp::CompletionItemKind kind,
StringRef sortText = "") {
for (StringRef completion : completions)
completionList.items.emplace_back(completion, kind, sortText);
@ -897,7 +901,7 @@ MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri,
void MLIRDocument::getCodeActionForDiagnostic(
const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity,
StringRef message, std::vector<lsp::TextEdit> &edits) {
StringRef message, std::vector<llvm::lsp::TextEdit> &edits) {
// Ignore diagnostics that print the current operation. These are always
// enabled for the language server, but not generally during normal
// parsing/verification.
@ -913,7 +917,7 @@ void MLIRDocument::getCodeActionForDiagnostic(
// Add a text edit for adding an expected-* diagnostic check for this
// diagnostic.
lsp::TextEdit edit;
llvm::lsp::TextEdit edit;
edit.range = lsp::Range(lsp::Position(pos.line, 0));
// Use the indent of the current line for the expected-* diagnostic.
@ -937,13 +941,14 @@ MLIRDocument::convertToBytecode() {
// conceptually be relaxed.
if (!llvm::hasSingleElement(parsedIR)) {
if (parsedIR.empty()) {
return llvm::make_error<lsp::LSPError>(
return llvm::make_error<llvm::lsp::LSPError>(
"expected a single and valid top-level operation, please ensure "
"there are no errors",
lsp::ErrorCode::RequestFailed);
llvm::lsp::ErrorCode::RequestFailed);
}
return llvm::make_error<lsp::LSPError>(
"expected a single top-level operation", lsp::ErrorCode::RequestFailed);
return llvm::make_error<llvm::lsp::LSPError>(
"expected a single top-level operation",
llvm::lsp::ErrorCode::RequestFailed);
}
lsp::MLIRConvertBytecodeResult result;
@ -1134,7 +1139,7 @@ void MLIRTextFile::findDocumentSymbols(
lsp::Position endPos((i == e - 1) ? totalNumLines - 1
: chunks[i + 1]->lineOffset);
lsp::DocumentSymbol symbol("<file-split-" + Twine(i) + ">",
lsp::SymbolKind::Namespace,
llvm::lsp::SymbolKind::Namespace,
/*range=*/lsp::Range(startPos, endPos),
/*selectionRange=*/lsp::Range(startPos));
chunk.document.findDocumentSymbols(symbol.children);
@ -1167,10 +1172,10 @@ lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri,
uri, completePos, context.getDialectRegistry());
// Adjust any completion locations.
for (lsp::CompletionItem &item : completionList.items) {
for (llvm::lsp::CompletionItem &item : completionList.items) {
if (item.textEdit)
chunk.adjustLocForChunkOffset(item.textEdit->range);
for (lsp::TextEdit &edit : item.additionalTextEdits)
for (llvm::lsp::TextEdit &edit : item.additionalTextEdits)
chunk.adjustLocForChunkOffset(edit.range);
}
return completionList;
@ -1194,10 +1199,10 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
StringRef severity;
switch (diag.severity) {
case lsp::DiagnosticSeverity::Error:
case llvm::lsp::DiagnosticSeverity::Error:
severity = "error";
break;
case lsp::DiagnosticSeverity::Warning:
case llvm::lsp::DiagnosticSeverity::Warning:
severity = "warning";
break;
default:
@ -1205,7 +1210,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
}
// Get edits for the diagnostic.
std::vector<lsp::TextEdit> edits;
std::vector<llvm::lsp::TextEdit> edits;
chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity,
diag.message, edits);
@ -1221,7 +1226,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri,
}
}
// Fixup the locations for any edits.
for (lsp::TextEdit &edit : edits)
for (llvm::lsp::TextEdit &edit : edits)
chunk.adjustLocForChunkOffset(edit.range);
action.edit.emplace();
@ -1236,9 +1241,9 @@ llvm::Expected<lsp::MLIRConvertBytecodeResult>
MLIRTextFile::convertToBytecode() {
// Bail out if there is more than one chunk, bytecode wants a single module.
if (chunks.size() != 1) {
return llvm::make_error<lsp::LSPError>(
return llvm::make_error<llvm::lsp::LSPError>(
"unexpected split file, please remove all `// -----`",
lsp::ErrorCode::RequestFailed);
llvm::lsp::ErrorCode::RequestFailed);
}
return chunks.front()->document.convertToBytecode();
}
@ -1283,7 +1288,7 @@ lsp::MLIRServer::~MLIRServer() = default;
void lsp::MLIRServer::addOrUpdateDocument(
const URIForFile &uri, StringRef contents, int64_t version,
std::vector<Diagnostic> &diagnostics) {
std::vector<llvm::lsp::Diagnostic> &diagnostics) {
impl->files[uri.file()] = std::make_unique<MLIRTextFile>(
uri, contents, version, impl->registry_fn, diagnostics);
}
@ -1298,17 +1303,17 @@ std::optional<int64_t> lsp::MLIRServer::removeDocument(const URIForFile &uri) {
return version;
}
void lsp::MLIRServer::getLocationsOf(const URIForFile &uri,
const Position &defPos,
std::vector<Location> &locations) {
void lsp::MLIRServer::getLocationsOf(
const URIForFile &uri, const Position &defPos,
std::vector<llvm::lsp::Location> &locations) {
auto fileIt = impl->files.find(uri.file());
if (fileIt != impl->files.end())
fileIt->second->getLocationsOf(uri, defPos, locations);
}
void lsp::MLIRServer::findReferencesOf(const URIForFile &uri,
const Position &pos,
std::vector<Location> &references) {
void lsp::MLIRServer::findReferencesOf(
const URIForFile &uri, const Position &pos,
std::vector<llvm::lsp::Location> &references) {
auto fileIt = impl->files.find(uri.file());
if (fileIt != impl->files.end())
fileIt->second->findReferencesOf(uri, pos, references);
@ -1367,17 +1372,17 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
// Try to parse the given source file.
Block parsedBlock;
if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
return llvm::make_error<lsp::LSPError>(
return llvm::make_error<llvm::lsp::LSPError>(
"failed to parse bytecode source file: " + errorMsg,
lsp::ErrorCode::RequestFailed);
llvm::lsp::ErrorCode::RequestFailed);
}
// TODO: We currently expect a single top-level operation, but this could
// conceptually be relaxed.
if (!llvm::hasSingleElement(parsedBlock)) {
return llvm::make_error<lsp::LSPError>(
return llvm::make_error<llvm::lsp::LSPError>(
"expected bytecode to contain a single top-level operation",
lsp::ErrorCode::RequestFailed);
llvm::lsp::ErrorCode::RequestFailed);
}
// Print the module to a buffer.
@ -1401,9 +1406,9 @@ llvm::Expected<lsp::MLIRConvertBytecodeResult>
lsp::MLIRServer::convertToBytecode(const URIForFile &uri) {
auto fileIt = impl->files.find(uri.file());
if (fileIt == impl->files.end()) {
return llvm::make_error<lsp::LSPError>(
return llvm::make_error<llvm::lsp::LSPError>(
"language server does not contain an entry for this source file",
lsp::ErrorCode::RequestFailed);
llvm::lsp::ErrorCode::RequestFailed);
}
return fileIt->second->convertToBytecode();
}

View File

@ -9,6 +9,7 @@
#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_
#define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_
#include "Protocol.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h"
#include "llvm/Support/Error.h"
@ -19,16 +20,17 @@ namespace mlir {
class DialectRegistry;
namespace lsp {
struct CodeAction;
struct CodeActionContext;
struct CompletionList;
struct Diagnostic;
struct DocumentSymbol;
struct Hover;
struct Location;
struct MLIRConvertBytecodeResult;
struct Position;
struct Range;
using llvm::lsp::CodeAction;
using llvm::lsp::CodeActionContext;
using llvm::lsp::CompletionList;
using llvm::lsp::Diagnostic;
using llvm::lsp::DocumentSymbol;
using llvm::lsp::Hover;
using llvm::lsp::Location;
using llvm::lsp::MLIRConvertBytecodeResult;
using llvm::lsp::Position;
using llvm::lsp::Range;
using llvm::lsp::URIForFile;
/// This class implements all of the MLIR related functionality necessary for a
/// language server. This class allows for keeping the MLIR specific logic

View File

@ -9,14 +9,18 @@
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
#include "LSPServer.h"
#include "MLIRServer.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Transport.h"
#include "llvm/Support/Program.h"
using namespace mlir;
using namespace mlir::lsp;
using llvm::lsp::JSONStreamStyle;
using llvm::lsp::JSONTransport;
using llvm::lsp::Logger;
LogicalResult mlir::MlirLspServerMain(int argc, char **argv,
DialectRegistryFn registry_fn) {
llvm::cl::opt<JSONStreamStyle> inputStyle{

View File

@ -13,14 +13,11 @@
#include "Protocol.h"
#include "llvm/Support/JSON.h"
using namespace mlir;
using namespace mlir::lsp;
//===----------------------------------------------------------------------===//
// MLIRConvertBytecodeParams
//===----------------------------------------------------------------------===//
bool mlir::lsp::fromJSON(const llvm::json::Value &value,
bool llvm::lsp::fromJSON(const llvm::json::Value &value,
MLIRConvertBytecodeParams &result,
llvm::json::Path path) {
llvm::json::ObjectMapper o(value, path);
@ -31,6 +28,6 @@ bool mlir::lsp::fromJSON(const llvm::json::Value &value,
// MLIRConvertBytecodeResult
//===----------------------------------------------------------------------===//
llvm::json::Value mlir::lsp::toJSON(const MLIRConvertBytecodeResult &value) {
llvm::json::Value llvm::lsp::toJSON(const MLIRConvertBytecodeResult &value) {
return llvm::json::Object{{"output", value.output}};
}

View File

@ -20,9 +20,9 @@
#ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_
#define LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "llvm/Support/LSP/Protocol.h"
namespace mlir {
namespace llvm {
namespace lsp {
//===----------------------------------------------------------------------===//
// MLIRConvertBytecodeParams
@ -54,6 +54,6 @@ struct MLIRConvertBytecodeResult {
llvm::json::Value toJSON(const MLIRConvertBytecodeResult &value);
} // namespace lsp
} // namespace mlir
} // namespace llvm
#endif

View File

@ -7,6 +7,9 @@ llvm_add_library(MLIRPdllLspServerLib
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-pdll-lsp-server
LINK_COMPONENTS
SupportLSP
LINK_LIBS PUBLIC
MLIRPDLLCodeGen
MLIRPDLLParser

View File

@ -10,8 +10,9 @@
#include "PDLLServer.h"
#include "Protocol.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/LSP/Transport.h"
#include <optional>
#define DEBUG_TYPE "pdll-lsp-server"
@ -19,6 +20,30 @@
using namespace mlir;
using namespace mlir::lsp;
using llvm::lsp::Callback;
using llvm::lsp::CompletionList;
using llvm::lsp::CompletionParams;
using llvm::lsp::DidChangeTextDocumentParams;
using llvm::lsp::DidCloseTextDocumentParams;
using llvm::lsp::DidOpenTextDocumentParams;
using llvm::lsp::DocumentLinkParams;
using llvm::lsp::DocumentSymbol;
using llvm::lsp::DocumentSymbolParams;
using llvm::lsp::Hover;
using llvm::lsp::InitializedParams;
using llvm::lsp::InitializeParams;
using llvm::lsp::InlayHintsParams;
using llvm::lsp::JSONTransport;
using llvm::lsp::Location;
using llvm::lsp::Logger;
using llvm::lsp::MessageHandler;
using llvm::lsp::NoParams;
using llvm::lsp::OutgoingNotification;
using llvm::lsp::PublishDiagnosticsParams;
using llvm::lsp::ReferenceParams;
using llvm::lsp::TextDocumentPositionParams;
using llvm::lsp::TextDocumentSyncKind;
//===----------------------------------------------------------------------===//
// LSPServer
//===----------------------------------------------------------------------===//

View File

@ -13,17 +13,19 @@
namespace llvm {
struct LogicalResult;
namespace lsp {
class JSONTransport;
} // namespace lsp
} // namespace llvm
namespace mlir {
namespace lsp {
class JSONTransport;
class PDLLServer;
/// Run the main loop of the LSP server using the given PDLL server and
/// transport.
llvm::LogicalResult runPdllLSPServer(PDLLServer &server,
JSONTransport &transport);
llvm::lsp::JSONTransport &transport);
} // namespace lsp
} // namespace mlir

View File

@ -9,14 +9,17 @@
#include "mlir/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.h"
#include "LSPServer.h"
#include "PDLLServer.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Transport.h"
#include "llvm/Support/Program.h"
using namespace mlir;
using namespace mlir::lsp;
using llvm::lsp::JSONStreamStyle;
using llvm::lsp::Logger;
LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) {
llvm::cl::opt<JSONStreamStyle> inputStyle{
"input-style",
@ -72,7 +75,8 @@ LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) {
// Configure the transport used for communication.
llvm::sys::ChangeStdinToBinary();
JSONTransport transport(stdin, llvm::outs(), inputStyle, prettyPrint);
llvm::lsp::JSONTransport transport(stdin, llvm::outs(), inputStyle,
prettyPrint);
// Configure the servers and start the main language server.
PDLLServer::Options options(compilationDatabases, extraIncludeDirs);

File diff suppressed because it is too large Load Diff

View File

@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LSP/Protocol.h"
#include <memory>
#include <optional>
#include <string>
@ -18,21 +19,22 @@
namespace mlir {
namespace lsp {
struct Diagnostic;
using llvm::lsp::CompletionList;
using llvm::lsp::Diagnostic;
using llvm::lsp::DocumentLink;
using llvm::lsp::DocumentSymbol;
using llvm::lsp::Hover;
using llvm::lsp::InlayHint;
using llvm::lsp::Location;
using llvm::lsp::Position;
using llvm::lsp::Range;
using llvm::lsp::SignatureHelp;
using llvm::lsp::TextDocumentContentChangeEvent;
using llvm::lsp::URIForFile;
class CompilationDatabase;
struct PDLLViewOutputResult;
enum class PDLLViewOutputKind;
struct CompletionList;
struct DocumentLink;
struct DocumentSymbol;
struct Hover;
struct InlayHint;
struct Location;
struct Position;
struct Range;
struct SignatureHelp;
struct TextDocumentContentChangeEvent;
class URIForFile;
/// This class implements all of the PDLL related functionality necessary for a
/// language server. This class allows for keeping the PDLL specific logic

View File

@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "Protocol.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/JSON.h"

View File

@ -20,10 +20,12 @@
#ifndef LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_
#define LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "llvm/Support/LSP/Protocol.h"
namespace mlir {
namespace lsp {
using llvm::lsp::URIForFile;
//===----------------------------------------------------------------------===//
// PDLLViewOutputParams
//===----------------------------------------------------------------------===//

View File

@ -2,6 +2,7 @@ set(LLVM_LINK_COMPONENTS
Demangle
Support
TableGen
SupportLSP
)
llvm_add_library(TableGenLspServerLib

View File

@ -9,14 +9,33 @@
#include "LSPServer.h"
#include "TableGenServer.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/LSP/Transport.h"
#include <optional>
using namespace mlir;
using namespace mlir::lsp;
using llvm::lsp::Callback;
using llvm::lsp::DidChangeTextDocumentParams;
using llvm::lsp::DidCloseTextDocumentParams;
using llvm::lsp::DidOpenTextDocumentParams;
using llvm::lsp::DocumentLinkParams;
using llvm::lsp::Hover;
using llvm::lsp::InitializedParams;
using llvm::lsp::InitializeParams;
using llvm::lsp::JSONTransport;
using llvm::lsp::Location;
using llvm::lsp::Logger;
using llvm::lsp::MessageHandler;
using llvm::lsp::NoParams;
using llvm::lsp::OutgoingNotification;
using llvm::lsp::PublishDiagnosticsParams;
using llvm::lsp::ReferenceParams;
using llvm::lsp::TextDocumentPositionParams;
using llvm::lsp::TextDocumentSyncKind;
//===----------------------------------------------------------------------===//
// LSPServer
//===----------------------------------------------------------------------===//

View File

@ -13,17 +13,19 @@
namespace llvm {
struct LogicalResult;
namespace lsp {
class JSONTransport;
} // namespace lsp
} // namespace llvm
namespace mlir {
namespace lsp {
class JSONTransport;
class TableGenServer;
/// Run the main loop of the LSP server using the given TableGen server and
/// transport.
llvm::LogicalResult runTableGenLSPServer(TableGenServer &server,
JSONTransport &transport);
llvm::lsp::JSONTransport &transport);
} // namespace lsp
} // namespace mlir

View File

@ -9,14 +9,18 @@
#include "mlir/Tools/tblgen-lsp-server/TableGenLspServerMain.h"
#include "LSPServer.h"
#include "TableGenServer.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Transport.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Transport.h"
#include "llvm/Support/Program.h"
using namespace mlir;
using namespace mlir::lsp;
using llvm::lsp::JSONStreamStyle;
using llvm::lsp::JSONTransport;
using llvm::lsp::Logger;
LogicalResult mlir::TableGenLspServerMain(int argc, char **argv) {
llvm::cl::opt<JSONStreamStyle> inputStyle{
"input-style",

View File

@ -10,12 +10,12 @@
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Tools/lsp-server-support/CompilationDatabase.h"
#include "mlir/Tools/lsp-server-support/Logging.h"
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "mlir/Tools/lsp-server-support/SourceMgrUtils.h"
#include "llvm/ADT/IntervalMap.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/LSP/Logging.h"
#include "llvm/Support/LSP/Protocol.h"
#include "llvm/Support/Path.h"
#include "llvm/TableGen/Parser.h"
#include "llvm/TableGen/Record.h"
@ -36,45 +36,49 @@ static SMRange convertTokenLocToRange(SMLoc loc) {
/// Returns a language server uri for the given source location. `mainFileURI`
/// corresponds to the uri for the main file of the source manager.
static lsp::URIForFile getURIFromLoc(const SourceMgr &mgr, SMLoc loc,
const lsp::URIForFile &mainFileURI) {
static llvm::lsp::URIForFile
getURIFromLoc(const SourceMgr &mgr, SMLoc loc,
const llvm::lsp::URIForFile &mainFileURI) {
int bufferId = mgr.FindBufferContainingLoc(loc);
if (bufferId == 0 || bufferId == static_cast<int>(mgr.getMainFileID()))
return mainFileURI;
llvm::Expected<lsp::URIForFile> fileForLoc = lsp::URIForFile::fromFile(
mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
llvm::Expected<llvm::lsp::URIForFile> fileForLoc =
llvm::lsp::URIForFile::fromFile(
mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier());
if (fileForLoc)
return *fileForLoc;
lsp::Logger::error("Failed to create URI for include file: {0}",
llvm::toString(fileForLoc.takeError()));
llvm::lsp::Logger::error("Failed to create URI for include file: {0}",
llvm::toString(fileForLoc.takeError()));
return mainFileURI;
}
/// Returns a language server location from the given source range.
static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMRange loc,
const lsp::URIForFile &uri) {
return lsp::Location(getURIFromLoc(mgr, loc.Start, uri),
lsp::Range(mgr, loc));
static llvm::lsp::Location
getLocationFromLoc(SourceMgr &mgr, SMRange loc,
const llvm::lsp::URIForFile &uri) {
return llvm::lsp::Location(getURIFromLoc(mgr, loc.Start, uri),
llvm::lsp::Range(mgr, loc));
}
static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMLoc loc,
const lsp::URIForFile &uri) {
static llvm::lsp::Location
getLocationFromLoc(SourceMgr &mgr, SMLoc loc,
const llvm::lsp::URIForFile &uri) {
return getLocationFromLoc(mgr, convertTokenLocToRange(loc), uri);
}
/// Convert the given TableGen diagnostic to the LSP form.
static std::optional<lsp::Diagnostic>
static std::optional<llvm::lsp::Diagnostic>
getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,
const lsp::URIForFile &uri) {
const llvm::lsp::URIForFile &uri) {
auto *sourceMgr = const_cast<SourceMgr *>(diag.getSourceMgr());
if (!sourceMgr || !diag.getLoc().isValid())
return std::nullopt;
lsp::Diagnostic lspDiag;
llvm::lsp::Diagnostic lspDiag;
lspDiag.source = "tablegen";
lspDiag.category = "Parse Error";
// Try to grab a file location for this diagnostic.
lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri);
llvm::lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri);
lspDiag.range = loc.range;
// Skip diagnostics that weren't emitted within the main file.
@ -84,17 +88,17 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag,
// Convert the severity for the diagnostic.
switch (diag.getKind()) {
case SourceMgr::DK_Warning:
lspDiag.severity = lsp::DiagnosticSeverity::Warning;
lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning;
break;
case SourceMgr::DK_Error:
lspDiag.severity = lsp::DiagnosticSeverity::Error;
lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error;
break;
case SourceMgr::DK_Note:
// Notes are emitted separately from the main diagnostic, so we just treat
// them as remarks given that we can't determine the diagnostic to relate
// them to.
case SourceMgr::DK_Remark:
lspDiag.severity = lsp::DiagnosticSeverity::Information;
lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information;
break;
}
lspDiag.message = diag.getMessage().str();
@ -322,54 +326,59 @@ namespace {
/// This class represents a text file containing one or more TableGen documents.
class TableGenTextFile {
public:
TableGenTextFile(const lsp::URIForFile &uri, StringRef fileContents,
TableGenTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents,
int64_t version,
const std::vector<std::string> &extraIncludeDirs,
std::vector<lsp::Diagnostic> &diagnostics);
std::vector<llvm::lsp::Diagnostic> &diagnostics);
/// Return the current version of this text file.
int64_t getVersion() const { return version; }
/// Update the file to the new version using the provided set of content
/// changes. Returns failure if the update was unsuccessful.
LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion,
ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
std::vector<lsp::Diagnostic> &diagnostics);
LogicalResult
update(const llvm::lsp::URIForFile &uri, int64_t newVersion,
ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes,
std::vector<llvm::lsp::Diagnostic> &diagnostics);
//===--------------------------------------------------------------------===//
// Definitions and References
//===--------------------------------------------------------------------===//
void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos,
std::vector<lsp::Location> &locations);
void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos,
std::vector<lsp::Location> &references);
void getLocationsOf(const llvm::lsp::URIForFile &uri,
const llvm::lsp::Position &defPos,
std::vector<llvm::lsp::Location> &locations);
void findReferencesOf(const llvm::lsp::URIForFile &uri,
const llvm::lsp::Position &pos,
std::vector<llvm::lsp::Location> &references);
//===--------------------------------------------------------------------===//
// Document Links
//===--------------------------------------------------------------------===//
void getDocumentLinks(const lsp::URIForFile &uri,
std::vector<lsp::DocumentLink> &links);
void getDocumentLinks(const llvm::lsp::URIForFile &uri,
std::vector<llvm::lsp::DocumentLink> &links);
//===--------------------------------------------------------------------===//
// Hover
//===--------------------------------------------------------------------===//
std::optional<lsp::Hover> findHover(const lsp::URIForFile &uri,
const lsp::Position &hoverPos);
lsp::Hover buildHoverForRecord(const Record *record,
const SMRange &hoverRange);
lsp::Hover buildHoverForTemplateArg(const Record *record,
std::optional<llvm::lsp::Hover>
findHover(const llvm::lsp::URIForFile &uri,
const llvm::lsp::Position &hoverPos);
llvm::lsp::Hover buildHoverForRecord(const Record *record,
const SMRange &hoverRange);
llvm::lsp::Hover buildHoverForTemplateArg(const Record *record,
const RecordVal *value,
const SMRange &hoverRange);
llvm::lsp::Hover buildHoverForField(const Record *record,
const RecordVal *value,
const SMRange &hoverRange);
lsp::Hover buildHoverForField(const Record *record, const RecordVal *value,
const SMRange &hoverRange);
private:
/// Initialize the text file from the given file contents.
void initialize(const lsp::URIForFile &uri, int64_t newVersion,
std::vector<lsp::Diagnostic> &diagnostics);
void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion,
std::vector<llvm::lsp::Diagnostic> &diagnostics);
/// The full string contents of the file.
std::string contents;
@ -395,9 +404,9 @@ private:
} // namespace
TableGenTextFile::TableGenTextFile(
const lsp::URIForFile &uri, StringRef fileContents, int64_t version,
const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version,
const std::vector<std::string> &extraIncludeDirs,
std::vector<lsp::Diagnostic> &diagnostics)
std::vector<llvm::lsp::Diagnostic> &diagnostics)
: contents(fileContents.str()), version(version) {
// Build the set of include directories for this file.
llvm::SmallString<32> uriDirectory(uri.file());
@ -409,12 +418,13 @@ TableGenTextFile::TableGenTextFile(
initialize(uri, version, diagnostics);
}
LogicalResult
TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
ArrayRef<lsp::TextDocumentContentChangeEvent> changes,
std::vector<lsp::Diagnostic> &diagnostics) {
if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) {
lsp::Logger::error("Failed to update contents of {0}", uri.file());
LogicalResult TableGenTextFile::update(
const llvm::lsp::URIForFile &uri, int64_t newVersion,
ArrayRef<llvm::lsp::TextDocumentContentChangeEvent> changes,
std::vector<llvm::lsp::Diagnostic> &diagnostics) {
if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes,
contents))) {
llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file());
return failure();
}
@ -423,9 +433,9 @@ TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion,
return success();
}
void TableGenTextFile::initialize(const lsp::URIForFile &uri,
int64_t newVersion,
std::vector<lsp::Diagnostic> &diagnostics) {
void TableGenTextFile::initialize(
const llvm::lsp::URIForFile &uri, int64_t newVersion,
std::vector<llvm::lsp::Diagnostic> &diagnostics) {
version = newVersion;
sourceMgr = SourceMgr();
recordKeeper = std::make_unique<RecordKeeper>();
@ -433,7 +443,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
// Build a buffer for this file.
auto memBuffer = llvm::MemoryBuffer::getMemBuffer(contents, uri.file());
if (!memBuffer) {
lsp::Logger::error("Failed to create memory buffer for file", uri.file());
llvm::lsp::Logger::error("Failed to create memory buffer for file",
uri.file());
return;
}
sourceMgr.setIncludeDirs(includeDirs);
@ -442,8 +453,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
// This class provides a context argument for the SourceMgr diagnostic
// handler.
struct DiagHandlerContext {
std::vector<lsp::Diagnostic> &diagnostics;
const lsp::URIForFile &uri;
std::vector<llvm::lsp::Diagnostic> &diagnostics;
const llvm::lsp::URIForFile &uri;
} handlerContext{diagnostics, uri};
// Set the diagnostic handler for the tablegen source manager.
@ -469,9 +480,9 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri,
// TableGenTextFile: Definitions and References
//===----------------------------------------------------------------------===//
void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri,
const lsp::Position &defPos,
std::vector<lsp::Location> &locations) {
void TableGenTextFile::getLocationsOf(
const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &defPos,
std::vector<llvm::lsp::Location> &locations) {
SMLoc posLoc = defPos.getAsSMLoc(sourceMgr);
const TableGenIndexSymbol *symbol = index.lookup(posLoc);
if (!symbol)
@ -492,8 +503,8 @@ void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri,
}
void TableGenTextFile::findReferencesOf(
const lsp::URIForFile &uri, const lsp::Position &pos,
std::vector<lsp::Location> &references) {
const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos,
std::vector<llvm::lsp::Location> &references) {
SMLoc posLoc = pos.getAsSMLoc(sourceMgr);
const TableGenIndexSymbol *symbol = index.lookup(posLoc);
if (!symbol)
@ -508,8 +519,9 @@ void TableGenTextFile::findReferencesOf(
// TableGenTextFile: Document Links
//===--------------------------------------------------------------------===//
void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri,
std::vector<lsp::DocumentLink> &links) {
void TableGenTextFile::getDocumentLinks(
const llvm::lsp::URIForFile &uri,
std::vector<llvm::lsp::DocumentLink> &links) {
for (const lsp::SourceMgrInclude &include : parsedIncludes)
links.emplace_back(include.range, include.uri);
}
@ -518,9 +530,9 @@ void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri,
// TableGenTextFile: Hover
//===----------------------------------------------------------------------===//
std::optional<lsp::Hover>
TableGenTextFile::findHover(const lsp::URIForFile &uri,
const lsp::Position &hoverPos) {
std::optional<llvm::lsp::Hover>
TableGenTextFile::findHover(const llvm::lsp::URIForFile &uri,
const llvm::lsp::Position &hoverPos) {
// Check for a reference to an include.
for (const lsp::SourceMgrInclude &include : parsedIncludes)
if (include.range.contains(hoverPos))
@ -546,9 +558,10 @@ TableGenTextFile::findHover(const lsp::URIForFile &uri,
return buildHoverForField(recordVal->record, value, hoverRange);
}
lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record,
const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
llvm::lsp::Hover
TableGenTextFile::buildHoverForRecord(const Record *record,
const SMRange &hoverRange) {
llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
@ -590,9 +603,9 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record,
return hover;
}
lsp::Hover TableGenTextFile::buildHoverForTemplateArg(
llvm::lsp::Hover TableGenTextFile::buildHoverForTemplateArg(
const Record *record, const RecordVal *value, const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
StringRef name = value->getName().rsplit(':').second;
@ -604,10 +617,9 @@ lsp::Hover TableGenTextFile::buildHoverForTemplateArg(
return hover;
}
lsp::Hover TableGenTextFile::buildHoverForField(const Record *record,
const RecordVal *value,
const SMRange &hoverRange) {
lsp::Hover hover(lsp::Range(sourceMgr, hoverRange));
llvm::lsp::Hover TableGenTextFile::buildHoverForField(
const Record *record, const RecordVal *value, const SMRange &hoverRange) {
llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange));
{
llvm::raw_string_ostream hoverOS(hover.contents.value);
hoverOS << "**field** `" << value->getName() << "`\n***\nType: `";
@ -722,7 +734,7 @@ void lsp::TableGenServer::getDocumentLinks(
return fileIt->second->getDocumentLinks(uri, documentLinks);
}
std::optional<lsp::Hover>
std::optional<llvm::lsp::Hover>
lsp::TableGenServer::findHover(const URIForFile &uri,
const Position &hoverPos) {
auto fileIt = impl->files.find(uri.file());

View File

@ -11,6 +11,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LSP/Protocol.h"
#include <memory>
#include <optional>
#include <string>
@ -18,13 +19,13 @@
namespace mlir {
namespace lsp {
struct Diagnostic;
struct DocumentLink;
struct Hover;
struct Location;
struct Position;
struct TextDocumentContentChangeEvent;
class URIForFile;
using llvm::lsp::Diagnostic;
using llvm::lsp::DocumentLink;
using llvm::lsp::Hover;
using llvm::lsp::Location;
using llvm::lsp::Position;
using llvm::lsp::TextDocumentContentChangeEvent;
using llvm::lsp::URIForFile;
/// This class implements all of the TableGen related functionality necessary
/// for a language server. This class allows for keeping the TableGen specific

View File

@ -10,8 +10,8 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllExtensions.h"
#include "mlir/Tools/lsp-server-support/Protocol.h"
#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h"
#include "llvm/Support/LSP/Protocol.h"
using namespace mlir;
@ -37,8 +37,8 @@ int main(int argc, char **argv) {
// Returns the registry, except in testing mode when the URI contains
// "-disable-lsp-registration". Testing for/example of registering dialects
// based on URI.
auto registryFn = [&registry,
&empty](const lsp::URIForFile &uri) -> DialectRegistry & {
auto registryFn = [&registry, &empty](
const llvm::lsp::URIForFile &uri) -> DialectRegistry & {
(void)empty;
#ifdef MLIR_INCLUDE_TESTS
if (uri.uri().contains("-disable-lsp-registration"))

View File

@ -18,7 +18,6 @@ add_subdirectory(Support)
add_subdirectory(Rewrite)
add_subdirectory(TableGen)
add_subdirectory(Target)
add_subdirectory(Tools)
add_subdirectory(Transforms)
if(MLIR_ENABLE_EXECUTION_ENGINE)

View File

@ -1 +0,0 @@
add_subdirectory(lsp-server-support)

View File

@ -1,7 +0,0 @@
add_mlir_unittest(MLIRLspServerSupportTests
Protocol.cpp
Transport.cpp
)
mlir_target_link_libraries(MLIRLspServerSupportTests
PRIVATE
MLIRLspServerSupportLib)