From a3a25996b11401f7589d1429225dc048d8720da9 Mon Sep 17 00:00:00 2001 From: Bertik23 <39457484+Bertik23@users.noreply.github.com> Date: Thu, 11 Sep 2025 19:17:52 +0200 Subject: [PATCH] [LLVM][MLIR] Move LSP server support library from MLIR into LLVM (#157885) This is a second PR on this patch (first #155572), that fixes the linking problem for `flang-aarch64-dylib` test. The SupportLSP library was made a component library. --- This PR moves the generic Language Server Protocol (LSP) server support code that was copied from clangd into MLIR, into the LLVM tree so it can be reused by multiple subprojects. Centralizing the generic LSP support in LLVM lowers the barrier to building new LSP servers across the LLVM ecosystem and avoids each subproject maintaining its own copy. The code originated in clangd and was copied into MLIR for its LSP server. MLIR had this code seperate to be reused by all of their LSP server. This PR relocates the MLIR copy into LLVM as a shared component into LLVM/Support. If this is not a suitable place, please suggest a better one. A follow up to this move could be deduplication with the original clangd implementation and converge on a single shared LSP support library used by clangd, MLIR, and future servers. What changes mlir/include/mlir/Tools/lsp-server-support/{Logging, Protocol, Transport}.h moved to llvm/include/llvm/Support/LSP mlir/lib/Tools/lsp-server-support/{Logging, Protocol, Transport}.cpp moved to llvm/lib/Support/LSP and their namespace was changed from mlir to llvm I ran clang-tidy --fix and clang-format on the whole moved files (last two commits), as they are basically new files and should hold up to the code style used by LLVM. MLIR LSP servers where updated to include these files from their new location and account for the namespace change. This PR is made as part of the LLVM IR LSP project ([RFC](https://discourse.llvm.org/t/rfc-ir-visualization-with-vs-code-extension-using-an-lsp-server/87773)) --- .../include/llvm/Support/LSP}/Logging.h | 38 +- .../include/llvm/Support/LSP}/Protocol.h | 23 +- llvm/include/llvm/Support/LSP/Transport.h | 289 +++++ llvm/lib/Support/CMakeLists.txt | 1 + llvm/lib/Support/LSP/CMakeLists.txt | 8 + .../lib/Support/LSP}/Logging.cpp | 28 +- llvm/lib/Support/LSP/Protocol.cpp | 1043 +++++++++++++++++ llvm/lib/Support/LSP/Transport.cpp | 369 ++++++ llvm/unittests/Support/CMakeLists.txt | 2 + llvm/unittests/Support/LSP/CMakeLists.txt | 8 + .../unittests/Support/LSP}/Protocol.cpp | 6 +- .../unittests/Support/LSP}/Transport.cpp | 12 +- .../Tools/lsp-server-support/SourceMgrUtils.h | 12 +- .../mlir/Tools/lsp-server-support/Transport.h | 283 ----- .../mlir-lsp-server/MlirLspRegistryFunction.h | 6 +- .../Tools/lsp-server-support/CMakeLists.txt | 8 +- .../CompilationDatabase.cpp | 5 +- .../lib/Tools/lsp-server-support/Protocol.cpp | 1043 ----------------- .../lsp-server-support/SourceMgrUtils.cpp | 4 + .../Tools/lsp-server-support/Transport.cpp | 369 ------ mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt | 3 + mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp | 31 +- mlir/lib/Tools/mlir-lsp-server/LSPServer.h | 6 +- mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp | 133 ++- mlir/lib/Tools/mlir-lsp-server/MLIRServer.h | 22 +- .../mlir-lsp-server/MlirLspServerMain.cpp | 8 +- mlir/lib/Tools/mlir-lsp-server/Protocol.cpp | 7 +- mlir/lib/Tools/mlir-lsp-server/Protocol.h | 6 +- .../Tools/mlir-pdll-lsp-server/CMakeLists.txt | 3 + .../Tools/mlir-pdll-lsp-server/LSPServer.cpp | 29 +- .../Tools/mlir-pdll-lsp-server/LSPServer.h | 6 +- .../MlirPdllLspServerMain.cpp | 10 +- .../Tools/mlir-pdll-lsp-server/PDLLServer.cpp | 561 ++++----- .../Tools/mlir-pdll-lsp-server/PDLLServer.h | 26 +- .../Tools/mlir-pdll-lsp-server/Protocol.cpp | 1 + .../lib/Tools/mlir-pdll-lsp-server/Protocol.h | 4 +- .../Tools/tblgen-lsp-server/CMakeLists.txt | 1 + .../lib/Tools/tblgen-lsp-server/LSPServer.cpp | 25 +- mlir/lib/Tools/tblgen-lsp-server/LSPServer.h | 6 +- .../TableGenLspServerMain.cpp | 8 +- .../tblgen-lsp-server/TableGenServer.cpp | 162 +-- .../Tools/tblgen-lsp-server/TableGenServer.h | 15 +- .../tools/mlir-lsp-server/mlir-lsp-server.cpp | 6 +- mlir/unittests/CMakeLists.txt | 1 - mlir/unittests/Tools/CMakeLists.txt | 1 - .../Tools/lsp-server-support/CMakeLists.txt | 7 - 46 files changed, 2413 insertions(+), 2232 deletions(-) rename {mlir/include/mlir/Tools/lsp-server-support => llvm/include/llvm/Support/LSP}/Logging.h (55%) rename {mlir/include/mlir/Tools/lsp-server-support => llvm/include/llvm/Support/LSP}/Protocol.h (98%) create mode 100644 llvm/include/llvm/Support/LSP/Transport.h create mode 100644 llvm/lib/Support/LSP/CMakeLists.txt rename {mlir/lib/Tools/lsp-server-support => llvm/lib/Support/LSP}/Logging.cpp (55%) create mode 100644 llvm/lib/Support/LSP/Protocol.cpp create mode 100644 llvm/lib/Support/LSP/Transport.cpp create mode 100644 llvm/unittests/Support/LSP/CMakeLists.txt rename {mlir/unittests/Tools/lsp-server-support => llvm/unittests/Support/LSP}/Protocol.cpp (93%) rename {mlir/unittests/Tools/lsp-server-support => llvm/unittests/Support/LSP}/Transport.cpp (96%) delete mode 100644 mlir/include/mlir/Tools/lsp-server-support/Transport.h delete mode 100644 mlir/lib/Tools/lsp-server-support/Protocol.cpp delete mode 100644 mlir/lib/Tools/lsp-server-support/Transport.cpp delete mode 100644 mlir/unittests/Tools/CMakeLists.txt delete mode 100644 mlir/unittests/Tools/lsp-server-support/CMakeLists.txt diff --git a/mlir/include/mlir/Tools/lsp-server-support/Logging.h b/llvm/include/llvm/Support/LSP/Logging.h similarity index 55% rename from mlir/include/mlir/Tools/lsp-server-support/Logging.h rename to llvm/include/llvm/Support/LSP/Logging.h index 9b090d05f7fa..fe65899b1d4c 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/Logging.h +++ b/llvm/include/llvm/Support/LSP/Logging.h @@ -1,4 +1,4 @@ -//===- Logging.h - MLIR LSP Server Logging ----------------------*- C++ -*-===// +//===- Logging.h - LSP Server Logging ----------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,16 +6,15 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H +#ifndef LLVM_SUPPORT_LSP_LOGGING_H +#define LLVM_SUPPORT_LSP_LOGGING_H -#include "mlir/Support/LLVM.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include #include -namespace mlir { +namespace llvm { namespace lsp { /// This class represents the main interface for logging, and allows for @@ -26,21 +25,18 @@ public: enum class Level { Debug, Info, Error }; /// Set the severity level of the logger. - static void setLogLevel(Level logLevel); + static void setLogLevel(Level LogLevel); /// Initiate a log message at various severity levels. These should be called /// after a call to `initialize`. - template - static void debug(const char *fmt, Ts &&...vals) { - log(Level::Debug, fmt, llvm::formatv(fmt, std::forward(vals)...)); + template static void debug(const char *Fmt, Ts &&...Vals) { + log(Level::Debug, Fmt, llvm::formatv(Fmt, std::forward(Vals)...)); } - template - static void info(const char *fmt, Ts &&...vals) { - log(Level::Info, fmt, llvm::formatv(fmt, std::forward(vals)...)); + template static void info(const char *Fmt, Ts &&...Vals) { + log(Level::Info, Fmt, llvm::formatv(Fmt, std::forward(Vals)...)); } - template - static void error(const char *fmt, Ts &&...vals) { - log(Level::Error, fmt, llvm::formatv(fmt, std::forward(vals)...)); + template static void error(const char *Fmt, Ts &&...Vals) { + log(Level::Error, Fmt, llvm::formatv(Fmt, std::forward(Vals)...)); } private: @@ -50,16 +46,16 @@ private: static Logger &get(); /// Start a log message with the given severity level. - static void log(Level logLevel, const char *fmt, - const llvm::formatv_object_base &message); + static void log(Level LogLevel, const char *Fmt, + const llvm::formatv_object_base &Message); /// The minimum logging level. Messages with lower level are ignored. - Level logLevel = Level::Error; + Level LogLevel = Level::Error; /// A mutex used to guard logging. - std::mutex mutex; + std::mutex Mutex; }; } // namespace lsp -} // namespace mlir +} // namespace llvm -#endif // MLIR_TOOLS_LSPSERVERSUPPORT_LOGGING_H +#endif // LLVM_SUPPORT_LSP_LOGGING_H diff --git a/mlir/include/mlir/Tools/lsp-server-support/Protocol.h b/llvm/include/llvm/Support/LSP/Protocol.h similarity index 98% rename from mlir/include/mlir/Tools/lsp-server-support/Protocol.h rename to llvm/include/llvm/Support/LSP/Protocol.h index cc06dbfedb42..93b82f1e581f 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/Protocol.h +++ b/llvm/include/llvm/Support/LSP/Protocol.h @@ -20,20 +20,24 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_PROTOCOL_H +#ifndef LLVM_SUPPORT_LSP_PROTOCOL_H +#define LLVM_SUPPORT_LSP_PROTOCOL_H -#include "mlir/Support/LLVM.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include #include #include #include -#include -namespace mlir { +// This file is using the LSP syntax for identifier names which is different +// from the LLVM coding standard. To avoid the clang-tidy warnings, we're +// disabling one check here. +// NOLINTBEGIN(readability-identifier-naming) + +namespace llvm { namespace lsp { enum class ErrorCode { @@ -1241,12 +1245,11 @@ struct CodeAction { llvm::json::Value toJSON(const CodeAction &); } // namespace lsp -} // namespace mlir +} // namespace llvm namespace llvm { -template <> -struct format_provider { - static void format(const mlir::lsp::Position &pos, raw_ostream &os, +template <> struct format_provider { + static void format(const llvm::lsp::Position &pos, raw_ostream &os, StringRef style) { assert(style.empty() && "style modifiers for this type are not supported"); os << pos; @@ -1255,3 +1258,5 @@ struct format_provider { } // namespace llvm #endif + +// NOLINTEND(readability-identifier-naming) diff --git a/llvm/include/llvm/Support/LSP/Transport.h b/llvm/include/llvm/Support/LSP/Transport.h new file mode 100644 index 000000000000..ccd7f213aa27 --- /dev/null +++ b/llvm/include/llvm/Support/LSP/Transport.h @@ -0,0 +1,289 @@ +//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The language server protocol is usually implemented by writing messages as +// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON +// transport interface that handles this communication. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SUPPORT_LSP_TRANSPORT_H +#define LLVM_SUPPORT_LSP_TRANSPORT_H + +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/FormatAdapters.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace llvm { +// Simple helper function that returns a string as printed from a op. +template static std::string debugString(T &&Op) { + std::string InstrStr; + llvm::raw_string_ostream Os(InstrStr); + Os << Op; + return Os.str(); +} +namespace lsp { +class MessageHandler; + +//===----------------------------------------------------------------------===// +// JSONTransport +//===----------------------------------------------------------------------===// + +/// The encoding style of the JSON-RPC messages (both input and output). +enum JSONStreamStyle { + /// Encoding per the LSP specification, with mandatory Content-Length header. + Standard, + /// Messages are delimited by a '// -----' line. Comment lines start with //. + Delimited +}; + +/// An abstract class used by the JSONTransport to read JSON message. +class JSONTransportInput { +public: + explicit JSONTransportInput(JSONStreamStyle Style = JSONStreamStyle::Standard) + : Style(Style) {} + virtual ~JSONTransportInput() = default; + + virtual bool hasError() const = 0; + virtual bool isEndOfInput() const = 0; + + /// Read in a message from the input stream. + LogicalResult readMessage(std::string &Json) { + return Style == JSONStreamStyle::Delimited ? readDelimitedMessage(Json) + : readStandardMessage(Json); + } + virtual LogicalResult readDelimitedMessage(std::string &Json) = 0; + virtual LogicalResult readStandardMessage(std::string &Json) = 0; + +private: + /// The JSON stream style to use. + JSONStreamStyle Style; +}; + +/// Concrete implementation of the JSONTransportInput that reads from a file. +class JSONTransportInputOverFile : public JSONTransportInput { +public: + explicit JSONTransportInputOverFile( + std::FILE *In, JSONStreamStyle Style = JSONStreamStyle::Standard) + : JSONTransportInput(Style), In(In) {} + + bool hasError() const final { return ferror(In); } + bool isEndOfInput() const final { return feof(In); } + + LogicalResult readDelimitedMessage(std::string &Json) final; + LogicalResult readStandardMessage(std::string &Json) final; + +private: + std::FILE *In; +}; + +/// A transport class that performs the JSON-RPC communication with the LSP +/// client. +class JSONTransport { +public: + JSONTransport(std::unique_ptr In, raw_ostream &Out, + bool PrettyOutput = false) + : In(std::move(In)), Out(Out), PrettyOutput(PrettyOutput) {} + + JSONTransport(std::FILE *In, raw_ostream &Out, + JSONStreamStyle Style = JSONStreamStyle::Standard, + bool PrettyOutput = false) + : In(std::make_unique(In, Style)), Out(Out), + PrettyOutput(PrettyOutput) {} + + /// The following methods are used to send a message to the LSP client. + void notify(StringRef Method, llvm::json::Value Params); + void call(StringRef Method, llvm::json::Value Params, llvm::json::Value Id); + void reply(llvm::json::Value Id, llvm::Expected Result); + + /// Start executing the JSON-RPC transport. + llvm::Error run(MessageHandler &Handler); + +private: + /// Dispatches the given incoming json message to the message handler. + bool handleMessage(llvm::json::Value Msg, MessageHandler &Handler); + /// Writes the given message to the output stream. + void sendMessage(llvm::json::Value Msg); + +private: + /// The input to read a message from. + std::unique_ptr In; + SmallVector OutputBuffer; + /// The output file stream. + raw_ostream &Out; + /// If the output JSON should be formatted for easier readability. + bool PrettyOutput; +}; + +//===----------------------------------------------------------------------===// +// MessageHandler +//===----------------------------------------------------------------------===// + +/// A Callback is a void function that accepts Expected. This is +/// accepted by functions that logically return T. +template +using Callback = llvm::unique_function)>; + +/// An OutgoingNotification is a function used for outgoing notifications +/// send to the client. +template +using OutgoingNotification = llvm::unique_function; + +/// An OutgoingRequest is a function used for outgoing requests to send to +/// the client. +template +using OutgoingRequest = + llvm::unique_function; + +/// An `OutgoingRequestCallback` is invoked when an outgoing request to the +/// client receives a response in turn. It is passed the original request's ID, +/// as well as the response result. +template +using OutgoingRequestCallback = + std::function)>; + +/// A handler used to process the incoming transport messages. +class MessageHandler { +public: + MessageHandler(JSONTransport &Transport) : Transport(Transport) {} + + bool onNotify(StringRef Method, llvm::json::Value Value); + bool onCall(StringRef Method, llvm::json::Value Params, llvm::json::Value Id); + bool onReply(llvm::json::Value Id, llvm::Expected Result); + + template + static llvm::Expected parse(const llvm::json::Value &Raw, + StringRef PayloadName, StringRef PayloadKind) { + T Result; + llvm::json::Path::Root Root; + if (fromJSON(Raw, Result, Root)) + return std::move(Result); + + // Dump the relevant parts of the broken message. + std::string Context; + llvm::raw_string_ostream Os(Context); + Root.printErrorContext(Raw, Os); + + // Report the error (e.g. to the client). + return llvm::make_error( + llvm::formatv("failed to decode {0} {1}: {2}", PayloadName, PayloadKind, + fmt_consume(Root.getError())), + ErrorCode::InvalidParams); + } + + template + void method(llvm::StringLiteral Method, ThisT *ThisPtr, + void (ThisT::*Handler)(const Param &, Callback)) { + MethodHandlers[Method] = [Method, Handler, + ThisPtr](llvm::json::Value RawParams, + Callback Reply) { + llvm::Expected Parameter = + parse(RawParams, Method, "request"); + if (!Parameter) + return Reply(Parameter.takeError()); + (ThisPtr->*Handler)(*Parameter, std::move(Reply)); + }; + } + + template + void notification(llvm::StringLiteral Method, ThisT *ThisPtr, + void (ThisT::*Handler)(const Param &)) { + NotificationHandlers[Method] = [Method, Handler, + ThisPtr](llvm::json::Value RawParams) { + llvm::Expected Parameter = + parse(RawParams, Method, "notification"); + if (!Parameter) { + return llvm::consumeError(llvm::handleErrors( + Parameter.takeError(), [](const LSPError &LspError) { + Logger::error("JSON parsing error: {0}", + LspError.message.c_str()); + })); + } + (ThisPtr->*Handler)(*Parameter); + }; + } + + /// Create an OutgoingNotification object used for the given method. + template + OutgoingNotification outgoingNotification(llvm::StringLiteral Method) { + return [&, Method](const T &Params) { + std::lock_guard TransportLock(TransportOutputMutex); + Logger::info("--> {0}", Method); + Transport.notify(Method, llvm::json::Value(Params)); + }; + } + + /// Create an OutgoingRequest function that, when called, sends a request with + /// the given method via the transport. Should the outgoing request be + /// met with a response, the result JSON is parsed and the response callback + /// is invoked. + template + OutgoingRequest + outgoingRequest(llvm::StringLiteral Method, + OutgoingRequestCallback Callback) { + return [&, Method, Callback](const Param &Parameter, llvm::json::Value Id) { + auto CallbackWrapper = [Method, Callback = std::move(Callback)]( + llvm::json::Value Id, + llvm::Expected Value) { + if (!Value) + return Callback(std::move(Id), Value.takeError()); + + std::string ResponseName = llvm::formatv("reply:{0}({1})", Method, Id); + llvm::Expected ParseResult = + parse(*Value, ResponseName, "response"); + if (!ParseResult) + return Callback(std::move(Id), ParseResult.takeError()); + + return Callback(std::move(Id), *ParseResult); + }; + + { + std::lock_guard Lock(ResponseHandlersMutex); + ResponseHandlers.insert( + {debugString(Id), std::make_pair(Method.str(), CallbackWrapper)}); + } + + std::lock_guard TransportLock(TransportOutputMutex); + Logger::info("--> {0}({1})", Method, Id); + Transport.call(Method, llvm::json::Value(Parameter), Id); + }; + } + +private: + template + using HandlerMap = llvm::StringMap>; + + HandlerMap NotificationHandlers; + HandlerMap)> + MethodHandlers; + + /// A pair of (1) the original request's method name, and (2) the callback + /// function to be invoked for responses. + using ResponseHandlerTy = + std::pair>; + /// A mapping from request/response ID to response handler. + llvm::StringMap ResponseHandlers; + /// Mutex to guard insertion into the response handler map. + std::mutex ResponseHandlersMutex; + + JSONTransport &Transport; + + /// Mutex to guard sending output messages to the transport. + std::mutex TransportOutputMutex; +}; + +} // namespace lsp +} // namespace llvm + +#endif diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt index 2528e8bd1142..7da972f372c5 100644 --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -135,6 +135,7 @@ if (UNIX AND "${CMAKE_SYSTEM_NAME}" MATCHES "AIX") endif() add_subdirectory(BLAKE3) +add_subdirectory(LSP) add_llvm_component_library(LLVMSupport ABIBreak.cpp diff --git a/llvm/lib/Support/LSP/CMakeLists.txt b/llvm/lib/Support/LSP/CMakeLists.txt new file mode 100644 index 000000000000..6094d9ac315c --- /dev/null +++ b/llvm/lib/Support/LSP/CMakeLists.txt @@ -0,0 +1,8 @@ +add_llvm_component_library(LLVMSupportLSP + Protocol.cpp + Transport.cpp + Logging.cpp + + DEPENDS + LLVMSupport +) diff --git a/mlir/lib/Tools/lsp-server-support/Logging.cpp b/llvm/lib/Support/LSP/Logging.cpp similarity index 55% rename from mlir/lib/Tools/lsp-server-support/Logging.cpp rename to llvm/lib/Support/LSP/Logging.cpp index 373e2165c244..b36621ae1c6c 100644 --- a/mlir/lib/Tools/lsp-server-support/Logging.cpp +++ b/llvm/lib/Support/LSP/Logging.cpp @@ -6,36 +6,36 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Tools/lsp-server-support/Logging.h" +#include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Chrono.h" #include "llvm/Support/raw_ostream.h" -using namespace mlir; -using namespace mlir::lsp; +using namespace llvm; +using namespace llvm::lsp; -void Logger::setLogLevel(Level logLevel) { get().logLevel = logLevel; } +void Logger::setLogLevel(Level LogLevel) { get().LogLevel = LogLevel; } Logger &Logger::get() { - static Logger logger; - return logger; + static Logger Logger; + return Logger; } -void Logger::log(Level logLevel, const char *fmt, - const llvm::formatv_object_base &message) { - Logger &logger = get(); +void Logger::log(Level LogLevel, const char *Fmt, + const llvm::formatv_object_base &Message) { + Logger &Logger = get(); // Ignore messages with log levels below the current setting in the logger. - if (logLevel < logger.logLevel) + if (LogLevel < Logger.LogLevel) return; // An indicator character for each log level. - const char *logLevelIndicators = "DIE"; + const char *LogLevelIndicators = "DIE"; // Format the message and print to errs. - llvm::sys::TimePoint<> timestamp = std::chrono::system_clock::now(); - std::lock_guard logGuard(logger.mutex); + llvm::sys::TimePoint<> Timestamp = std::chrono::system_clock::now(); + std::lock_guard LogGuard(Logger.Mutex); llvm::errs() << llvm::formatv( "{0}[{1:%H:%M:%S.%L}] {2}\n", - logLevelIndicators[static_cast(logLevel)], timestamp, message); + LogLevelIndicators[static_cast(LogLevel)], Timestamp, Message); llvm::errs().flush(); } diff --git a/llvm/lib/Support/LSP/Protocol.cpp b/llvm/lib/Support/LSP/Protocol.cpp new file mode 100644 index 000000000000..f22126345a43 --- /dev/null +++ b/llvm/lib/Support/LSP/Protocol.cpp @@ -0,0 +1,1043 @@ +//===--- Protocol.cpp - Language Server Protocol Implementation -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the serialization code for the LSP structs. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace llvm::lsp; + +// Helper that doesn't treat `null` and absent fields as failures. +template +static bool mapOptOrNull(const llvm::json::Value &Params, + llvm::StringLiteral Prop, T &Out, + llvm::json::Path Path) { + const llvm::json::Object *O = Params.getAsObject(); + assert(O); + + // Field is missing or null. + auto *V = O->get(Prop); + if (!V || V->getAsNull()) + return true; + return fromJSON(*V, Out, Path.field(Prop)); +} + +//===----------------------------------------------------------------------===// +// LSPError +//===----------------------------------------------------------------------===// + +char LSPError::ID; + +//===----------------------------------------------------------------------===// +// URIForFile +//===----------------------------------------------------------------------===// + +static bool isWindowsPath(StringRef Path) { + return Path.size() > 1 && llvm::isAlpha(Path[0]) && Path[1] == ':'; +} + +static bool isNetworkPath(StringRef Path) { + return Path.size() > 2 && Path[0] == Path[1] && + llvm::sys::path::is_separator(Path[0]); +} + +static bool shouldEscapeInURI(unsigned char C) { + // Unreserved characters. + if ((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') || + (C >= '0' && C <= '9')) + return false; + + switch (C) { + case '-': + case '_': + case '.': + case '~': + // '/' is only reserved when parsing. + case '/': + // ':' is only reserved for relative URI paths, which we doesn't produce. + case ':': + return false; + } + return true; +} + +/// Encodes a string according to percent-encoding. +/// - Unreserved characters are not escaped. +/// - Reserved characters always escaped with exceptions like '/'. +/// - All other characters are escaped. +static void percentEncode(StringRef Content, std::string &Out) { + for (unsigned char C : Content) { + if (shouldEscapeInURI(C)) { + Out.push_back('%'); + Out.push_back(llvm::hexdigit(C / 16)); + Out.push_back(llvm::hexdigit(C % 16)); + } else { + Out.push_back(C); + } + } +} + +/// Decodes a string according to percent-encoding. +static std::string percentDecode(StringRef Content) { + std::string Result; + for (auto I = Content.begin(), E = Content.end(); I != E; ++I) { + if (*I != '%') { + Result += *I; + continue; + } + if (*I == '%' && I + 2 < Content.end() && llvm::isHexDigit(*(I + 1)) && + llvm::isHexDigit(*(I + 2))) { + Result.push_back(llvm::hexFromNibbles(*(I + 1), *(I + 2))); + I += 2; + } else { + Result.push_back(*I); + } + } + return Result; +} + +/// Return the set containing the supported URI schemes. +static StringSet<> &getSupportedSchemes() { + static StringSet<> Schemes({"file", "test"}); + return Schemes; +} + +/// Returns true if the given scheme is structurally valid, i.e. it does not +/// contain any invalid scheme characters. This does not check that the scheme +/// is actually supported. +static bool isStructurallyValidScheme(StringRef Scheme) { + if (Scheme.empty()) + return false; + if (!llvm::isAlpha(Scheme[0])) + return false; + return llvm::all_of(llvm::drop_begin(Scheme), [](char C) { + return llvm::isAlnum(C) || C == '+' || C == '.' || C == '-'; + }); +} + +static llvm::Expected uriFromAbsolutePath(StringRef AbsolutePath, + StringRef Scheme) { + std::string Body; + StringRef Authority; + StringRef Root = llvm::sys::path::root_name(AbsolutePath); + if (isNetworkPath(Root)) { + // Windows UNC paths e.g. \\server\share => file://server/share + Authority = Root.drop_front(2); + AbsolutePath.consume_front(Root); + } else if (isWindowsPath(Root)) { + // Windows paths e.g. X:\path => file:///X:/path + Body = "/"; + } + Body += llvm::sys::path::convert_to_slash(AbsolutePath); + + std::string Uri = Scheme.str() + ":"; + if (Authority.empty() && Body.empty()) + return Uri; + + // If authority if empty, we only print body if it starts with "/"; otherwise, + // the URI is invalid. + if (!Authority.empty() || StringRef(Body).starts_with("/")) { + Uri.append("//"); + percentEncode(Authority, Uri); + } + percentEncode(Body, Uri); + return Uri; +} + +static llvm::Expected getAbsolutePath(StringRef Authority, + StringRef Body) { + if (!Body.starts_with("/")) + return llvm::createStringError( + llvm::inconvertibleErrorCode(), + "File scheme: expect body to be an absolute path starting " + "with '/': " + + Body); + SmallString<128> Path; + if (!Authority.empty()) { + // Windows UNC paths e.g. file://server/share => \\server\share + ("//" + Authority).toVector(Path); + } else if (isWindowsPath(Body.substr(1))) { + // Windows paths e.g. file:///X:/path => X:\path + Body.consume_front("/"); + } + Path.append(Body); + llvm::sys::path::native(Path); + return std::string(Path); +} + +static llvm::Expected parseFilePathFromURI(StringRef OrigUri) { + StringRef Uri = OrigUri; + + // Decode the scheme of the URI. + size_t Pos = Uri.find(':'); + if (Pos == StringRef::npos) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Scheme must be provided in URI: " + + OrigUri); + StringRef SchemeStr = Uri.substr(0, Pos); + std::string UriScheme = percentDecode(SchemeStr); + if (!isStructurallyValidScheme(UriScheme)) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "Invalid scheme: " + SchemeStr + + " (decoded: " + UriScheme + ")"); + Uri = Uri.substr(Pos + 1); + + // Decode the authority of the URI. + std::string UriAuthority; + if (Uri.consume_front("//")) { + Pos = Uri.find('/'); + UriAuthority = percentDecode(Uri.substr(0, Pos)); + Uri = Uri.substr(Pos); + } + + // Decode the body of the URI. + std::string UriBody = percentDecode(Uri); + + // Compute the absolute path for this uri. + if (!getSupportedSchemes().contains(UriScheme)) { + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "unsupported URI scheme `" + UriScheme + + "' for workspace files"); + } + return getAbsolutePath(UriAuthority, UriBody); +} + +llvm::Expected URIForFile::fromURI(StringRef Uri) { + llvm::Expected FilePath = parseFilePathFromURI(Uri); + if (!FilePath) + return FilePath.takeError(); + return URIForFile(std::move(*FilePath), Uri.str()); +} + +llvm::Expected URIForFile::fromFile(StringRef AbsoluteFilepath, + StringRef Scheme) { + llvm::Expected Uri = + uriFromAbsolutePath(AbsoluteFilepath, Scheme); + if (!Uri) + return Uri.takeError(); + return fromURI(*Uri); +} + +StringRef URIForFile::scheme() const { return uri().split(':').first; } + +void URIForFile::registerSupportedScheme(StringRef Scheme) { + getSupportedSchemes().insert(Scheme); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, URIForFile &Result, + llvm::json::Path Path) { + if (std::optional Str = Value.getAsString()) { + llvm::Expected ExpectedUri = URIForFile::fromURI(*Str); + if (!ExpectedUri) { + Path.report("unresolvable URI"); + consumeError(ExpectedUri.takeError()); + return false; + } + Result = std::move(*ExpectedUri); + return true; + } + return false; +} + +llvm::json::Value llvm::lsp::toJSON(const URIForFile &Value) { + return Value.uri(); +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const URIForFile &Value) { + return Os << Value.uri(); +} + +//===----------------------------------------------------------------------===// +// ClientCapabilities +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + ClientCapabilities &Result, llvm::json::Path Path) { + const llvm::json::Object *O = Value.getAsObject(); + if (!O) { + Path.report("expected object"); + return false; + } + if (const llvm::json::Object *TextDocument = O->getObject("textDocument")) { + if (const llvm::json::Object *DocumentSymbol = + TextDocument->getObject("documentSymbol")) { + if (std::optional HierarchicalSupport = + DocumentSymbol->getBoolean("hierarchicalDocumentSymbolSupport")) + Result.hierarchicalDocumentSymbol = *HierarchicalSupport; + } + if (auto *CodeAction = TextDocument->getObject("codeAction")) { + if (CodeAction->getObject("codeActionLiteralSupport")) + Result.codeActionStructure = true; + } + } + if (auto *Window = O->getObject("window")) { + if (std::optional WorkDoneProgressSupport = + Window->getBoolean("workDoneProgress")) + Result.workDoneProgress = *WorkDoneProgressSupport; + } + return true; +} + +//===----------------------------------------------------------------------===// +// ClientInfo +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, ClientInfo &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + if (!O || !O.map("name", Result.name)) + return false; + + // Don't fail if we can't parse version. + O.map("version", Result.version); + return true; +} + +//===----------------------------------------------------------------------===// +// InitializeParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, TraceLevel &Result, + llvm::json::Path Path) { + if (std::optional Str = Value.getAsString()) { + if (*Str == "off") { + Result = TraceLevel::Off; + return true; + } + if (*Str == "messages") { + Result = TraceLevel::Messages; + return true; + } + if (*Str == "verbose") { + Result = TraceLevel::Verbose; + return true; + } + } + return false; +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + InitializeParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + if (!O) + return false; + // We deliberately don't fail if we can't parse individual fields. + O.map("capabilities", Result.capabilities); + O.map("trace", Result.trace); + mapOptOrNull(Value, "clientInfo", Result.clientInfo, Path); + + return true; +} + +//===----------------------------------------------------------------------===// +// TextDocumentItem +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + TextDocumentItem &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("uri", Result.uri) && + O.map("languageId", Result.languageId) && O.map("text", Result.text) && + O.map("version", Result.version); +} + +//===----------------------------------------------------------------------===// +// TextDocumentIdentifier +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const TextDocumentIdentifier &Value) { + return llvm::json::Object{{"uri", Value.uri}}; +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + TextDocumentIdentifier &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("uri", Result.uri); +} + +//===----------------------------------------------------------------------===// +// VersionedTextDocumentIdentifier +//===----------------------------------------------------------------------===// + +llvm::json::Value +llvm::lsp::toJSON(const VersionedTextDocumentIdentifier &Value) { + return llvm::json::Object{ + {"uri", Value.uri}, + {"version", Value.version}, + }; +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + VersionedTextDocumentIdentifier &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("uri", Result.uri) && O.map("version", Result.version); +} + +//===----------------------------------------------------------------------===// +// Position +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, Position &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("line", Result.line) && + O.map("character", Result.character); +} + +llvm::json::Value llvm::lsp::toJSON(const Position &Value) { + return llvm::json::Object{ + {"line", Value.line}, + {"character", Value.character}, + }; +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const Position &Value) { + return Os << Value.line << ':' << Value.character; +} + +//===----------------------------------------------------------------------===// +// Range +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, Range &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("start", Result.start) && O.map("end", Result.end); +} + +llvm::json::Value llvm::lsp::toJSON(const Range &Value) { + return llvm::json::Object{ + {"start", Value.start}, + {"end", Value.end}, + }; +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const Range &Value) { + return Os << Value.start << '-' << Value.end; +} + +//===----------------------------------------------------------------------===// +// Location +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, Location &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("uri", Result.uri) && O.map("range", Result.range); +} + +llvm::json::Value llvm::lsp::toJSON(const Location &Value) { + return llvm::json::Object{ + {"uri", Value.uri}, + {"range", Value.range}, + }; +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const Location &Value) { + return Os << Value.range << '@' << Value.uri; +} + +//===----------------------------------------------------------------------===// +// TextDocumentPositionParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + TextDocumentPositionParams &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument) && + O.map("position", Result.position); +} + +//===----------------------------------------------------------------------===// +// ReferenceParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + ReferenceContext &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.mapOptional("includeDeclaration", Result.includeDeclaration); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + ReferenceParams &Result, llvm::json::Path Path) { + TextDocumentPositionParams &Base = Result; + llvm::json::ObjectMapper O(Value, Path); + return fromJSON(Value, Base, Path) && O && + O.mapOptional("context", Result.context); +} + +//===----------------------------------------------------------------------===// +// DidOpenTextDocumentParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DidOpenTextDocumentParams &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DidCloseTextDocumentParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DidCloseTextDocumentParams &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DidChangeTextDocumentParams +//===----------------------------------------------------------------------===// + +LogicalResult +TextDocumentContentChangeEvent::applyTo(std::string &Contents) const { + // If there is no range, the full document changed. + if (!range) { + Contents = text; + return success(); + } + + // Try to map the replacement range to the content. + llvm::SourceMgr TmpScrMgr; + TmpScrMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(Contents), + SMLoc()); + SMRange RangeLoc = range->getAsSMRange(TmpScrMgr); + if (!RangeLoc.isValid()) + return failure(); + + Contents.replace(RangeLoc.Start.getPointer() - Contents.data(), + RangeLoc.End.getPointer() - RangeLoc.Start.getPointer(), + text); + return success(); +} + +LogicalResult TextDocumentContentChangeEvent::applyTo( + ArrayRef Changes, std::string &Contents) { + for (const auto &Change : Changes) + if (failed(Change.applyTo(Contents))) + return failure(); + return success(); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + TextDocumentContentChangeEvent &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("range", Result.range) && + O.map("rangeLength", Result.rangeLength) && O.map("text", Result.text); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DidChangeTextDocumentParams &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument) && + O.map("contentChanges", Result.contentChanges); +} + +//===----------------------------------------------------------------------===// +// MarkupContent +//===----------------------------------------------------------------------===// + +static llvm::StringRef toTextKind(MarkupKind Kind) { + switch (Kind) { + case MarkupKind::PlainText: + return "plaintext"; + case MarkupKind::Markdown: + return "markdown"; + } + llvm_unreachable("Invalid MarkupKind"); +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, MarkupKind Kind) { + return Os << toTextKind(Kind); +} + +llvm::json::Value llvm::lsp::toJSON(const MarkupContent &Mc) { + if (Mc.value.empty()) + return nullptr; + + return llvm::json::Object{ + {"kind", toTextKind(Mc.kind)}, + {"value", Mc.value}, + }; +} + +//===----------------------------------------------------------------------===// +// Hover +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const Hover &Hover) { + llvm::json::Object Result{{"contents", toJSON(Hover.contents)}}; + if (Hover.range) + Result["range"] = toJSON(*Hover.range); + return std::move(Result); +} + +//===----------------------------------------------------------------------===// +// DocumentSymbol +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const DocumentSymbol &Symbol) { + llvm::json::Object Result{{"name", Symbol.name}, + {"kind", static_cast(Symbol.kind)}, + {"range", Symbol.range}, + {"selectionRange", Symbol.selectionRange}}; + + if (!Symbol.detail.empty()) + Result["detail"] = Symbol.detail; + if (!Symbol.children.empty()) + Result["children"] = Symbol.children; + return std::move(Result); +} + +//===----------------------------------------------------------------------===// +// DocumentSymbolParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DocumentSymbolParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DiagnosticRelatedInformation +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DiagnosticRelatedInformation &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("location", Result.location) && + O.map("message", Result.message); +} + +llvm::json::Value llvm::lsp::toJSON(const DiagnosticRelatedInformation &Info) { + return llvm::json::Object{ + {"location", Info.location}, + {"message", Info.message}, + }; +} + +//===----------------------------------------------------------------------===// +// Diagnostic +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(DiagnosticTag Tag) { + return static_cast(Tag); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, DiagnosticTag &Result, + llvm::json::Path Path) { + if (std::optional I = Value.getAsInteger()) { + Result = (DiagnosticTag)*I; + return true; + } + + return false; +} + +llvm::json::Value llvm::lsp::toJSON(const Diagnostic &Diag) { + llvm::json::Object Result{ + {"range", Diag.range}, + {"severity", (int)Diag.severity}, + {"message", Diag.message}, + }; + if (Diag.category) + Result["category"] = *Diag.category; + if (!Diag.source.empty()) + Result["source"] = Diag.source; + if (Diag.relatedInformation) + Result["relatedInformation"] = *Diag.relatedInformation; + if (!Diag.tags.empty()) + Result["tags"] = Diag.tags; + return std::move(Result); +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, Diagnostic &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + if (!O) + return false; + int Severity = 0; + if (!mapOptOrNull(Value, "severity", Severity, Path)) + return false; + Result.severity = (DiagnosticSeverity)Severity; + + return O.map("range", Result.range) && O.map("message", Result.message) && + mapOptOrNull(Value, "category", Result.category, Path) && + mapOptOrNull(Value, "source", Result.source, Path) && + mapOptOrNull(Value, "relatedInformation", Result.relatedInformation, + Path) && + mapOptOrNull(Value, "tags", Result.tags, Path); +} + +//===----------------------------------------------------------------------===// +// PublishDiagnosticsParams +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const PublishDiagnosticsParams &Params) { + return llvm::json::Object{ + {"uri", Params.uri}, + {"diagnostics", Params.diagnostics}, + {"version", Params.version}, + }; +} + +//===----------------------------------------------------------------------===// +// TextEdit +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, TextEdit &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("range", Result.range) && O.map("newText", Result.newText); +} + +llvm::json::Value llvm::lsp::toJSON(const TextEdit &Value) { + return llvm::json::Object{ + {"range", Value.range}, + {"newText", Value.newText}, + }; +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, const TextEdit &Value) { + Os << Value.range << " => \""; + llvm::printEscapedString(Value.newText, Os); + return Os << '"'; +} + +//===----------------------------------------------------------------------===// +// CompletionItemKind +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CompletionItemKind &Result, llvm::json::Path Path) { + if (std::optional IntValue = Value.getAsInteger()) { + if (*IntValue < static_cast(CompletionItemKind::Text) || + *IntValue > static_cast(CompletionItemKind::TypeParameter)) + return false; + Result = static_cast(*IntValue); + return true; + } + return false; +} + +CompletionItemKind llvm::lsp::adjustKindToCapability( + CompletionItemKind Kind, + CompletionItemKindBitset &SupportedCompletionItemKinds) { + size_t KindVal = static_cast(Kind); + if (KindVal >= kCompletionItemKindMin && + KindVal <= SupportedCompletionItemKinds.size() && + SupportedCompletionItemKinds[KindVal]) + return Kind; + + // Provide some fall backs for common kinds that are close enough. + switch (Kind) { + case CompletionItemKind::Folder: + return CompletionItemKind::File; + case CompletionItemKind::EnumMember: + return CompletionItemKind::Enum; + case CompletionItemKind::Struct: + return CompletionItemKind::Class; + default: + return CompletionItemKind::Text; + } +} + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CompletionItemKindBitset &Result, + llvm::json::Path Path) { + if (const llvm::json::Array *ArrayValue = Value.getAsArray()) { + for (size_t I = 0, E = ArrayValue->size(); I < E; ++I) { + CompletionItemKind KindOut; + if (fromJSON((*ArrayValue)[I], KindOut, Path.index(I))) + Result.set(size_t(KindOut)); + } + return true; + } + return false; +} + +//===----------------------------------------------------------------------===// +// CompletionItem +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const CompletionItem &Value) { + assert(!Value.label.empty() && "completion item label is required"); + llvm::json::Object Result{{"label", Value.label}}; + if (Value.kind != CompletionItemKind::Missing) + Result["kind"] = static_cast(Value.kind); + if (!Value.detail.empty()) + Result["detail"] = Value.detail; + if (Value.documentation) + Result["documentation"] = Value.documentation; + if (!Value.sortText.empty()) + Result["sortText"] = Value.sortText; + if (!Value.filterText.empty()) + Result["filterText"] = Value.filterText; + if (!Value.insertText.empty()) + Result["insertText"] = Value.insertText; + if (Value.insertTextFormat != InsertTextFormat::Missing) + Result["insertTextFormat"] = static_cast(Value.insertTextFormat); + if (Value.textEdit) + Result["textEdit"] = *Value.textEdit; + if (!Value.additionalTextEdits.empty()) { + Result["additionalTextEdits"] = + llvm::json::Array(Value.additionalTextEdits); + } + if (Value.deprecated) + Result["deprecated"] = Value.deprecated; + return std::move(Result); +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, + const CompletionItem &Value) { + return Os << Value.label << " - " << toJSON(Value); +} + +bool llvm::lsp::operator<(const CompletionItem &Lhs, + const CompletionItem &Rhs) { + return (Lhs.sortText.empty() ? Lhs.label : Lhs.sortText) < + (Rhs.sortText.empty() ? Rhs.label : Rhs.sortText); +} + +//===----------------------------------------------------------------------===// +// CompletionList +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const CompletionList &Value) { + return llvm::json::Object{ + {"isIncomplete", Value.isIncomplete}, + {"items", llvm::json::Array(Value.items)}, + }; +} + +//===----------------------------------------------------------------------===// +// CompletionContext +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CompletionContext &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + int TriggerKind; + if (!O || !O.map("triggerKind", TriggerKind) || + !mapOptOrNull(Value, "triggerCharacter", Result.triggerCharacter, Path)) + return false; + Result.triggerKind = static_cast(TriggerKind); + return true; +} + +//===----------------------------------------------------------------------===// +// CompletionParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CompletionParams &Result, llvm::json::Path Path) { + if (!fromJSON(Value, static_cast(Result), Path)) + return false; + if (const llvm::json::Value *Context = Value.getAsObject()->get("context")) + return fromJSON(*Context, Result.context, Path.field("context")); + return true; +} + +//===----------------------------------------------------------------------===// +// ParameterInformation +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const ParameterInformation &Value) { + assert((Value.labelOffsets || !Value.labelString.empty()) && + "parameter information label is required"); + llvm::json::Object Result; + if (Value.labelOffsets) + Result["label"] = llvm::json::Array( + {Value.labelOffsets->first, Value.labelOffsets->second}); + else + Result["label"] = Value.labelString; + if (!Value.documentation.empty()) + Result["documentation"] = Value.documentation; + return std::move(Result); +} + +//===----------------------------------------------------------------------===// +// SignatureInformation +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const SignatureInformation &Value) { + assert(!Value.label.empty() && "signature information label is required"); + llvm::json::Object Result{ + {"label", Value.label}, + {"parameters", llvm::json::Array(Value.parameters)}, + }; + if (!Value.documentation.empty()) + Result["documentation"] = Value.documentation; + return std::move(Result); +} + +raw_ostream &llvm::lsp::operator<<(raw_ostream &Os, + const SignatureInformation &Value) { + return Os << Value.label << " - " << toJSON(Value); +} + +//===----------------------------------------------------------------------===// +// SignatureHelp +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const SignatureHelp &Value) { + assert(Value.activeSignature >= 0 && + "Unexpected negative value for number of active signatures."); + assert(Value.activeParameter >= 0 && + "Unexpected negative value for active parameter index"); + return llvm::json::Object{ + {"activeSignature", Value.activeSignature}, + {"activeParameter", Value.activeParameter}, + {"signatures", llvm::json::Array(Value.signatures)}, + }; +} + +//===----------------------------------------------------------------------===// +// DocumentLinkParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + DocumentLinkParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument); +} + +//===----------------------------------------------------------------------===// +// DocumentLink +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const DocumentLink &Value) { + return llvm::json::Object{ + {"range", Value.range}, + {"target", Value.target}, + }; +} + +//===----------------------------------------------------------------------===// +// InlayHintsParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + InlayHintsParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument) && + O.map("range", Result.range); +} + +//===----------------------------------------------------------------------===// +// InlayHint +//===----------------------------------------------------------------------===// + +llvm::json::Value llvm::lsp::toJSON(const InlayHint &Value) { + return llvm::json::Object{{"position", Value.position}, + {"kind", (int)Value.kind}, + {"label", Value.label}, + {"paddingLeft", Value.paddingLeft}, + {"paddingRight", Value.paddingRight}}; +} +bool llvm::lsp::operator==(const InlayHint &Lhs, const InlayHint &Rhs) { + return std::tie(Lhs.position, Lhs.kind, Lhs.label) == + std::tie(Rhs.position, Rhs.kind, Rhs.label); +} +bool llvm::lsp::operator<(const InlayHint &Lhs, const InlayHint &Rhs) { + return std::tie(Lhs.position, Lhs.kind, Lhs.label) < + std::tie(Rhs.position, Rhs.kind, Rhs.label); +} + +llvm::raw_ostream &llvm::lsp::operator<<(llvm::raw_ostream &Os, + InlayHintKind Value) { + switch (Value) { + case InlayHintKind::Parameter: + return Os << "parameter"; + case InlayHintKind::Type: + return Os << "type"; + } + llvm_unreachable("Unknown InlayHintKind"); +} + +//===----------------------------------------------------------------------===// +// CodeActionContext +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CodeActionContext &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + if (!O || !O.map("diagnostics", Result.diagnostics)) + return false; + O.map("only", Result.only); + return true; +} + +//===----------------------------------------------------------------------===// +// CodeActionParams +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, + CodeActionParams &Result, llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("textDocument", Result.textDocument) && + O.map("range", Result.range) && O.map("context", Result.context); +} + +//===----------------------------------------------------------------------===// +// WorkspaceEdit +//===----------------------------------------------------------------------===// + +bool llvm::lsp::fromJSON(const llvm::json::Value &Value, WorkspaceEdit &Result, + llvm::json::Path Path) { + llvm::json::ObjectMapper O(Value, Path); + return O && O.map("changes", Result.changes); +} + +llvm::json::Value llvm::lsp::toJSON(const WorkspaceEdit &Value) { + llvm::json::Object FileChanges; + for (auto &Change : Value.changes) + FileChanges[Change.first] = llvm::json::Array(Change.second); + return llvm::json::Object{{"changes", std::move(FileChanges)}}; +} + +//===----------------------------------------------------------------------===// +// CodeAction +//===----------------------------------------------------------------------===// + +const llvm::StringLiteral CodeAction::kQuickFix = "quickfix"; +const llvm::StringLiteral CodeAction::kRefactor = "refactor"; +const llvm::StringLiteral CodeAction::kInfo = "info"; + +llvm::json::Value llvm::lsp::toJSON(const CodeAction &Value) { + llvm::json::Object CodeAction{{"title", Value.title}}; + if (Value.kind) + CodeAction["kind"] = *Value.kind; + if (Value.diagnostics) + CodeAction["diagnostics"] = llvm::json::Array(*Value.diagnostics); + if (Value.isPreferred) + CodeAction["isPreferred"] = true; + if (Value.edit) + CodeAction["edit"] = *Value.edit; + return std::move(CodeAction); +} diff --git a/llvm/lib/Support/LSP/Transport.cpp b/llvm/lib/Support/LSP/Transport.cpp new file mode 100644 index 000000000000..e71f17701636 --- /dev/null +++ b/llvm/lib/Support/LSP/Transport.cpp @@ -0,0 +1,369 @@ +//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/LSP/Transport.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include +#include +#include +#include + +using namespace llvm; +using namespace llvm::lsp; + +//===----------------------------------------------------------------------===// +// Reply +//===----------------------------------------------------------------------===// + +namespace { +/// Function object to reply to an LSP call. +/// Each instance must be called exactly once, otherwise: +/// - if there was no reply, an error reply is sent +/// - if there were multiple replies, only the first is sent +class Reply { +public: + Reply(const llvm::json::Value &Id, StringRef Method, JSONTransport &Transport, + std::mutex &TransportOutputMutex); + Reply(Reply &&Other); + Reply &operator=(Reply &&) = delete; + Reply(const Reply &) = delete; + Reply &operator=(const Reply &) = delete; + + void operator()(llvm::Expected Reply); + +private: + std::string Method; + std::atomic Replied = {false}; + llvm::json::Value Id; + JSONTransport *Transport; + std::mutex &TransportOutputMutex; +}; +} // namespace + +Reply::Reply(const llvm::json::Value &Id, llvm::StringRef Method, + JSONTransport &Transport, std::mutex &TransportOutputMutex) + : Method(Method), Id(Id), Transport(&Transport), + TransportOutputMutex(TransportOutputMutex) {} + +Reply::Reply(Reply &&Other) + : Method(Other.Method), Replied(Other.Replied.load()), + Id(std::move(Other.Id)), Transport(Other.Transport), + TransportOutputMutex(Other.TransportOutputMutex) { + Other.Transport = nullptr; +} + +void Reply::operator()(llvm::Expected Reply) { + if (Replied.exchange(true)) { + Logger::error("Replied twice to message {0}({1})", Method, Id); + assert(false && "must reply to each call only once!"); + return; + } + assert(Transport && "expected valid transport to reply to"); + + std::lock_guard TransportLock(TransportOutputMutex); + if (Reply) { + Logger::info("--> reply:{0}({1})", Method, Id); + Transport->reply(std::move(Id), std::move(Reply)); + } else { + llvm::Error Error = Reply.takeError(); + Logger::info("--> reply:{0}({1}): {2}", Method, Id, Error); + Transport->reply(std::move(Id), std::move(Error)); + } +} + +//===----------------------------------------------------------------------===// +// MessageHandler +//===----------------------------------------------------------------------===// + +bool MessageHandler::onNotify(llvm::StringRef Method, llvm::json::Value Value) { + Logger::info("--> {0}", Method); + + if (Method == "exit") + return false; + if (Method == "$cancel") { + // TODO: Add support for cancelling requests. + } else { + auto It = NotificationHandlers.find(Method); + if (It != NotificationHandlers.end()) + It->second(std::move(Value)); + } + return true; +} + +bool MessageHandler::onCall(llvm::StringRef Method, llvm::json::Value Params, + llvm::json::Value Id) { + Logger::info("--> {0}({1})", Method, Id); + + Reply Reply(Id, Method, Transport, TransportOutputMutex); + + auto It = MethodHandlers.find(Method); + if (It != MethodHandlers.end()) { + It->second(std::move(Params), std::move(Reply)); + } else { + Reply(llvm::make_error("method not found: " + Method.str(), + ErrorCode::MethodNotFound)); + } + return true; +} + +bool MessageHandler::onReply(llvm::json::Value Id, + llvm::Expected Result) { + // Find the response handler in the mapping. If it exists, move it out of the + // mapping and erase it. + ResponseHandlerTy ResponseHandler; + { + std::lock_guard responseHandlersLock(ResponseHandlerTy); + auto It = ResponseHandlers.find(debugString(Id)); + if (It != ResponseHandlers.end()) { + ResponseHandler = std::move(It->second); + ResponseHandlers.erase(It); + } + } + + // If we found a response handler, invoke it. Otherwise, log an error. + if (ResponseHandler.second) { + Logger::info("--> reply:{0}({1})", ResponseHandler.first, Id); + ResponseHandler.second(std::move(Id), std::move(Result)); + } else { + Logger::error( + "received a reply with ID {0}, but there was no such outgoing request", + Id); + if (!Result) + llvm::consumeError(Result.takeError()); + } + return true; +} + +//===----------------------------------------------------------------------===// +// JSONTransport +//===----------------------------------------------------------------------===// + +/// Encode the given error as a JSON object. +static llvm::json::Object encodeError(llvm::Error Error) { + std::string Message; + ErrorCode Code = ErrorCode::UnknownErrorCode; + auto HandlerFn = [&](const LSPError &LspError) -> llvm::Error { + Message = LspError.message; + Code = LspError.code; + return llvm::Error::success(); + }; + if (llvm::Error Unhandled = llvm::handleErrors(std::move(Error), HandlerFn)) + Message = llvm::toString(std::move(Unhandled)); + + return llvm::json::Object{ + {"message", std::move(Message)}, + {"code", int64_t(Code)}, + }; +} + +/// Decode the given JSON object into an error. +llvm::Error decodeError(const llvm::json::Object &O) { + StringRef Msg = O.getString("message").value_or("Unspecified error"); + if (std::optional Code = O.getInteger("code")) + return llvm::make_error(Msg.str(), ErrorCode(*Code)); + return llvm::make_error(llvm::inconvertibleErrorCode(), + Msg.str()); +} + +void JSONTransport::notify(StringRef Method, llvm::json::Value Params) { + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"method", Method}, + {"params", std::move(Params)}, + }); +} +void JSONTransport::call(StringRef Method, llvm::json::Value Params, + llvm::json::Value Id) { + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(Id)}, + {"method", Method}, + {"params", std::move(Params)}, + }); +} +void JSONTransport::reply(llvm::json::Value Id, + llvm::Expected Result) { + if (Result) { + return sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(Id)}, + {"result", std::move(*Result)}, + }); + } + + sendMessage(llvm::json::Object{ + {"jsonrpc", "2.0"}, + {"id", std::move(Id)}, + {"error", encodeError(Result.takeError())}, + }); +} + +llvm::Error JSONTransport::run(MessageHandler &Handler) { + std::string Json; + while (!In->isEndOfInput()) { + if (In->hasError()) { + return llvm::errorCodeToError( + std::error_code(errno, std::system_category())); + } + + if (succeeded(In->readMessage(Json))) { + if (llvm::Expected Doc = llvm::json::parse(Json)) { + if (!handleMessage(std::move(*Doc), Handler)) + return llvm::Error::success(); + } else { + Logger::error("JSON parse error: {0}", llvm::toString(Doc.takeError())); + } + } + } + return llvm::errorCodeToError(std::make_error_code(std::errc::io_error)); +} + +void JSONTransport::sendMessage(llvm::json::Value Msg) { + OutputBuffer.clear(); + llvm::raw_svector_ostream os(OutputBuffer); + os << llvm::formatv(PrettyOutput ? "{0:2}\n" : "{0}", Msg); + Out << "Content-Length: " << OutputBuffer.size() << "\r\n\r\n" + << OutputBuffer; + Out.flush(); + Logger::debug(">>> {0}\n", OutputBuffer); +} + +bool JSONTransport::handleMessage(llvm::json::Value Msg, + MessageHandler &Handler) { + // Message must be an object with "jsonrpc":"2.0". + llvm::json::Object *Object = Msg.getAsObject(); + if (!Object || + Object->getString("jsonrpc") != std::optional("2.0")) + return false; + + // `id` may be any JSON value. If absent, this is a notification. + std::optional Id; + if (llvm::json::Value *I = Object->get("id")) + Id = std::move(*I); + std::optional Method = Object->getString("method"); + + // This is a response. + if (!Method) { + if (!Id) + return false; + if (auto *Err = Object->getObject("error")) + return Handler.onReply(std::move(*Id), decodeError(*Err)); + // result should be given, use null if not. + llvm::json::Value Result = nullptr; + if (llvm::json::Value *R = Object->get("result")) + Result = std::move(*R); + return Handler.onReply(std::move(*Id), std::move(Result)); + } + + // Params should be given, use null if not. + llvm::json::Value Params = nullptr; + if (llvm::json::Value *P = Object->get("params")) + Params = std::move(*P); + + if (Id) + return Handler.onCall(*Method, std::move(Params), std::move(*Id)); + return Handler.onNotify(*Method, std::move(Params)); +} + +/// Tries to read a line up to and including \n. +/// If failing, feof(), ferror(), or shutdownRequested() will be set. +LogicalResult readLine(std::FILE *In, SmallVectorImpl &Out) { + // Big enough to hold any reasonable header line. May not fit content lines + // in delimited mode, but performance doesn't matter for that mode. + static constexpr int BufSize = 128; + size_t Size = 0; + Out.clear(); + for (;;) { + Out.resize_for_overwrite(Size + BufSize); + if (!std::fgets(&Out[Size], BufSize, In)) + return failure(); + + clearerr(In); + + // If the line contained null bytes, anything after it (including \n) will + // be ignored. Fortunately this is not a legal header or JSON. + size_t Read = std::strlen(&Out[Size]); + if (Read > 0 && Out[Size + Read - 1] == '\n') { + Out.resize(Size + Read); + return success(); + } + Size += Read; + } +} + +// Returns std::nullopt when: +// - ferror(), feof(), or shutdownRequested() are set. +// - Content-Length is missing or empty (protocol error) +LogicalResult +JSONTransportInputOverFile::readStandardMessage(std::string &Json) { + // A Language Server Protocol message starts with a set of HTTP headers, + // delimited by \r\n, and terminated by an empty line (\r\n). + unsigned long long ContentLength = 0; + llvm::SmallString<128> Line; + while (true) { + if (feof(In) || hasError() || failed(readLine(In, Line))) + return failure(); + + // Content-Length is a mandatory header, and the only one we handle. + StringRef LineRef = Line; + if (LineRef.consume_front("Content-Length: ")) { + llvm::getAsUnsignedInteger(LineRef.trim(), 0, ContentLength); + } else if (!LineRef.trim().empty()) { + // It's another header, ignore it. + continue; + } else { + // An empty line indicates the end of headers. Go ahead and read the JSON. + break; + } + } + + // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" + if (ContentLength == 0 || ContentLength > 1 << 30) + return failure(); + + Json.resize(ContentLength); + for (size_t Pos = 0, Read; Pos < ContentLength; Pos += Read) { + Read = std::fread(&Json[Pos], 1, ContentLength - Pos, In); + if (Read == 0) + return failure(); + + // If we're done, the error was transient. If we're not done, either it was + // transient or we'll see it again on retry. + clearerr(In); + Pos += Read; + } + return success(); +} + +/// For lit tests we support a simplified syntax: +/// - messages are delimited by '// -----' on a line by itself +/// - lines starting with // are ignored. +/// This is a testing path, so favor simplicity over performance here. +/// When returning failure: feof(), ferror(), or shutdownRequested() will be +/// set. +LogicalResult +JSONTransportInputOverFile::readDelimitedMessage(std::string &Json) { + Json.clear(); + llvm::SmallString<128> Line; + while (succeeded(readLine(In, Line))) { + StringRef LineRef = Line.str().trim(); + if (LineRef.starts_with("//")) { + // Found a delimiter for the message. + if (LineRef == "// -----") + break; + continue; + } + + Json += Line; + } + + return failure(ferror(In)); +} diff --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt index 0910a0b296dd..d1dfb1dc4a72 100644 --- a/llvm/unittests/Support/CMakeLists.txt +++ b/llvm/unittests/Support/CMakeLists.txt @@ -125,6 +125,8 @@ add_llvm_unittest(SupportTests intrinsics_gen ) +add_subdirectory(LSP) + target_link_libraries(SupportTests PRIVATE LLVMTestingSupport) # Disable all warning for AlignOfTest.cpp, diff --git a/llvm/unittests/Support/LSP/CMakeLists.txt b/llvm/unittests/Support/LSP/CMakeLists.txt new file mode 100644 index 000000000000..790a8b725469 --- /dev/null +++ b/llvm/unittests/Support/LSP/CMakeLists.txt @@ -0,0 +1,8 @@ +set(LLVM_LINK_COMPONENTS + SupportLSP +) + +add_llvm_unittest(LLVMSupportLSPTests + Protocol.cpp + Transport.cpp +) diff --git a/mlir/unittests/Tools/lsp-server-support/Protocol.cpp b/llvm/unittests/Support/LSP/Protocol.cpp similarity index 93% rename from mlir/unittests/Tools/lsp-server-support/Protocol.cpp rename to llvm/unittests/Support/LSP/Protocol.cpp index 04d7b2fbb440..43c548c24b38 100644 --- a/mlir/unittests/Tools/lsp-server-support/Protocol.cpp +++ b/llvm/unittests/Support/LSP/Protocol.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Protocol.h" #include "gtest/gtest.h" -using namespace mlir; -using namespace mlir::lsp; +using namespace llvm; +using namespace llvm::lsp; using namespace testing; namespace { diff --git a/mlir/unittests/Tools/lsp-server-support/Transport.cpp b/llvm/unittests/Support/LSP/Transport.cpp similarity index 96% rename from mlir/unittests/Tools/lsp-server-support/Transport.cpp rename to llvm/unittests/Support/LSP/Transport.cpp index 92581bd2bad0..514e93e98352 100644 --- a/mlir/unittests/Tools/lsp-server-support/Transport.cpp +++ b/llvm/unittests/Support/LSP/Transport.cpp @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Tools/lsp-server-support/Transport.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -using namespace mlir; -using namespace mlir::lsp; +using namespace llvm; +using namespace llvm::lsp; using namespace testing; namespace { @@ -88,7 +88,7 @@ protected: TEST_F(TransportInputTest, RequestWithInvalidParams) { struct Handler { void onMethod(const TextDocumentItem ¶ms, - mlir::lsp::Callback callback) {} + llvm::lsp::Callback callback) {} } handler; getMessageHandler().method("invalid-params-request", &handler, &Handler::onMethod); diff --git a/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h b/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h index 9ed8326a602e..920ce831e42b 100644 --- a/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h +++ b/mlir/include/mlir/Tools/lsp-server-support/SourceMgrUtils.h @@ -14,7 +14,8 @@ #ifndef MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H #define MLIR_TOOLS_LSPSERVERSUPPORT_SOURCEMGRUTILS_H -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/SourceMgr.h" #include @@ -45,17 +46,18 @@ bool contains(SMRange range, SMLoc loc); /// This class represents a single include within a root file. struct SourceMgrInclude { - SourceMgrInclude(const lsp::URIForFile &uri, const lsp::Range &range) + SourceMgrInclude(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range) : uri(uri), range(range) {} /// Build a hover for the current include file. - Hover buildHover() const; + llvm::lsp::Hover buildHover() const; /// The URI of the file that is included. - lsp::URIForFile uri; + llvm::lsp::URIForFile uri; /// The range of the include directive. - lsp::Range range; + llvm::lsp::Range range; }; /// Given a source manager, gather all of the processed include files. These are diff --git a/mlir/include/mlir/Tools/lsp-server-support/Transport.h b/mlir/include/mlir/Tools/lsp-server-support/Transport.h deleted file mode 100644 index 0010a475fedd..000000000000 --- a/mlir/include/mlir/Tools/lsp-server-support/Transport.h +++ /dev/null @@ -1,283 +0,0 @@ -//===--- Transport.h - Sending and Receiving LSP messages -------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// The language server protocol is usually implemented by writing messages as -// JSON-RPC over the stdin/stdout of a subprocess. This file contains a JSON -// transport interface that handles this communication. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H -#define MLIR_TOOLS_LSPSERVERSUPPORT_TRANSPORT_H - -#include "mlir/Support/DebugStringHelper.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/FunctionExtras.h" -#include "llvm/ADT/StringMap.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/FormatAdapters.h" -#include "llvm/Support/JSON.h" -#include "llvm/Support/raw_ostream.h" -#include - -namespace mlir { -namespace lsp { -class MessageHandler; - -//===----------------------------------------------------------------------===// -// JSONTransport -//===----------------------------------------------------------------------===// - -/// The encoding style of the JSON-RPC messages (both input and output). -enum JSONStreamStyle { - /// Encoding per the LSP specification, with mandatory Content-Length header. - Standard, - /// Messages are delimited by a '// -----' line. Comment lines start with //. - Delimited -}; - -/// An abstract class used by the JSONTransport to read JSON message. -class JSONTransportInput { -public: - explicit JSONTransportInput(JSONStreamStyle style = JSONStreamStyle::Standard) - : style(style) {} - virtual ~JSONTransportInput() = default; - - virtual bool hasError() const = 0; - virtual bool isEndOfInput() const = 0; - - /// Read in a message from the input stream. - LogicalResult readMessage(std::string &json) { - return style == JSONStreamStyle::Delimited ? readDelimitedMessage(json) - : readStandardMessage(json); - } - virtual LogicalResult readDelimitedMessage(std::string &json) = 0; - virtual LogicalResult readStandardMessage(std::string &json) = 0; - -private: - /// The JSON stream style to use. - JSONStreamStyle style; -}; - -/// Concrete implementation of the JSONTransportInput that reads from a file. -class JSONTransportInputOverFile : public JSONTransportInput { -public: - explicit JSONTransportInputOverFile( - std::FILE *in, JSONStreamStyle style = JSONStreamStyle::Standard) - : JSONTransportInput(style), in(in) {} - - bool hasError() const final { return ferror(in); } - bool isEndOfInput() const final { return feof(in); } - - LogicalResult readDelimitedMessage(std::string &json) final; - LogicalResult readStandardMessage(std::string &json) final; - -private: - std::FILE *in; -}; - -/// A transport class that performs the JSON-RPC communication with the LSP -/// client. -class JSONTransport { -public: - JSONTransport(std::unique_ptr in, raw_ostream &out, - bool prettyOutput = false) - : in(std::move(in)), out(out), prettyOutput(prettyOutput) {} - - JSONTransport(std::FILE *in, raw_ostream &out, - JSONStreamStyle style = JSONStreamStyle::Standard, - bool prettyOutput = false) - : in(std::make_unique(in, style)), out(out), - prettyOutput(prettyOutput) {} - - /// The following methods are used to send a message to the LSP client. - void notify(StringRef method, llvm::json::Value params); - void call(StringRef method, llvm::json::Value params, llvm::json::Value id); - void reply(llvm::json::Value id, llvm::Expected result); - - /// Start executing the JSON-RPC transport. - llvm::Error run(MessageHandler &handler); - -private: - /// Dispatches the given incoming json message to the message handler. - bool handleMessage(llvm::json::Value msg, MessageHandler &handler); - /// Writes the given message to the output stream. - void sendMessage(llvm::json::Value msg); - -private: - /// The input to read a message from. - std::unique_ptr in; - SmallVector outputBuffer; - /// The output file stream. - raw_ostream &out; - /// If the output JSON should be formatted for easier readability. - bool prettyOutput; -}; - -//===----------------------------------------------------------------------===// -// MessageHandler -//===----------------------------------------------------------------------===// - -/// A Callback is a void function that accepts Expected. This is -/// accepted by functions that logically return T. -template -using Callback = llvm::unique_function)>; - -/// An OutgoingNotification is a function used for outgoing notifications -/// send to the client. -template -using OutgoingNotification = llvm::unique_function; - -/// An OutgoingRequest is a function used for outgoing requests to send to -/// the client. -template -using OutgoingRequest = - llvm::unique_function; - -/// An `OutgoingRequestCallback` is invoked when an outgoing request to the -/// client receives a response in turn. It is passed the original request's ID, -/// as well as the response result. -template -using OutgoingRequestCallback = - std::function)>; - -/// A handler used to process the incoming transport messages. -class MessageHandler { -public: - MessageHandler(JSONTransport &transport) : transport(transport) {} - - bool onNotify(StringRef method, llvm::json::Value value); - bool onCall(StringRef method, llvm::json::Value params, llvm::json::Value id); - bool onReply(llvm::json::Value id, llvm::Expected result); - - template - static llvm::Expected parse(const llvm::json::Value &raw, - StringRef payloadName, StringRef payloadKind) { - T result; - llvm::json::Path::Root root; - if (fromJSON(raw, result, root)) - return std::move(result); - - // Dump the relevant parts of the broken message. - std::string context; - llvm::raw_string_ostream os(context); - root.printErrorContext(raw, os); - - // Report the error (e.g. to the client). - return llvm::make_error( - llvm::formatv("failed to decode {0} {1}: {2}", payloadName, payloadKind, - fmt_consume(root.getError())), - ErrorCode::InvalidParams); - } - - template - void method(llvm::StringLiteral method, ThisT *thisPtr, - void (ThisT::*handler)(const Param &, Callback)) { - methodHandlers[method] = [method, handler, - thisPtr](llvm::json::Value rawParams, - Callback reply) { - llvm::Expected param = parse(rawParams, method, "request"); - if (!param) - return reply(param.takeError()); - (thisPtr->*handler)(*param, std::move(reply)); - }; - } - - template - void notification(llvm::StringLiteral method, ThisT *thisPtr, - void (ThisT::*handler)(const Param &)) { - notificationHandlers[method] = [method, handler, - thisPtr](llvm::json::Value rawParams) { - llvm::Expected param = - parse(rawParams, method, "notification"); - if (!param) { - return llvm::consumeError( - llvm::handleErrors(param.takeError(), [](const LSPError &lspError) { - Logger::error("JSON parsing error: {0}", - lspError.message.c_str()); - })); - } - (thisPtr->*handler)(*param); - }; - } - - /// Create an OutgoingNotification object used for the given method. - template - OutgoingNotification outgoingNotification(llvm::StringLiteral method) { - return [&, method](const T ¶ms) { - std::lock_guard transportLock(transportOutputMutex); - Logger::info("--> {0}", method); - transport.notify(method, llvm::json::Value(params)); - }; - } - - /// Create an OutgoingRequest function that, when called, sends a request with - /// the given method via the transport. Should the outgoing request be - /// met with a response, the result JSON is parsed and the response callback - /// is invoked. - template - OutgoingRequest - outgoingRequest(llvm::StringLiteral method, - OutgoingRequestCallback callback) { - return [&, method, callback](const Param ¶m, llvm::json::Value id) { - auto callbackWrapper = [method, callback = std::move(callback)]( - llvm::json::Value id, - llvm::Expected value) { - if (!value) - return callback(std::move(id), value.takeError()); - - std::string responseName = llvm::formatv("reply:{0}({1})", method, id); - llvm::Expected result = - parse(*value, responseName, "response"); - if (!result) - return callback(std::move(id), result.takeError()); - - return callback(std::move(id), *result); - }; - - { - std::lock_guard lock(responseHandlersMutex); - responseHandlers.insert( - {debugString(id), std::make_pair(method.str(), callbackWrapper)}); - } - - std::lock_guard transportLock(transportOutputMutex); - Logger::info("--> {0}({1})", method, id); - transport.call(method, llvm::json::Value(param), id); - }; - } - -private: - template - using HandlerMap = llvm::StringMap>; - - HandlerMap notificationHandlers; - HandlerMap)> - methodHandlers; - - /// A pair of (1) the original request's method name, and (2) the callback - /// function to be invoked for responses. - using ResponseHandlerTy = - std::pair>; - /// A mapping from request/response ID to response handler. - llvm::StringMap responseHandlers; - /// Mutex to guard insertion into the response handler map. - std::mutex responseHandlersMutex; - - JSONTransport &transport; - - /// Mutex to guard sending output messages to the transport. - std::mutex transportOutputMutex; -}; - -} // namespace lsp -} // namespace mlir - -#endif diff --git a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h index 4811ecb5e92b..0d9ba2a0d160 100644 --- a/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h +++ b/mlir/include/mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h @@ -16,14 +16,16 @@ namespace llvm { template class function_ref; +namespace lsp { +class URIForFile; +} // namespace lsp } // namespace llvm namespace mlir { class DialectRegistry; namespace lsp { -class URIForFile; using DialectRegistryFn = - llvm::function_ref; + llvm::function_ref; } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/lsp-server-support/CMakeLists.txt b/mlir/lib/Tools/lsp-server-support/CMakeLists.txt index 48a96016b792..2fe29f1b9ec4 100644 --- a/mlir/lib/Tools/lsp-server-support/CMakeLists.txt +++ b/mlir/lib/Tools/lsp-server-support/CMakeLists.txt @@ -1,13 +1,13 @@ add_mlir_library(MLIRLspServerSupportLib CompilationDatabase.cpp - Logging.cpp - Protocol.cpp SourceMgrUtils.cpp - Transport.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/lsp-server-support + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRSupport - ) +) diff --git a/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp b/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp index 9ae0674383a1..67b8ef6a256b 100644 --- a/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp +++ b/mlir/lib/Tools/lsp-server-support/CompilationDatabase.cpp @@ -8,14 +8,15 @@ #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" #include "mlir/Support/FileUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/YAMLTraits.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Logger; //===----------------------------------------------------------------------===// // YamlFileInfo diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp deleted file mode 100644 index 98287048355c..000000000000 --- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp +++ /dev/null @@ -1,1043 +0,0 @@ -//===--- Protocol.cpp - Language Server Protocol Implementation -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file contains the serialization code for the LSP structs. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/JSON.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/Path.h" -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace mlir::lsp; - -// Helper that doesn't treat `null` and absent fields as failures. -template -static bool mapOptOrNull(const llvm::json::Value ¶ms, - llvm::StringLiteral prop, T &out, - llvm::json::Path path) { - const llvm::json::Object *o = params.getAsObject(); - assert(o); - - // Field is missing or null. - auto *v = o->get(prop); - if (!v || v->getAsNull()) - return true; - return fromJSON(*v, out, path.field(prop)); -} - -//===----------------------------------------------------------------------===// -// LSPError -//===----------------------------------------------------------------------===// - -char LSPError::ID; - -//===----------------------------------------------------------------------===// -// URIForFile -//===----------------------------------------------------------------------===// - -static bool isWindowsPath(StringRef path) { - return path.size() > 1 && llvm::isAlpha(path[0]) && path[1] == ':'; -} - -static bool isNetworkPath(StringRef path) { - return path.size() > 2 && path[0] == path[1] && - llvm::sys::path::is_separator(path[0]); -} - -static bool shouldEscapeInURI(unsigned char c) { - // Unreserved characters. - if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || - (c >= '0' && c <= '9')) - return false; - - switch (c) { - case '-': - case '_': - case '.': - case '~': - // '/' is only reserved when parsing. - case '/': - // ':' is only reserved for relative URI paths, which we doesn't produce. - case ':': - return false; - } - return true; -} - -/// Encodes a string according to percent-encoding. -/// - Unreserved characters are not escaped. -/// - Reserved characters always escaped with exceptions like '/'. -/// - All other characters are escaped. -static void percentEncode(StringRef content, std::string &out) { - for (unsigned char c : content) { - if (shouldEscapeInURI(c)) { - out.push_back('%'); - out.push_back(llvm::hexdigit(c / 16)); - out.push_back(llvm::hexdigit(c % 16)); - } else { - out.push_back(c); - } - } -} - -/// Decodes a string according to percent-encoding. -static std::string percentDecode(StringRef content) { - std::string result; - for (auto i = content.begin(), e = content.end(); i != e; ++i) { - if (*i != '%') { - result += *i; - continue; - } - if (*i == '%' && i + 2 < content.end() && llvm::isHexDigit(*(i + 1)) && - llvm::isHexDigit(*(i + 2))) { - result.push_back(llvm::hexFromNibbles(*(i + 1), *(i + 2))); - i += 2; - } else { - result.push_back(*i); - } - } - return result; -} - -/// Return the set containing the supported URI schemes. -static StringSet<> &getSupportedSchemes() { - static StringSet<> schemes({"file", "test"}); - return schemes; -} - -/// Returns true if the given scheme is structurally valid, i.e. it does not -/// contain any invalid scheme characters. This does not check that the scheme -/// is actually supported. -static bool isStructurallyValidScheme(StringRef scheme) { - if (scheme.empty()) - return false; - if (!llvm::isAlpha(scheme[0])) - return false; - return llvm::all_of(llvm::drop_begin(scheme), [](char c) { - return llvm::isAlnum(c) || c == '+' || c == '.' || c == '-'; - }); -} - -static llvm::Expected uriFromAbsolutePath(StringRef absolutePath, - StringRef scheme) { - std::string body; - StringRef authority; - StringRef root = llvm::sys::path::root_name(absolutePath); - if (isNetworkPath(root)) { - // Windows UNC paths e.g. \\server\share => file://server/share - authority = root.drop_front(2); - absolutePath.consume_front(root); - } else if (isWindowsPath(root)) { - // Windows paths e.g. X:\path => file:///X:/path - body = "/"; - } - body += llvm::sys::path::convert_to_slash(absolutePath); - - std::string uri = scheme.str() + ":"; - if (authority.empty() && body.empty()) - return uri; - - // If authority if empty, we only print body if it starts with "/"; otherwise, - // the URI is invalid. - if (!authority.empty() || StringRef(body).starts_with("/")) { - uri.append("//"); - percentEncode(authority, uri); - } - percentEncode(body, uri); - return uri; -} - -static llvm::Expected getAbsolutePath(StringRef authority, - StringRef body) { - if (!body.starts_with("/")) - return llvm::createStringError( - llvm::inconvertibleErrorCode(), - "File scheme: expect body to be an absolute path starting " - "with '/': " + - body); - SmallString<128> path; - if (!authority.empty()) { - // Windows UNC paths e.g. file://server/share => \\server\share - ("//" + authority).toVector(path); - } else if (isWindowsPath(body.substr(1))) { - // Windows paths e.g. file:///X:/path => X:\path - body.consume_front("/"); - } - path.append(body); - llvm::sys::path::native(path); - return std::string(path); -} - -static llvm::Expected parseFilePathFromURI(StringRef origUri) { - StringRef uri = origUri; - - // Decode the scheme of the URI. - size_t pos = uri.find(':'); - if (pos == StringRef::npos) - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "Scheme must be provided in URI: " + - origUri); - StringRef schemeStr = uri.substr(0, pos); - std::string uriScheme = percentDecode(schemeStr); - if (!isStructurallyValidScheme(uriScheme)) - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "Invalid scheme: " + schemeStr + - " (decoded: " + uriScheme + ")"); - uri = uri.substr(pos + 1); - - // Decode the authority of the URI. - std::string uriAuthority; - if (uri.consume_front("//")) { - pos = uri.find('/'); - uriAuthority = percentDecode(uri.substr(0, pos)); - uri = uri.substr(pos); - } - - // Decode the body of the URI. - std::string uriBody = percentDecode(uri); - - // Compute the absolute path for this uri. - if (!getSupportedSchemes().contains(uriScheme)) { - return llvm::createStringError(llvm::inconvertibleErrorCode(), - "unsupported URI scheme `" + uriScheme + - "' for workspace files"); - } - return getAbsolutePath(uriAuthority, uriBody); -} - -llvm::Expected URIForFile::fromURI(StringRef uri) { - llvm::Expected filePath = parseFilePathFromURI(uri); - if (!filePath) - return filePath.takeError(); - return URIForFile(std::move(*filePath), uri.str()); -} - -llvm::Expected URIForFile::fromFile(StringRef absoluteFilepath, - StringRef scheme) { - llvm::Expected uri = - uriFromAbsolutePath(absoluteFilepath, scheme); - if (!uri) - return uri.takeError(); - return fromURI(*uri); -} - -StringRef URIForFile::scheme() const { return uri().split(':').first; } - -void URIForFile::registerSupportedScheme(StringRef scheme) { - getSupportedSchemes().insert(scheme); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, URIForFile &result, - llvm::json::Path path) { - if (std::optional str = value.getAsString()) { - llvm::Expected expectedURI = URIForFile::fromURI(*str); - if (!expectedURI) { - path.report("unresolvable URI"); - consumeError(expectedURI.takeError()); - return false; - } - result = std::move(*expectedURI); - return true; - } - return false; -} - -llvm::json::Value mlir::lsp::toJSON(const URIForFile &value) { - return value.uri(); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const URIForFile &value) { - return os << value.uri(); -} - -//===----------------------------------------------------------------------===// -// ClientCapabilities -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ClientCapabilities &result, llvm::json::Path path) { - const llvm::json::Object *o = value.getAsObject(); - if (!o) { - path.report("expected object"); - return false; - } - if (const llvm::json::Object *textDocument = o->getObject("textDocument")) { - if (const llvm::json::Object *documentSymbol = - textDocument->getObject("documentSymbol")) { - if (std::optional hierarchicalSupport = - documentSymbol->getBoolean("hierarchicalDocumentSymbolSupport")) - result.hierarchicalDocumentSymbol = *hierarchicalSupport; - } - if (auto *codeAction = textDocument->getObject("codeAction")) { - if (codeAction->getObject("codeActionLiteralSupport")) - result.codeActionStructure = true; - } - } - if (auto *window = o->getObject("window")) { - if (std::optional workDoneProgressSupport = - window->getBoolean("workDoneProgress")) - result.workDoneProgress = *workDoneProgressSupport; - } - return true; -} - -//===----------------------------------------------------------------------===// -// ClientInfo -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, ClientInfo &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o || !o.map("name", result.name)) - return false; - - // Don't fail if we can't parse version. - o.map("version", result.version); - return true; -} - -//===----------------------------------------------------------------------===// -// InitializeParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, TraceLevel &result, - llvm::json::Path path) { - if (std::optional str = value.getAsString()) { - if (*str == "off") { - result = TraceLevel::Off; - return true; - } - if (*str == "messages") { - result = TraceLevel::Messages; - return true; - } - if (*str == "verbose") { - result = TraceLevel::Verbose; - return true; - } - } - return false; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - InitializeParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o) - return false; - // We deliberately don't fail if we can't parse individual fields. - o.map("capabilities", result.capabilities); - o.map("trace", result.trace); - mapOptOrNull(value, "clientInfo", result.clientInfo, path); - - return true; -} - -//===----------------------------------------------------------------------===// -// TextDocumentItem -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentItem &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && - o.map("languageId", result.languageId) && o.map("text", result.text) && - o.map("version", result.version); -} - -//===----------------------------------------------------------------------===// -// TextDocumentIdentifier -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const TextDocumentIdentifier &value) { - return llvm::json::Object{{"uri", value.uri}}; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentIdentifier &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri); -} - -//===----------------------------------------------------------------------===// -// VersionedTextDocumentIdentifier -//===----------------------------------------------------------------------===// - -llvm::json::Value -mlir::lsp::toJSON(const VersionedTextDocumentIdentifier &value) { - return llvm::json::Object{ - {"uri", value.uri}, - {"version", value.version}, - }; -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - VersionedTextDocumentIdentifier &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && o.map("version", result.version); -} - -//===----------------------------------------------------------------------===// -// Position -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Position &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("line", result.line) && - o.map("character", result.character); -} - -llvm::json::Value mlir::lsp::toJSON(const Position &value) { - return llvm::json::Object{ - {"line", value.line}, - {"character", value.character}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Position &value) { - return os << value.line << ':' << value.character; -} - -//===----------------------------------------------------------------------===// -// Range -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Range &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("start", result.start) && o.map("end", result.end); -} - -llvm::json::Value mlir::lsp::toJSON(const Range &value) { - return llvm::json::Object{ - {"start", value.start}, - {"end", value.end}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Range &value) { - return os << value.start << '-' << value.end; -} - -//===----------------------------------------------------------------------===// -// Location -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Location &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("uri", result.uri) && o.map("range", result.range); -} - -llvm::json::Value mlir::lsp::toJSON(const Location &value) { - return llvm::json::Object{ - {"uri", value.uri}, - {"range", value.range}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const Location &value) { - return os << value.range << '@' << value.uri; -} - -//===----------------------------------------------------------------------===// -// TextDocumentPositionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentPositionParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("position", result.position); -} - -//===----------------------------------------------------------------------===// -// ReferenceParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ReferenceContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.mapOptional("includeDeclaration", result.includeDeclaration); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - ReferenceParams &result, llvm::json::Path path) { - TextDocumentPositionParams &base = result; - llvm::json::ObjectMapper o(value, path); - return fromJSON(value, base, path) && o && - o.mapOptional("context", result.context); -} - -//===----------------------------------------------------------------------===// -// DidOpenTextDocumentParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidOpenTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DidCloseTextDocumentParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidCloseTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DidChangeTextDocumentParams -//===----------------------------------------------------------------------===// - -LogicalResult -TextDocumentContentChangeEvent::applyTo(std::string &contents) const { - // If there is no range, the full document changed. - if (!range) { - contents = text; - return success(); - } - - // Try to map the replacement range to the content. - llvm::SourceMgr tmpScrMgr; - tmpScrMgr.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(contents), - SMLoc()); - SMRange rangeLoc = range->getAsSMRange(tmpScrMgr); - if (!rangeLoc.isValid()) - return failure(); - - contents.replace(rangeLoc.Start.getPointer() - contents.data(), - rangeLoc.End.getPointer() - rangeLoc.Start.getPointer(), - text); - return success(); -} - -LogicalResult TextDocumentContentChangeEvent::applyTo( - ArrayRef changes, std::string &contents) { - for (const auto &change : changes) - if (failed(change.applyTo(contents))) - return failure(); - return success(); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - TextDocumentContentChangeEvent &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("range", result.range) && - o.map("rangeLength", result.rangeLength) && o.map("text", result.text); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DidChangeTextDocumentParams &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("contentChanges", result.contentChanges); -} - -//===----------------------------------------------------------------------===// -// MarkupContent -//===----------------------------------------------------------------------===// - -static llvm::StringRef toTextKind(MarkupKind kind) { - switch (kind) { - case MarkupKind::PlainText: - return "plaintext"; - case MarkupKind::Markdown: - return "markdown"; - } - llvm_unreachable("Invalid MarkupKind"); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, MarkupKind kind) { - return os << toTextKind(kind); -} - -llvm::json::Value mlir::lsp::toJSON(const MarkupContent &mc) { - if (mc.value.empty()) - return nullptr; - - return llvm::json::Object{ - {"kind", toTextKind(mc.kind)}, - {"value", mc.value}, - }; -} - -//===----------------------------------------------------------------------===// -// Hover -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const Hover &hover) { - llvm::json::Object result{{"contents", toJSON(hover.contents)}}; - if (hover.range) - result["range"] = toJSON(*hover.range); - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// DocumentSymbol -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const DocumentSymbol &symbol) { - llvm::json::Object result{{"name", symbol.name}, - {"kind", static_cast(symbol.kind)}, - {"range", symbol.range}, - {"selectionRange", symbol.selectionRange}}; - - if (!symbol.detail.empty()) - result["detail"] = symbol.detail; - if (!symbol.children.empty()) - result["children"] = symbol.children; - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// DocumentSymbolParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DocumentSymbolParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DiagnosticRelatedInformation -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DiagnosticRelatedInformation &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("location", result.location) && - o.map("message", result.message); -} - -llvm::json::Value mlir::lsp::toJSON(const DiagnosticRelatedInformation &info) { - return llvm::json::Object{ - {"location", info.location}, - {"message", info.message}, - }; -} - -//===----------------------------------------------------------------------===// -// Diagnostic -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(DiagnosticTag tag) { - return static_cast(tag); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, DiagnosticTag &result, - llvm::json::Path path) { - if (std::optional i = value.getAsInteger()) { - result = (DiagnosticTag)*i; - return true; - } - - return false; -} - -llvm::json::Value mlir::lsp::toJSON(const Diagnostic &diag) { - llvm::json::Object result{ - {"range", diag.range}, - {"severity", (int)diag.severity}, - {"message", diag.message}, - }; - if (diag.category) - result["category"] = *diag.category; - if (!diag.source.empty()) - result["source"] = diag.source; - if (diag.relatedInformation) - result["relatedInformation"] = *diag.relatedInformation; - if (!diag.tags.empty()) - result["tags"] = diag.tags; - return std::move(result); -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, Diagnostic &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o) - return false; - int severity = 0; - if (!mapOptOrNull(value, "severity", severity, path)) - return false; - result.severity = (DiagnosticSeverity)severity; - - return o.map("range", result.range) && o.map("message", result.message) && - mapOptOrNull(value, "category", result.category, path) && - mapOptOrNull(value, "source", result.source, path) && - mapOptOrNull(value, "relatedInformation", result.relatedInformation, - path) && - mapOptOrNull(value, "tags", result.tags, path); -} - -//===----------------------------------------------------------------------===// -// PublishDiagnosticsParams -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const PublishDiagnosticsParams ¶ms) { - return llvm::json::Object{ - {"uri", params.uri}, - {"diagnostics", params.diagnostics}, - {"version", params.version}, - }; -} - -//===----------------------------------------------------------------------===// -// TextEdit -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, TextEdit &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("range", result.range) && o.map("newText", result.newText); -} - -llvm::json::Value mlir::lsp::toJSON(const TextEdit &value) { - return llvm::json::Object{ - {"range", value.range}, - {"newText", value.newText}, - }; -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, const TextEdit &value) { - os << value.range << " => \""; - llvm::printEscapedString(value.newText, os); - return os << '"'; -} - -//===----------------------------------------------------------------------===// -// CompletionItemKind -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionItemKind &result, llvm::json::Path path) { - if (std::optional intValue = value.getAsInteger()) { - if (*intValue < static_cast(CompletionItemKind::Text) || - *intValue > static_cast(CompletionItemKind::TypeParameter)) - return false; - result = static_cast(*intValue); - return true; - } - return false; -} - -CompletionItemKind mlir::lsp::adjustKindToCapability( - CompletionItemKind kind, - CompletionItemKindBitset &supportedCompletionItemKinds) { - size_t kindVal = static_cast(kind); - if (kindVal >= kCompletionItemKindMin && - kindVal <= supportedCompletionItemKinds.size() && - supportedCompletionItemKinds[kindVal]) - return kind; - - // Provide some fall backs for common kinds that are close enough. - switch (kind) { - case CompletionItemKind::Folder: - return CompletionItemKind::File; - case CompletionItemKind::EnumMember: - return CompletionItemKind::Enum; - case CompletionItemKind::Struct: - return CompletionItemKind::Class; - default: - return CompletionItemKind::Text; - } -} - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionItemKindBitset &result, - llvm::json::Path path) { - if (const llvm::json::Array *arrayValue = value.getAsArray()) { - for (size_t i = 0, e = arrayValue->size(); i < e; ++i) { - CompletionItemKind kindOut; - if (fromJSON((*arrayValue)[i], kindOut, path.index(i))) - result.set(size_t(kindOut)); - } - return true; - } - return false; -} - -//===----------------------------------------------------------------------===// -// CompletionItem -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const CompletionItem &value) { - assert(!value.label.empty() && "completion item label is required"); - llvm::json::Object result{{"label", value.label}}; - if (value.kind != CompletionItemKind::Missing) - result["kind"] = static_cast(value.kind); - if (!value.detail.empty()) - result["detail"] = value.detail; - if (value.documentation) - result["documentation"] = value.documentation; - if (!value.sortText.empty()) - result["sortText"] = value.sortText; - if (!value.filterText.empty()) - result["filterText"] = value.filterText; - if (!value.insertText.empty()) - result["insertText"] = value.insertText; - if (value.insertTextFormat != InsertTextFormat::Missing) - result["insertTextFormat"] = static_cast(value.insertTextFormat); - if (value.textEdit) - result["textEdit"] = *value.textEdit; - if (!value.additionalTextEdits.empty()) { - result["additionalTextEdits"] = - llvm::json::Array(value.additionalTextEdits); - } - if (value.deprecated) - result["deprecated"] = value.deprecated; - return std::move(result); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, - const CompletionItem &value) { - return os << value.label << " - " << toJSON(value); -} - -bool mlir::lsp::operator<(const CompletionItem &lhs, - const CompletionItem &rhs) { - return (lhs.sortText.empty() ? lhs.label : lhs.sortText) < - (rhs.sortText.empty() ? rhs.label : rhs.sortText); -} - -//===----------------------------------------------------------------------===// -// CompletionList -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const CompletionList &value) { - return llvm::json::Object{ - {"isIncomplete", value.isIncomplete}, - {"items", llvm::json::Array(value.items)}, - }; -} - -//===----------------------------------------------------------------------===// -// CompletionContext -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - int triggerKind; - if (!o || !o.map("triggerKind", triggerKind) || - !mapOptOrNull(value, "triggerCharacter", result.triggerCharacter, path)) - return false; - result.triggerKind = static_cast(triggerKind); - return true; -} - -//===----------------------------------------------------------------------===// -// CompletionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CompletionParams &result, llvm::json::Path path) { - if (!fromJSON(value, static_cast(result), path)) - return false; - if (const llvm::json::Value *context = value.getAsObject()->get("context")) - return fromJSON(*context, result.context, path.field("context")); - return true; -} - -//===----------------------------------------------------------------------===// -// ParameterInformation -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const ParameterInformation &value) { - assert((value.labelOffsets || !value.labelString.empty()) && - "parameter information label is required"); - llvm::json::Object result; - if (value.labelOffsets) - result["label"] = llvm::json::Array( - {value.labelOffsets->first, value.labelOffsets->second}); - else - result["label"] = value.labelString; - if (!value.documentation.empty()) - result["documentation"] = value.documentation; - return std::move(result); -} - -//===----------------------------------------------------------------------===// -// SignatureInformation -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const SignatureInformation &value) { - assert(!value.label.empty() && "signature information label is required"); - llvm::json::Object result{ - {"label", value.label}, - {"parameters", llvm::json::Array(value.parameters)}, - }; - if (!value.documentation.empty()) - result["documentation"] = value.documentation; - return std::move(result); -} - -raw_ostream &mlir::lsp::operator<<(raw_ostream &os, - const SignatureInformation &value) { - return os << value.label << " - " << toJSON(value); -} - -//===----------------------------------------------------------------------===// -// SignatureHelp -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const SignatureHelp &value) { - assert(value.activeSignature >= 0 && - "Unexpected negative value for number of active signatures."); - assert(value.activeParameter >= 0 && - "Unexpected negative value for active parameter index"); - return llvm::json::Object{ - {"activeSignature", value.activeSignature}, - {"activeParameter", value.activeParameter}, - {"signatures", llvm::json::Array(value.signatures)}, - }; -} - -//===----------------------------------------------------------------------===// -// DocumentLinkParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - DocumentLinkParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument); -} - -//===----------------------------------------------------------------------===// -// DocumentLink -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const DocumentLink &value) { - return llvm::json::Object{ - {"range", value.range}, - {"target", value.target}, - }; -} - -//===----------------------------------------------------------------------===// -// InlayHintsParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - InlayHintsParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("range", result.range); -} - -//===----------------------------------------------------------------------===// -// InlayHint -//===----------------------------------------------------------------------===// - -llvm::json::Value mlir::lsp::toJSON(const InlayHint &value) { - return llvm::json::Object{{"position", value.position}, - {"kind", (int)value.kind}, - {"label", value.label}, - {"paddingLeft", value.paddingLeft}, - {"paddingRight", value.paddingRight}}; -} -bool mlir::lsp::operator==(const InlayHint &lhs, const InlayHint &rhs) { - return std::tie(lhs.position, lhs.kind, lhs.label) == - std::tie(rhs.position, rhs.kind, rhs.label); -} -bool mlir::lsp::operator<(const InlayHint &lhs, const InlayHint &rhs) { - return std::tie(lhs.position, lhs.kind, lhs.label) < - std::tie(rhs.position, rhs.kind, rhs.label); -} - -llvm::raw_ostream &mlir::lsp::operator<<(llvm::raw_ostream &os, - InlayHintKind value) { - switch (value) { - case InlayHintKind::Parameter: - return os << "parameter"; - case InlayHintKind::Type: - return os << "type"; - } - llvm_unreachable("Unknown InlayHintKind"); -} - -//===----------------------------------------------------------------------===// -// CodeActionContext -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CodeActionContext &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - if (!o || !o.map("diagnostics", result.diagnostics)) - return false; - o.map("only", result.only); - return true; -} - -//===----------------------------------------------------------------------===// -// CodeActionParams -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, - CodeActionParams &result, llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("textDocument", result.textDocument) && - o.map("range", result.range) && o.map("context", result.context); -} - -//===----------------------------------------------------------------------===// -// WorkspaceEdit -//===----------------------------------------------------------------------===// - -bool mlir::lsp::fromJSON(const llvm::json::Value &value, WorkspaceEdit &result, - llvm::json::Path path) { - llvm::json::ObjectMapper o(value, path); - return o && o.map("changes", result.changes); -} - -llvm::json::Value mlir::lsp::toJSON(const WorkspaceEdit &value) { - llvm::json::Object fileChanges; - for (auto &change : value.changes) - fileChanges[change.first] = llvm::json::Array(change.second); - return llvm::json::Object{{"changes", std::move(fileChanges)}}; -} - -//===----------------------------------------------------------------------===// -// CodeAction -//===----------------------------------------------------------------------===// - -const llvm::StringLiteral CodeAction::kQuickFix = "quickfix"; -const llvm::StringLiteral CodeAction::kRefactor = "refactor"; -const llvm::StringLiteral CodeAction::kInfo = "info"; - -llvm::json::Value mlir::lsp::toJSON(const CodeAction &value) { - llvm::json::Object codeAction{{"title", value.title}}; - if (value.kind) - codeAction["kind"] = *value.kind; - if (value.diagnostics) - codeAction["diagnostics"] = llvm::json::Array(*value.diagnostics); - if (value.isPreferred) - codeAction["isPreferred"] = true; - if (value.edit) - codeAction["edit"] = *value.edit; - return std::move(codeAction); -} diff --git a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp index f1a362385f28..5cd1c85d054a 100644 --- a/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp +++ b/mlir/lib/Tools/lsp-server-support/SourceMgrUtils.cpp @@ -14,6 +14,10 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Hover; +using llvm::lsp::Range; +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // Utils //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/lsp-server-support/Transport.cpp b/mlir/lib/Tools/lsp-server-support/Transport.cpp deleted file mode 100644 index 5a098b2841f4..000000000000 --- a/mlir/lib/Tools/lsp-server-support/Transport.cpp +++ /dev/null @@ -1,369 +0,0 @@ -//===--- JSONTransport.cpp - sending and receiving LSP messages over JSON -===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/lsp-server-support/Transport.h" -#include "mlir/Support/ToolUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/Support/Error.h" -#include -#include -#include - -using namespace mlir; -using namespace mlir::lsp; - -//===----------------------------------------------------------------------===// -// Reply -//===----------------------------------------------------------------------===// - -namespace { -/// Function object to reply to an LSP call. -/// Each instance must be called exactly once, otherwise: -/// - if there was no reply, an error reply is sent -/// - if there were multiple replies, only the first is sent -class Reply { -public: - Reply(const llvm::json::Value &id, StringRef method, JSONTransport &transport, - std::mutex &transportOutputMutex); - Reply(Reply &&other); - Reply &operator=(Reply &&) = delete; - Reply(const Reply &) = delete; - Reply &operator=(const Reply &) = delete; - - void operator()(llvm::Expected reply); - -private: - std::string method; - std::atomic replied = {false}; - llvm::json::Value id; - JSONTransport *transport; - std::mutex &transportOutputMutex; -}; -} // namespace - -Reply::Reply(const llvm::json::Value &id, llvm::StringRef method, - JSONTransport &transport, std::mutex &transportOutputMutex) - : method(method), id(id), transport(&transport), - transportOutputMutex(transportOutputMutex) {} - -Reply::Reply(Reply &&other) - : method(other.method), replied(other.replied.load()), - id(std::move(other.id)), transport(other.transport), - transportOutputMutex(other.transportOutputMutex) { - other.transport = nullptr; -} - -void Reply::operator()(llvm::Expected reply) { - if (replied.exchange(true)) { - Logger::error("Replied twice to message {0}({1})", method, id); - assert(false && "must reply to each call only once!"); - return; - } - assert(transport && "expected valid transport to reply to"); - - std::lock_guard transportLock(transportOutputMutex); - if (reply) { - Logger::info("--> reply:{0}({1})", method, id); - transport->reply(std::move(id), std::move(reply)); - } else { - llvm::Error error = reply.takeError(); - Logger::info("--> reply:{0}({1}): {2}", method, id, error); - transport->reply(std::move(id), std::move(error)); - } -} - -//===----------------------------------------------------------------------===// -// MessageHandler -//===----------------------------------------------------------------------===// - -bool MessageHandler::onNotify(llvm::StringRef method, llvm::json::Value value) { - Logger::info("--> {0}", method); - - if (method == "exit") - return false; - if (method == "$cancel") { - // TODO: Add support for cancelling requests. - } else { - auto it = notificationHandlers.find(method); - if (it != notificationHandlers.end()) - it->second(std::move(value)); - } - return true; -} - -bool MessageHandler::onCall(llvm::StringRef method, llvm::json::Value params, - llvm::json::Value id) { - Logger::info("--> {0}({1})", method, id); - - Reply reply(id, method, transport, transportOutputMutex); - - auto it = methodHandlers.find(method); - if (it != methodHandlers.end()) { - it->second(std::move(params), std::move(reply)); - } else { - reply(llvm::make_error("method not found: " + method.str(), - ErrorCode::MethodNotFound)); - } - return true; -} - -bool MessageHandler::onReply(llvm::json::Value id, - llvm::Expected result) { - // Find the response handler in the mapping. If it exists, move it out of the - // mapping and erase it. - ResponseHandlerTy responseHandler; - { - std::lock_guard responseHandlersLock(responseHandlersMutex); - auto it = responseHandlers.find(debugString(id)); - if (it != responseHandlers.end()) { - responseHandler = std::move(it->second); - responseHandlers.erase(it); - } - } - - // If we found a response handler, invoke it. Otherwise, log an error. - if (responseHandler.second) { - Logger::info("--> reply:{0}({1})", responseHandler.first, id); - responseHandler.second(std::move(id), std::move(result)); - } else { - Logger::error( - "received a reply with ID {0}, but there was no such outgoing request", - id); - if (!result) - llvm::consumeError(result.takeError()); - } - return true; -} - -//===----------------------------------------------------------------------===// -// JSONTransport -//===----------------------------------------------------------------------===// - -/// Encode the given error as a JSON object. -static llvm::json::Object encodeError(llvm::Error error) { - std::string message; - ErrorCode code = ErrorCode::UnknownErrorCode; - auto handlerFn = [&](const LSPError &lspError) -> llvm::Error { - message = lspError.message; - code = lspError.code; - return llvm::Error::success(); - }; - if (llvm::Error unhandled = llvm::handleErrors(std::move(error), handlerFn)) - message = llvm::toString(std::move(unhandled)); - - return llvm::json::Object{ - {"message", std::move(message)}, - {"code", int64_t(code)}, - }; -} - -/// Decode the given JSON object into an error. -llvm::Error decodeError(const llvm::json::Object &o) { - StringRef msg = o.getString("message").value_or("Unspecified error"); - if (std::optional code = o.getInteger("code")) - return llvm::make_error(msg.str(), ErrorCode(*code)); - return llvm::make_error(llvm::inconvertibleErrorCode(), - msg.str()); -} - -void JSONTransport::notify(StringRef method, llvm::json::Value params) { - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"method", method}, - {"params", std::move(params)}, - }); -} -void JSONTransport::call(StringRef method, llvm::json::Value params, - llvm::json::Value id) { - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"method", method}, - {"params", std::move(params)}, - }); -} -void JSONTransport::reply(llvm::json::Value id, - llvm::Expected result) { - if (result) { - return sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"result", std::move(*result)}, - }); - } - - sendMessage(llvm::json::Object{ - {"jsonrpc", "2.0"}, - {"id", std::move(id)}, - {"error", encodeError(result.takeError())}, - }); -} - -llvm::Error JSONTransport::run(MessageHandler &handler) { - std::string json; - while (!in->isEndOfInput()) { - if (in->hasError()) { - return llvm::errorCodeToError( - std::error_code(errno, std::system_category())); - } - - if (succeeded(in->readMessage(json))) { - if (llvm::Expected doc = llvm::json::parse(json)) { - if (!handleMessage(std::move(*doc), handler)) - return llvm::Error::success(); - } else { - Logger::error("JSON parse error: {0}", llvm::toString(doc.takeError())); - } - } - } - return llvm::errorCodeToError(std::make_error_code(std::errc::io_error)); -} - -void JSONTransport::sendMessage(llvm::json::Value msg) { - outputBuffer.clear(); - llvm::raw_svector_ostream os(outputBuffer); - os << llvm::formatv(prettyOutput ? "{0:2}\n" : "{0}", msg); - out << "Content-Length: " << outputBuffer.size() << "\r\n\r\n" - << outputBuffer; - out.flush(); - Logger::debug(">>> {0}\n", outputBuffer); -} - -bool JSONTransport::handleMessage(llvm::json::Value msg, - MessageHandler &handler) { - // Message must be an object with "jsonrpc":"2.0". - llvm::json::Object *object = msg.getAsObject(); - if (!object || - object->getString("jsonrpc") != std::optional("2.0")) - return false; - - // `id` may be any JSON value. If absent, this is a notification. - std::optional id; - if (llvm::json::Value *i = object->get("id")) - id = std::move(*i); - std::optional method = object->getString("method"); - - // This is a response. - if (!method) { - if (!id) - return false; - if (auto *err = object->getObject("error")) - return handler.onReply(std::move(*id), decodeError(*err)); - // result should be given, use null if not. - llvm::json::Value result = nullptr; - if (llvm::json::Value *r = object->get("result")) - result = std::move(*r); - return handler.onReply(std::move(*id), std::move(result)); - } - - // Params should be given, use null if not. - llvm::json::Value params = nullptr; - if (llvm::json::Value *p = object->get("params")) - params = std::move(*p); - - if (id) - return handler.onCall(*method, std::move(params), std::move(*id)); - return handler.onNotify(*method, std::move(params)); -} - -/// Tries to read a line up to and including \n. -/// If failing, feof(), ferror(), or shutdownRequested() will be set. -LogicalResult readLine(std::FILE *in, SmallVectorImpl &out) { - // Big enough to hold any reasonable header line. May not fit content lines - // in delimited mode, but performance doesn't matter for that mode. - static constexpr int bufSize = 128; - size_t size = 0; - out.clear(); - for (;;) { - out.resize_for_overwrite(size + bufSize); - if (!std::fgets(&out[size], bufSize, in)) - return failure(); - - clearerr(in); - - // If the line contained null bytes, anything after it (including \n) will - // be ignored. Fortunately this is not a legal header or JSON. - size_t read = std::strlen(&out[size]); - if (read > 0 && out[size + read - 1] == '\n') { - out.resize(size + read); - return success(); - } - size += read; - } -} - -// Returns std::nullopt when: -// - ferror(), feof(), or shutdownRequested() are set. -// - Content-Length is missing or empty (protocol error) -LogicalResult -JSONTransportInputOverFile::readStandardMessage(std::string &json) { - // A Language Server Protocol message starts with a set of HTTP headers, - // delimited by \r\n, and terminated by an empty line (\r\n). - unsigned long long contentLength = 0; - llvm::SmallString<128> line; - while (true) { - if (feof(in) || hasError() || failed(readLine(in, line))) - return failure(); - - // Content-Length is a mandatory header, and the only one we handle. - StringRef lineRef = line; - if (lineRef.consume_front("Content-Length: ")) { - llvm::getAsUnsignedInteger(lineRef.trim(), 0, contentLength); - } else if (!lineRef.trim().empty()) { - // It's another header, ignore it. - continue; - } else { - // An empty line indicates the end of headers. Go ahead and read the JSON. - break; - } - } - - // The fuzzer likes crashing us by sending "Content-Length: 9999999999999999" - if (contentLength == 0 || contentLength > 1 << 30) - return failure(); - - json.resize(contentLength); - for (size_t pos = 0, read; pos < contentLength; pos += read) { - read = std::fread(&json[pos], 1, contentLength - pos, in); - if (read == 0) - return failure(); - - // If we're done, the error was transient. If we're not done, either it was - // transient or we'll see it again on retry. - clearerr(in); - pos += read; - } - return success(); -} - -/// For lit tests we support a simplified syntax: -/// - messages are delimited by '// -----' on a line by itself -/// - lines starting with // are ignored. -/// This is a testing path, so favor simplicity over performance here. -/// When returning failure: feof(), ferror(), or shutdownRequested() will be -/// set. -LogicalResult -JSONTransportInputOverFile::readDelimitedMessage(std::string &json) { - json.clear(); - llvm::SmallString<128> line; - while (succeeded(readLine(in, line))) { - StringRef lineRef = line.str().trim(); - if (lineRef.starts_with("//")) { - // Found a delimiter for the message. - if (lineRef == kDefaultSplitMarker) - break; - continue; - } - - json += line; - } - - return failure(ferror(in)); -} diff --git a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt index d04d5156fb3c..e2acba54e562 100644 --- a/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-lsp-server/CMakeLists.txt @@ -7,6 +7,9 @@ add_mlir_library(MLIRLspServerLib ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-lsp-server + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRBytecodeWriter MLIRFunctionInterfaces diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp index 9b937db0c6a7..1bbbcdecb57a 100644 --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -9,8 +9,8 @@ #include "LSPServer.h" #include "MLIRServer.h" #include "Protocol.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include #define DEBUG_TYPE "mlir-lsp-server" @@ -18,6 +18,33 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::CodeAction; +using llvm::lsp::CodeActionParams; +using llvm::lsp::CompletionList; +using llvm::lsp::CompletionParams; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::DocumentSymbolParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::MLIRConvertBytecodeParams; +using llvm::lsp::MLIRConvertBytecodeResult; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h index 2c50c6b4ac6f..d65289963325 100644 --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class MLIRServer; /// Run the main loop of the LSP server using the given MLIR server and /// transport. llvm::LogicalResult runMlirLSPServer(MLIRServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp index 61987525a5ca..47b4328d0d9e 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -16,10 +16,10 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/ToolUtilities.h" -#include "mlir/Tools/lsp-server-support/Logging.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Base64.h" +#include "llvm/Support/LSP/Logging.h" #include "llvm/Support/SourceMgr.h" #include @@ -39,9 +39,9 @@ static std::optional getLocationFromLoc(StringRef uriScheme, llvm::Expected sourceURI = lsp::URIForFile::fromFile(loc.getFilename(), uriScheme); if (!sourceURI) { - lsp::Logger::error("Failed to create URI for file `{0}`: {1}", - loc.getFilename(), - llvm::toString(sourceURI.takeError())); + llvm::lsp::Logger::error("Failed to create URI for file `{0}`: {1}", + loc.getFilename(), + llvm::toString(sourceURI.takeError())); return std::nullopt; } @@ -217,22 +217,22 @@ static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, // Convert the severity for the diagnostic. switch (diag.getSeverity()) { - case DiagnosticSeverity::Note: + case mlir::DiagnosticSeverity::Note: llvm_unreachable("expected notes to be handled separately"); - case DiagnosticSeverity::Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + case mlir::DiagnosticSeverity::Warning: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; - case DiagnosticSeverity::Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + case mlir::DiagnosticSeverity::Error: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; - case DiagnosticSeverity::Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + case mlir::DiagnosticSeverity::Remark: + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.str(); // Attach any notes to the main diagnostic as related information. - std::vector relatedDiags; + std::vector relatedDiags; for (Diagnostic ¬e : diag.getNotes()) { lsp::Location noteLoc; if (std::optional loc = @@ -317,7 +317,7 @@ struct MLIRDocument { void getCodeActionForDiagnostic(const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity, StringRef message, - std::vector &edits); + std::vector &edits); //===--------------------------------------------------------------------===// // Bytecode @@ -355,7 +355,8 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri, // Try to parsed the given IR string. auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } @@ -695,8 +696,8 @@ void MLIRDocument::findDocumentSymbols( if (SymbolOpInterface symbol = dyn_cast(op)) { symbols.emplace_back(symbol.getName(), isa(op) - ? lsp::SymbolKind::Function - : lsp::SymbolKind::Class, + ? llvm::lsp::SymbolKind::Function + : llvm::lsp::SymbolKind::Class, lsp::Range(sourceMgr, def->scopeLoc), lsp::Range(sourceMgr, def->loc)); childSymbols = &symbols.back().children; @@ -704,9 +705,9 @@ void MLIRDocument::findDocumentSymbols( } else if (op->hasTrait()) { // Otherwise, if this is a symbol table push an anonymous document symbol. symbols.emplace_back("<" + op->getName().getStringRef() + ">", - lsp::SymbolKind::Namespace, - lsp::Range(sourceMgr, def->scopeLoc), - lsp::Range(sourceMgr, def->loc)); + llvm::lsp::SymbolKind::Namespace, + llvm::lsp::Range(sourceMgr, def->scopeLoc), + llvm::lsp::Range(sourceMgr, def->loc)); childSymbols = &symbols.back().children; } } @@ -734,9 +735,9 @@ public: /// Signal code completion for a dialect name, with an optional prefix. void completeDialectName(StringRef prefix) final { for (StringRef dialect : ctx->getAvailableDialects()) { - lsp::CompletionItem item(prefix + dialect, - lsp::CompletionItemKind::Module, - /*sortText=*/"3"); + llvm::lsp::CompletionItem item(prefix + dialect, + llvm::lsp::CompletionItemKind::Module, + /*sortText=*/"3"); item.detail = "dialect"; completionList.items.emplace_back(item); } @@ -753,9 +754,9 @@ public: if (&op.getDialect() != dialect) continue; - lsp::CompletionItem item( + llvm::lsp::CompletionItem item( op.getStringRef().drop_front(dialectName.size() + 1), - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); item.detail = "operation"; completionList.items.emplace_back(item); @@ -768,7 +769,8 @@ public: // Check if we need to insert the `%` or not. bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '%'; - lsp::CompletionItem item(name, lsp::CompletionItemKind::Variable); + llvm::lsp::CompletionItem item(name, + llvm::lsp::CompletionItemKind::Variable); if (stripPrefix) item.insertText = name.drop_front(1).str(); item.detail = std::move(typeData); @@ -781,7 +783,7 @@ public: // Check if we need to insert the `^` or not. bool stripPrefix = getCodeCompleteLoc().getPointer()[-1] == '^'; - lsp::CompletionItem item(name, lsp::CompletionItemKind::Field); + llvm::lsp::CompletionItem item(name, llvm::lsp::CompletionItemKind::Field); if (stripPrefix) item.insertText = name.drop_front(1).str(); completionList.items.emplace_back(item); @@ -790,8 +792,9 @@ public: /// Signal a completion for the given expected token. void completeExpectedTokens(ArrayRef tokens, bool optional) final { for (StringRef token : tokens) { - lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword, - /*sortText=*/"0"); + llvm::lsp::CompletionItem item(token, + llvm::lsp::CompletionItemKind::Keyword, + /*sortText=*/"0"); item.detail = optional ? "optional" : ""; completionList.items.emplace_back(item); } @@ -802,7 +805,7 @@ public: appendSimpleCompletions({"affine_set", "affine_map", "dense", "dense_resource", "false", "loc", "sparse", "true", "unit"}, - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); completeDialectName("#"); @@ -820,13 +823,14 @@ public: appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector", "bf16", "f16", "f32", "f64", "f80", "f128", "index", "none"}, - lsp::CompletionItemKind::Field, + llvm::lsp::CompletionItemKind::Field, /*sortText=*/"1"); // Handle the builtin integer types. for (StringRef type : {"i", "si", "ui"}) { - lsp::CompletionItem item(type + "", lsp::CompletionItemKind::Field, - /*sortText=*/"1"); + llvm::lsp::CompletionItem item(type + "", + llvm::lsp::CompletionItemKind::Field, + /*sortText=*/"1"); item.insertText = type.str(); completionList.items.emplace_back(item); } @@ -846,9 +850,9 @@ public: void completeAliases(const llvm::StringMap &aliases, StringRef prefix = "") { for (const auto &alias : aliases) { - lsp::CompletionItem item(prefix + alias.getKey(), - lsp::CompletionItemKind::Field, - /*sortText=*/"2"); + llvm::lsp::CompletionItem item(prefix + alias.getKey(), + llvm::lsp::CompletionItemKind::Field, + /*sortText=*/"2"); llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue(); completionList.items.emplace_back(item); } @@ -856,7 +860,7 @@ public: /// Add a set of simple completions that all have the same kind. void appendSimpleCompletions(ArrayRef completions, - lsp::CompletionItemKind kind, + llvm::lsp::CompletionItemKind kind, StringRef sortText = "") { for (StringRef completion : completions) completionList.items.emplace_back(completion, kind, sortText); @@ -897,7 +901,7 @@ MLIRDocument::getCodeCompletion(const lsp::URIForFile &uri, void MLIRDocument::getCodeActionForDiagnostic( const lsp::URIForFile &uri, lsp::Position &pos, StringRef severity, - StringRef message, std::vector &edits) { + StringRef message, std::vector &edits) { // Ignore diagnostics that print the current operation. These are always // enabled for the language server, but not generally during normal // parsing/verification. @@ -913,7 +917,7 @@ void MLIRDocument::getCodeActionForDiagnostic( // Add a text edit for adding an expected-* diagnostic check for this // diagnostic. - lsp::TextEdit edit; + llvm::lsp::TextEdit edit; edit.range = lsp::Range(lsp::Position(pos.line, 0)); // Use the indent of the current line for the expected-* diagnostic. @@ -937,13 +941,14 @@ MLIRDocument::convertToBytecode() { // conceptually be relaxed. if (!llvm::hasSingleElement(parsedIR)) { if (parsedIR.empty()) { - return llvm::make_error( + return llvm::make_error( "expected a single and valid top-level operation, please ensure " "there are no errors", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } - return llvm::make_error( - "expected a single top-level operation", lsp::ErrorCode::RequestFailed); + return llvm::make_error( + "expected a single top-level operation", + llvm::lsp::ErrorCode::RequestFailed); } lsp::MLIRConvertBytecodeResult result; @@ -1134,7 +1139,7 @@ void MLIRTextFile::findDocumentSymbols( lsp::Position endPos((i == e - 1) ? totalNumLines - 1 : chunks[i + 1]->lineOffset); lsp::DocumentSymbol symbol("", - lsp::SymbolKind::Namespace, + llvm::lsp::SymbolKind::Namespace, /*range=*/lsp::Range(startPos, endPos), /*selectionRange=*/lsp::Range(startPos)); chunk.document.findDocumentSymbols(symbol.children); @@ -1167,10 +1172,10 @@ lsp::CompletionList MLIRTextFile::getCodeCompletion(const lsp::URIForFile &uri, uri, completePos, context.getDialectRegistry()); // Adjust any completion locations. - for (lsp::CompletionItem &item : completionList.items) { + for (llvm::lsp::CompletionItem &item : completionList.items) { if (item.textEdit) chunk.adjustLocForChunkOffset(item.textEdit->range); - for (lsp::TextEdit &edit : item.additionalTextEdits) + for (llvm::lsp::TextEdit &edit : item.additionalTextEdits) chunk.adjustLocForChunkOffset(edit.range); } return completionList; @@ -1194,10 +1199,10 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, StringRef severity; switch (diag.severity) { - case lsp::DiagnosticSeverity::Error: + case llvm::lsp::DiagnosticSeverity::Error: severity = "error"; break; - case lsp::DiagnosticSeverity::Warning: + case llvm::lsp::DiagnosticSeverity::Warning: severity = "warning"; break; default: @@ -1205,7 +1210,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, } // Get edits for the diagnostic. - std::vector edits; + std::vector edits; chunk.document.getCodeActionForDiagnostic(uri, diagPos, severity, diag.message, edits); @@ -1221,7 +1226,7 @@ void MLIRTextFile::getCodeActions(const lsp::URIForFile &uri, } } // Fixup the locations for any edits. - for (lsp::TextEdit &edit : edits) + for (llvm::lsp::TextEdit &edit : edits) chunk.adjustLocForChunkOffset(edit.range); action.edit.emplace(); @@ -1236,9 +1241,9 @@ llvm::Expected MLIRTextFile::convertToBytecode() { // Bail out if there is more than one chunk, bytecode wants a single module. if (chunks.size() != 1) { - return llvm::make_error( + return llvm::make_error( "unexpected split file, please remove all `// -----`", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } return chunks.front()->document.convertToBytecode(); } @@ -1283,7 +1288,7 @@ lsp::MLIRServer::~MLIRServer() = default; void lsp::MLIRServer::addOrUpdateDocument( const URIForFile &uri, StringRef contents, int64_t version, - std::vector &diagnostics) { + std::vector &diagnostics) { impl->files[uri.file()] = std::make_unique( uri, contents, version, impl->registry_fn, diagnostics); } @@ -1298,17 +1303,17 @@ std::optional lsp::MLIRServer::removeDocument(const URIForFile &uri) { return version; } -void lsp::MLIRServer::getLocationsOf(const URIForFile &uri, - const Position &defPos, - std::vector &locations) { +void lsp::MLIRServer::getLocationsOf( + const URIForFile &uri, const Position &defPos, + std::vector &locations) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->getLocationsOf(uri, defPos, locations); } -void lsp::MLIRServer::findReferencesOf(const URIForFile &uri, - const Position &pos, - std::vector &references) { +void lsp::MLIRServer::findReferencesOf( + const URIForFile &uri, const Position &pos, + std::vector &references) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->findReferencesOf(uri, pos, references); @@ -1367,17 +1372,17 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) { // Try to parse the given source file. Block parsedBlock; if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) { - return llvm::make_error( + return llvm::make_error( "failed to parse bytecode source file: " + errorMsg, - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } // TODO: We currently expect a single top-level operation, but this could // conceptually be relaxed. if (!llvm::hasSingleElement(parsedBlock)) { - return llvm::make_error( + return llvm::make_error( "expected bytecode to contain a single top-level operation", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } // Print the module to a buffer. @@ -1401,9 +1406,9 @@ llvm::Expected lsp::MLIRServer::convertToBytecode(const URIForFile &uri) { auto fileIt = impl->files.find(uri.file()); if (fileIt == impl->files.end()) { - return llvm::make_error( + return llvm::make_error( "language server does not contain an entry for this source file", - lsp::ErrorCode::RequestFailed); + llvm::lsp::ErrorCode::RequestFailed); } return fileIt->second->convertToBytecode(); } diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h index 85e69e69f663..31a01fec8bbc 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.h @@ -9,6 +9,7 @@ #ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ #define LIB_MLIR_TOOLS_MLIRLSPSERVER_SERVER_H_ +#include "Protocol.h" #include "mlir/Support/LLVM.h" #include "mlir/Tools/mlir-lsp-server/MlirLspRegistryFunction.h" #include "llvm/Support/Error.h" @@ -19,16 +20,17 @@ namespace mlir { class DialectRegistry; namespace lsp { -struct CodeAction; -struct CodeActionContext; -struct CompletionList; -struct Diagnostic; -struct DocumentSymbol; -struct Hover; -struct Location; -struct MLIRConvertBytecodeResult; -struct Position; -struct Range; +using llvm::lsp::CodeAction; +using llvm::lsp::CodeActionContext; +using llvm::lsp::CompletionList; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::Hover; +using llvm::lsp::Location; +using llvm::lsp::MLIRConvertBytecodeResult; +using llvm::lsp::Position; +using llvm::lsp::Range; +using llvm::lsp::URIForFile; /// This class implements all of the MLIR related functionality necessary for a /// language server. This class allows for keeping the MLIR specific logic diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp index f1dc32615c6a..d4589b240e39 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp @@ -9,14 +9,18 @@ #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" #include "LSPServer.h" #include "MLIRServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::JSONTransport; +using llvm::lsp::Logger; + LogicalResult mlir::MlirLspServerMain(int argc, char **argv, DialectRegistryFn registry_fn) { llvm::cl::opt inputStyle{ diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp index a56e9a10f03f..28aded304d38 100644 --- a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp @@ -13,14 +13,11 @@ #include "Protocol.h" #include "llvm/Support/JSON.h" -using namespace mlir; -using namespace mlir::lsp; - //===----------------------------------------------------------------------===// // MLIRConvertBytecodeParams //===----------------------------------------------------------------------===// -bool mlir::lsp::fromJSON(const llvm::json::Value &value, +bool llvm::lsp::fromJSON(const llvm::json::Value &value, MLIRConvertBytecodeParams &result, llvm::json::Path path) { llvm::json::ObjectMapper o(value, path); @@ -31,6 +28,6 @@ bool mlir::lsp::fromJSON(const llvm::json::Value &value, // MLIRConvertBytecodeResult //===----------------------------------------------------------------------===// -llvm::json::Value mlir::lsp::toJSON(const MLIRConvertBytecodeResult &value) { +llvm::json::Value llvm::lsp::toJSON(const MLIRConvertBytecodeResult &value) { return llvm::json::Object{{"output", value.output}}; } diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-lsp-server/Protocol.h index d910780e1ee9..ed0db4e591d8 100644 --- a/mlir/lib/Tools/mlir-lsp-server/Protocol.h +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.h @@ -20,9 +20,9 @@ #ifndef LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_ #define LIB_MLIR_TOOLS_MLIRLSPSERVER_PROTOCOL_H_ -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Protocol.h" -namespace mlir { +namespace llvm { namespace lsp { //===----------------------------------------------------------------------===// // MLIRConvertBytecodeParams @@ -54,6 +54,6 @@ struct MLIRConvertBytecodeResult { llvm::json::Value toJSON(const MLIRConvertBytecodeResult &value); } // namespace lsp -} // namespace mlir +} // namespace llvm #endif diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt index bf25b7e0a64f..b41603fb67eb 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/CMakeLists.txt @@ -7,6 +7,9 @@ llvm_add_library(MLIRPdllLspServerLib ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-pdll-lsp-server + LINK_COMPONENTS + SupportLSP + LINK_LIBS PUBLIC MLIRPDLLCodeGen MLIRPDLLParser diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp index 82542a12a180..7b23adcc7e2e 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.cpp @@ -10,8 +10,9 @@ #include "PDLLServer.h" #include "Protocol.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/Support/LSP/Transport.h" #include #define DEBUG_TYPE "pdll-lsp-server" @@ -19,6 +20,30 @@ using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::CompletionList; +using llvm::lsp::CompletionParams; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentLinkParams; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::DocumentSymbolParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::InlayHintsParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h index 78c4c31100cb..42c0a5d7b6d2 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class PDLLServer; /// Run the main loop of the LSP server using the given PDLL server and /// transport. llvm::LogicalResult runPdllLSPServer(PDLLServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp index 287a131ecd17..5dea130675cd 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.cpp @@ -9,14 +9,17 @@ #include "mlir/Tools/mlir-pdll-lsp-server/MlirPdllLspServerMain.h" #include "LSPServer.h" #include "PDLLServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::Logger; + LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) { llvm::cl::opt inputStyle{ "input-style", @@ -72,7 +75,8 @@ LogicalResult mlir::MlirPdllLspServerMain(int argc, char **argv) { // Configure the transport used for communication. llvm::sys::ChangeStdinToBinary(); - JSONTransport transport(stdin, llvm::outs(), inputStyle, prettyPrint); + llvm::lsp::JSONTransport transport(stdin, llvm::outs(), inputStyle, + prettyPrint); // Configure the servers and start the main language server. PDLLServer::Options options(compilationDatabases, extraIncludeDirs); diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp index 84f529ae1640..60b9567ff780 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.cpp @@ -23,13 +23,13 @@ #include "mlir/Tools/PDLL/Parser/CodeComplete.h" #include "mlir/Tools/PDLL/Parser/Parser.h" #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" -#include "mlir/Tools/lsp-server-support/Logging.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/LSP/Logging.h" #include "llvm/Support/Path.h" #include @@ -38,17 +38,19 @@ using namespace mlir::pdll; /// Returns a language server uri for the given source location. `mainFileURI` /// corresponds to the uri for the main file of the source manager. -static lsp::URIForFile getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, - const lsp::URIForFile &mainFileURI) { +static llvm::lsp::URIForFile +getURIFromLoc(llvm::SourceMgr &mgr, SMRange loc, + const llvm::lsp::URIForFile &mainFileURI) { int bufferId = mgr.FindBufferContainingLoc(loc.Start); if (bufferId == 0 || bufferId == static_cast(mgr.getMainFileID())) return mainFileURI; - llvm::Expected fileForLoc = lsp::URIForFile::fromFile( - mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); + llvm::Expected fileForLoc = + llvm::lsp::URIForFile::fromFile( + mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); if (fileForLoc) return *fileForLoc; - lsp::Logger::error("Failed to create URI for include file: {0}", - llvm::toString(fileForLoc.takeError())); + llvm::lsp::Logger::error("Failed to create URI for include file: {0}", + llvm::toString(fileForLoc.takeError())); return mainFileURI; } @@ -59,16 +61,18 @@ static bool isMainFileLoc(llvm::SourceMgr &mgr, SMRange loc) { } /// Returns a language server location from the given source range. -static lsp::Location getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, - const lsp::URIForFile &uri) { - return lsp::Location(getURIFromLoc(mgr, range, uri), lsp::Range(mgr, range)); +static llvm::lsp::Location +getLocationFromLoc(llvm::SourceMgr &mgr, SMRange range, + const llvm::lsp::URIForFile &uri) { + return llvm::lsp::Location(getURIFromLoc(mgr, range, uri), + llvm::lsp::Range(mgr, range)); } /// Convert the given MLIR diagnostic to the LSP form. -static std::optional +static std::optional getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, - const lsp::URIForFile &uri) { - lsp::Diagnostic lspDiag; + const llvm::lsp::URIForFile &uri) { + llvm::lsp::Diagnostic lspDiag; lspDiag.source = "pdll"; // FIXME: Right now all of the diagnostics are treated as parser issues, but @@ -76,7 +80,8 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, lspDiag.category = "Parse Error"; // Try to grab a file location for this diagnostic. - lsp::Location loc = getLocationFromLoc(sourceMgr, diag.getLocation(), uri); + llvm::lsp::Location loc = + getLocationFromLoc(sourceMgr, diag.getLocation(), uri); lspDiag.range = loc.range; // Skip diagnostics that weren't emitted within the main file. @@ -88,19 +93,19 @@ getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, const ast::Diagnostic &diag, case ast::Diagnostic::Severity::DK_Note: llvm_unreachable("expected notes to be handled separately"); case ast::Diagnostic::Severity::DK_Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; case ast::Diagnostic::Severity::DK_Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; case ast::Diagnostic::Severity::DK_Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.getMessage().str(); // Attach any notes to the main diagnostic as related information. - std::vector relatedDiags; + std::vector relatedDiags; for (const ast::Diagnostic ¬e : diag.getNotes()) { relatedDiags.emplace_back( getLocationFromLoc(sourceMgr, note.getLocation(), uri), @@ -259,9 +264,9 @@ namespace { /// This class represents all of the information pertaining to a specific PDL /// document. struct PDLDocument { - PDLDocument(const lsp::URIForFile &uri, StringRef contents, + PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector &extraDirs, - std::vector &diagnostics); + std::vector &diagnostics); PDLDocument(const PDLDocument &) = delete; PDLDocument &operator=(const PDLDocument &) = delete; @@ -269,76 +274,83 @@ struct PDLDocument { // Definitions and References //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, - std::vector &locations); - void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector &references); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &pos, + std::vector &references); //===--------------------------------------------------------------------===// // Document Links //===--------------------------------------------------------------------===// - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector &links); //===--------------------------------------------------------------------===// // Hover //===--------------------------------------------------------------------===// - std::optional findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos); - std::optional findHover(const ast::Decl *decl, - const SMRange &hoverRange); - lsp::Hover buildHoverForOpName(const ods::Operation *op, - const SMRange &hoverRange); - lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl, - const SMRange &hoverRange); - lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl, - const SMRange &hoverRange); - lsp::Hover buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, + std::optional + findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos); + std::optional findHover(const ast::Decl *decl, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForOpName(const ods::Operation *op, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForVariable(const ast::VariableDecl *varDecl, const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForPattern(const ast::PatternDecl *decl, + const SMRange &hoverRange); + llvm::lsp::Hover + buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, + const SMRange &hoverRange); template - lsp::Hover buildHoverForUserConstraintOrRewrite(StringRef typeName, - const T *decl, - const SMRange &hoverRange); + llvm::lsp::Hover + buildHoverForUserConstraintOrRewrite(StringRef typeName, const T *decl, + const SMRange &hoverRange); //===--------------------------------------------------------------------===// // Document Symbols //===--------------------------------------------------------------------===// - void findDocumentSymbols(std::vector &symbols); + void findDocumentSymbols(std::vector &symbols); //===--------------------------------------------------------------------===// // Code Completion //===--------------------------------------------------------------------===// - lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, - const lsp::Position &completePos); + llvm::lsp::CompletionList + getCodeCompletion(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &completePos); //===--------------------------------------------------------------------===// // Signature Help //===--------------------------------------------------------------------===// - lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, - const lsp::Position &helpPos); + llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &helpPos); //===--------------------------------------------------------------------===// // Inlay Hints //===--------------------------------------------------------------------===// - void getInlayHints(const lsp::URIForFile &uri, const lsp::Range &range, - std::vector &inlayHints); + void getInlayHints(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range, + std::vector &inlayHints); void getInlayHintsFor(const ast::VariableDecl *decl, - const lsp::URIForFile &uri, - std::vector &inlayHints); - void getInlayHintsFor(const ast::CallExpr *expr, const lsp::URIForFile &uri, - std::vector &inlayHints); + const llvm::lsp::URIForFile &uri, + std::vector &inlayHints); + void getInlayHintsFor(const ast::CallExpr *expr, + const llvm::lsp::URIForFile &uri, + std::vector &inlayHints); void getInlayHintsFor(const ast::OperationExpr *expr, - const lsp::URIForFile &uri, - std::vector &inlayHints); + const llvm::lsp::URIForFile &uri, + std::vector &inlayHints); /// Add a parameter hint for the given expression using `label`. - void addParameterHintFor(std::vector &inlayHints, + void addParameterHintFor(std::vector &inlayHints, const ast::Expr *expr, StringRef label); //===--------------------------------------------------------------------===// @@ -372,13 +384,14 @@ struct PDLDocument { }; } // namespace -PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents, +PDLDocument::PDLDocument(const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector &extraDirs, - std::vector &diagnostics) + std::vector &diagnostics) : astContext(odsContext) { auto memBuffer = llvm::MemoryBuffer::getMemBufferCopy(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } @@ -412,9 +425,9 @@ PDLDocument::PDLDocument(const lsp::URIForFile &uri, StringRef contents, // PDLDocument: Definitions and References //===----------------------------------------------------------------------===// -void PDLDocument::getLocationsOf(const lsp::URIForFile &uri, - const lsp::Position &defPos, - std::vector &locations) { +void PDLDocument::getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector &locations) { SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); const PDLIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -423,9 +436,9 @@ void PDLDocument::getLocationsOf(const lsp::URIForFile &uri, locations.push_back(getLocationFromLoc(sourceMgr, symbol->getDefLoc(), uri)); } -void PDLDocument::findReferencesOf(const lsp::URIForFile &uri, - const lsp::Position &pos, - std::vector &references) { +void PDLDocument::findReferencesOf( + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos, + std::vector &references) { SMLoc posLoc = pos.getAsSMLoc(sourceMgr); const PDLIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -440,8 +453,9 @@ void PDLDocument::findReferencesOf(const lsp::URIForFile &uri, // PDLDocument: Document Links //===--------------------------------------------------------------------===// -void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links) { +void PDLDocument::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector &links) { for (const lsp::SourceMgrInclude &include : parsedIncludes) links.emplace_back(include.range, include.uri); } @@ -450,9 +464,9 @@ void PDLDocument::getDocumentLinks(const lsp::URIForFile &uri, // PDLDocument: Hover //===----------------------------------------------------------------------===// -std::optional -PDLDocument::findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos) { +std::optional +PDLDocument::findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos) { SMLoc posLoc = hoverPos.getAsSMLoc(sourceMgr); // Check for a reference to an include. @@ -474,8 +488,8 @@ PDLDocument::findHover(const lsp::URIForFile &uri, return findHover(decl, hoverRange); } -std::optional PDLDocument::findHover(const ast::Decl *decl, - const SMRange &hoverRange) { +std::optional +PDLDocument::findHover(const ast::Decl *decl, const SMRange &hoverRange) { // Add hover for variables. if (const auto *varDecl = dyn_cast(decl)) return buildHoverForVariable(varDecl, hoverRange); @@ -499,9 +513,9 @@ std::optional PDLDocument::findHover(const ast::Decl *decl, return std::nullopt; } -lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**OpName**: `" << op->getName() << "`\n***\n" @@ -511,9 +525,10 @@ lsp::Hover PDLDocument::buildHoverForOpName(const ods::Operation *op, return hover; } -lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover +PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Variable**: `" << varDecl->getName().getName() << "`\n***\n" @@ -522,9 +537,9 @@ lsp::Hover PDLDocument::buildHoverForVariable(const ast::VariableDecl *varDecl, return hover; } -lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Pattern**"; @@ -545,10 +560,10 @@ lsp::Hover PDLDocument::buildHoverForPattern(const ast::PatternDecl *decl, return hover; } -lsp::Hover +llvm::lsp::Hover PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**Constraint**: `"; @@ -573,9 +588,9 @@ PDLDocument::buildHoverForCoreConstraint(const ast::CoreConstraintDecl *decl, } template -lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( +llvm::lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( StringRef typeName, const T *decl, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**" << typeName << "**: `" << decl->getName().getName() @@ -617,7 +632,7 @@ lsp::Hover PDLDocument::buildHoverForUserConstraintOrRewrite( //===----------------------------------------------------------------------===// void PDLDocument::findDocumentSymbols( - std::vector &symbols) { + std::vector &symbols) { if (failed(astModule)) return; @@ -631,25 +646,28 @@ void PDLDocument::findDocumentSymbols( SMRange nameLoc = name ? name->getLoc() : patternDecl->getLoc(); SMRange bodyLoc(nameLoc.Start, patternDecl->getBody()->getLoc().End); - symbols.emplace_back( - name ? name->getName() : "", lsp::SymbolKind::Class, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(name ? name->getName() : "", + llvm::lsp::SymbolKind::Class, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } else if (const auto *cDecl = dyn_cast(decl)) { // TODO: Add source information for the code block body. SMRange nameLoc = cDecl->getName().getLoc(); SMRange bodyLoc = nameLoc; - symbols.emplace_back( - cDecl->getName().getName(), lsp::SymbolKind::Function, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(cDecl->getName().getName(), + llvm::lsp::SymbolKind::Function, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } else if (const auto *cDecl = dyn_cast(decl)) { // TODO: Add source information for the code block body. SMRange nameLoc = cDecl->getName().getLoc(); SMRange bodyLoc = nameLoc; - symbols.emplace_back( - cDecl->getName().getName(), lsp::SymbolKind::Function, - lsp::Range(sourceMgr, bodyLoc), lsp::Range(sourceMgr, nameLoc)); + symbols.emplace_back(cDecl->getName().getName(), + llvm::lsp::SymbolKind::Function, + llvm::lsp::Range(sourceMgr, bodyLoc), + llvm::lsp::Range(sourceMgr, nameLoc)); } } } @@ -662,7 +680,7 @@ namespace { class LSPCodeCompleteContext : public CodeCompleteContext { public: LSPCodeCompleteContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr, - lsp::CompletionList &completionList, + llvm::lsp::CompletionList &completionList, ods::Context &odsContext, ArrayRef includeDirs) : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr), @@ -674,13 +692,13 @@ public: ArrayRef elementNames = tupleType.getElementNames(); for (unsigned i = 0, e = tupleType.size(); i < e; ++i) { // Push back a completion item that uses the result index. - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = llvm::formatv("{0} (field #{0})", i).str(); item.insertText = Twine(i).str(); item.filterText = item.sortText = item.insertText; - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; item.detail = llvm::formatv("{0}: {1}", i, elementTypes[i]); - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); // If the element has a name, push back a completion item with that name. @@ -705,11 +723,11 @@ public: const ods::TypeConstraint &constraint = result.getConstraint(); // Push back a completion item that uses the result index. - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = llvm::formatv("{0} (field #{0})", it.index()).str(); item.insertText = Twine(it.index()).str(); item.filterText = item.sortText = item.insertText; - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; switch (result.getVariableLengthKind()) { case ods::VariableLengthKind::Single: item.detail = llvm::formatv("{0}: Value", it.index()).str(); @@ -721,12 +739,12 @@ public: item.detail = llvm::formatv("{0}: ValueRange", it.index()).str(); break; } - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(), constraint.getCppClass()) .str()}; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); // If the result has a name, push back a completion item with the result @@ -750,16 +768,16 @@ public: for (const ods::Attribute &attr : odsOp->getAttributes()) { const ods::AttributeConstraint &constraint = attr.getConstraint(); - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = attr.getName().str(); - item.kind = lsp::CompletionItemKind::Field; + item.kind = llvm::lsp::CompletionItemKind::Field; item.detail = attr.isOptional() ? "optional" : ""; - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, llvm::formatv("{0}\n\n```c++\n{1}\n```\n", constraint.getSummary(), constraint.getCppClass()) .str()}; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -769,18 +787,18 @@ public: const ast::DeclScope *scope) final { auto addCoreConstraint = [&](StringRef constraint, StringRef mlirType, StringRef snippetText = "") { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = constraint.str(); - item.kind = lsp::CompletionItemKind::Class; + item.kind = llvm::lsp::CompletionItemKind::Class; item.detail = (constraint + " constraint").str(); - item.documentation = lsp::MarkupContent{ - lsp::MarkupKind::Markdown, + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, ("A single entity core constraint of type `" + mlirType + "`").str()}; item.sortText = "0"; item.insertText = snippetText.str(); item.insertTextFormat = snippetText.empty() - ? lsp::InsertTextFormat::PlainText - : lsp::InsertTextFormat::Snippet; + ? llvm::lsp::InsertTextFormat::PlainText + : llvm::lsp::InsertTextFormat::Snippet; completionList.items.emplace_back(item); }; @@ -812,9 +830,9 @@ public: while (scope) { for (const ast::Decl *decl : scope->getDecls()) { if (const auto *cst = dyn_cast(decl)) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = cst->getName().getName().str(); - item.kind = lsp::CompletionItemKind::Interface; + item.kind = llvm::lsp::CompletionItemKind::Interface; item.sortText = "2_" + item.label; // Skip constraints that are not single-arg. We currently only @@ -841,8 +859,8 @@ public: // Format the documentation for the constraint. if (std::optional doc = getDocumentationFor(sourceMgr, cst)) { - item.documentation = - lsp::MarkupContent{lsp::MarkupKind::Markdown, std::move(*doc)}; + item.documentation = llvm::lsp::MarkupContent{ + llvm::lsp::MarkupKind::Markdown, std::move(*doc)}; } completionList.items.emplace_back(item); @@ -856,10 +874,10 @@ public: void codeCompleteDialectName() final { // Code complete known dialects. for (const ods::Dialect &dialect : odsContext.getDialects()) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = dialect.getName().str(); - item.kind = lsp::CompletionItemKind::Class; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.kind = llvm::lsp::CompletionItemKind::Class; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -872,10 +890,10 @@ public: for (const auto &it : dialect->getOperations()) { const ods::Operation &op = *it.second; - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = op.getName().drop_front(dialectName.size() + 1).str(); - item.kind = lsp::CompletionItemKind::Field; - item.insertTextFormat = lsp::InsertTextFormat::PlainText; + item.kind = llvm::lsp::CompletionItemKind::Field; + item.insertTextFormat = llvm::lsp::InsertTextFormat::PlainText; completionList.items.emplace_back(item); } } @@ -883,16 +901,16 @@ public: void codeCompletePatternMetadata() final { auto addSimpleConstraint = [&](StringRef constraint, StringRef desc, StringRef snippetText = "") { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = constraint.str(); - item.kind = lsp::CompletionItemKind::Class; + item.kind = llvm::lsp::CompletionItemKind::Class; item.detail = "pattern metadata"; item.documentation = - lsp::MarkupContent{lsp::MarkupKind::Markdown, desc.str()}; + llvm::lsp::MarkupContent{llvm::lsp::MarkupKind::Markdown, desc.str()}; item.insertText = snippetText.str(); item.insertTextFormat = snippetText.empty() - ? lsp::InsertTextFormat::PlainText - : lsp::InsertTextFormat::Snippet; + ? llvm::lsp::InsertTextFormat::PlainText + : llvm::lsp::InsertTextFormat::Snippet; completionList.items.emplace_back(item); }; @@ -913,10 +931,10 @@ public: // Functor used to add a single include completion item. auto addIncludeCompletion = [&](StringRef path, bool isDirectory) { - lsp::CompletionItem item; + llvm::lsp::CompletionItem item; item.label = path.str(); - item.kind = isDirectory ? lsp::CompletionItemKind::Folder - : lsp::CompletionItemKind::File; + item.kind = isDirectory ? llvm::lsp::CompletionItemKind::Folder + : llvm::lsp::CompletionItemKind::File; if (seenResults.insert(item.label).second) completionList.items.emplace_back(item); }; @@ -961,31 +979,31 @@ public: // Sort the completion results to make sure the output is deterministic in // the face of different iteration schemes for different platforms. - llvm::sort(completionList.items, [](const lsp::CompletionItem &lhs, - const lsp::CompletionItem &rhs) { + llvm::sort(completionList.items, [](const llvm::lsp::CompletionItem &lhs, + const llvm::lsp::CompletionItem &rhs) { return lhs.label < rhs.label; }); } private: llvm::SourceMgr &sourceMgr; - lsp::CompletionList &completionList; + llvm::lsp::CompletionList &completionList; ods::Context &odsContext; ArrayRef includeDirs; }; } // namespace -lsp::CompletionList -PDLDocument::getCodeCompletion(const lsp::URIForFile &uri, - const lsp::Position &completePos) { +llvm::lsp::CompletionList +PDLDocument::getCodeCompletion(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &completePos) { SMLoc posLoc = completePos.getAsSMLoc(sourceMgr); if (!posLoc.isValid()) - return lsp::CompletionList(); + return llvm::lsp::CompletionList(); // To perform code completion, we run another parse of the module with the // code completion context provided. ods::Context tmpODSContext; - lsp::CompletionList completionList; + llvm::lsp::CompletionList completionList; LSPCodeCompleteContext lspCompleteContext(posLoc, sourceMgr, completionList, tmpODSContext, sourceMgr.getIncludeDirs()); @@ -1005,7 +1023,7 @@ namespace { class LSPSignatureHelpContext : public CodeCompleteContext { public: LSPSignatureHelpContext(SMLoc completeLoc, llvm::SourceMgr &sourceMgr, - lsp::SignatureHelp &signatureHelp, + llvm::lsp::SignatureHelp &signatureHelp, ods::Context &odsContext) : CodeCompleteContext(completeLoc), sourceMgr(sourceMgr), signatureHelp(signatureHelp), odsContext(odsContext) {} @@ -1014,7 +1032,7 @@ public: unsigned currentNumArgs) final { signatureHelp.activeParameter = currentNumArgs; - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; { llvm::raw_string_ostream strOS(signatureInfo.label); strOS << callable->getName()->getName() << "("; @@ -1022,7 +1040,7 @@ public: unsigned paramStart = strOS.str().size(); strOS << var->getName().getName() << ": " << var->getType(); unsigned paramEnd = strOS.str().size(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(strOS.str()).slice(paramStart, paramEnd).str(), std::make_pair(paramStart, paramEnd), /*paramDoc*/ std::string()}); }; @@ -1070,7 +1088,7 @@ public: // not more than what is defined in ODS, as this will result in an error // anyways. if (odsOp && currentValue < values.size()) { - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; // Build the signature label. { @@ -1099,7 +1117,7 @@ public: } unsigned paramEnd = strOS.str().size(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(strOS.str()).slice(paramStart, paramEnd).str(), std::make_pair(paramStart, paramEnd), paramDoc}); }; @@ -1114,12 +1132,12 @@ public: // If there aren't any arguments yet, we also add the generic signature. if (currentValue == 0 && (!odsOp || !values.empty())) { - lsp::SignatureInformation signatureInfo; + llvm::lsp::SignatureInformation signatureInfo; signatureInfo.label = llvm::formatv("(<{0}s>: {1}Range)", label, dataType).str(); signatureInfo.documentation = ("Generic operation " + label + " specification").str(); - signatureInfo.parameters.emplace_back(lsp::ParameterInformation{ + signatureInfo.parameters.emplace_back(llvm::lsp::ParameterInformation{ StringRef(signatureInfo.label).drop_front().drop_back().str(), std::pair(1, signatureInfo.label.size() - 1), ("All of the " + label + "s of the operation.").str()}); @@ -1129,21 +1147,22 @@ public: private: llvm::SourceMgr &sourceMgr; - lsp::SignatureHelp &signatureHelp; + llvm::lsp::SignatureHelp &signatureHelp; ods::Context &odsContext; }; } // namespace -lsp::SignatureHelp PDLDocument::getSignatureHelp(const lsp::URIForFile &uri, - const lsp::Position &helpPos) { +llvm::lsp::SignatureHelp +PDLDocument::getSignatureHelp(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &helpPos) { SMLoc posLoc = helpPos.getAsSMLoc(sourceMgr); if (!posLoc.isValid()) - return lsp::SignatureHelp(); + return llvm::lsp::SignatureHelp(); // To perform code completion, we run another parse of the module with the // code completion context provided. ods::Context tmpODSContext; - lsp::SignatureHelp signatureHelp; + llvm::lsp::SignatureHelp signatureHelp; LSPSignatureHelpContext completeContext(posLoc, sourceMgr, signatureHelp, tmpODSContext); @@ -1173,9 +1192,9 @@ static bool shouldAddHintFor(const ast::Expr *expr, StringRef name) { return true; } -void PDLDocument::getInlayHints(const lsp::URIForFile &uri, - const lsp::Range &range, - std::vector &inlayHints) { +void PDLDocument::getInlayHints(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Range &range, + std::vector &inlayHints) { if (failed(astModule)) return; SMRange rangeLoc = range.getAsSMRange(sourceMgr); @@ -1198,9 +1217,9 @@ void PDLDocument::getInlayHints(const lsp::URIForFile &uri, }); } -void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, - const lsp::URIForFile &uri, - std::vector &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::VariableDecl *decl, const llvm::lsp::URIForFile &uri, + std::vector &inlayHints) { // Check to see if the variable has a constraint list, if it does we don't // provide initializer hints. if (!decl->getConstraints().empty()) @@ -1215,8 +1234,8 @@ void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, return; } - lsp::InlayHint hint(lsp::InlayHintKind::Type, - lsp::Position(sourceMgr, decl->getLoc().End)); + llvm::lsp::InlayHint hint(llvm::lsp::InlayHintKind::Type, + llvm::lsp::Position(sourceMgr, decl->getLoc().End)); { llvm::raw_string_ostream labelOS(hint.label); labelOS << ": " << decl->getType(); @@ -1225,9 +1244,9 @@ void PDLDocument::getInlayHintsFor(const ast::VariableDecl *decl, inlayHints.emplace_back(std::move(hint)); } -void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr, - const lsp::URIForFile &uri, - std::vector &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::CallExpr *expr, const llvm::lsp::URIForFile &uri, + std::vector &inlayHints) { // Try to extract the callable of this call. const auto *callableRef = dyn_cast(expr->getCallableExpr()); const auto *callable = @@ -1242,9 +1261,9 @@ void PDLDocument::getInlayHintsFor(const ast::CallExpr *expr, std::get<1>(it)->getName().getName()); } -void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr, - const lsp::URIForFile &uri, - std::vector &inlayHints) { +void PDLDocument::getInlayHintsFor( + const ast::OperationExpr *expr, const llvm::lsp::URIForFile &uri, + std::vector &inlayHints) { // Check for ODS information. ast::OperationType opType = dyn_cast(expr->getType()); const auto *odsOp = opType ? opType.getODSOperation() : nullptr; @@ -1290,13 +1309,15 @@ void PDLDocument::getInlayHintsFor(const ast::OperationExpr *expr, "results"); } -void PDLDocument::addParameterHintFor(std::vector &inlayHints, - const ast::Expr *expr, StringRef label) { +void PDLDocument::addParameterHintFor( + std::vector &inlayHints, const ast::Expr *expr, + StringRef label) { if (!shouldAddHintFor(expr, label)) return; - lsp::InlayHint hint(lsp::InlayHintKind::Parameter, - lsp::Position(sourceMgr, expr->getLoc().Start)); + llvm::lsp::InlayHint hint( + llvm::lsp::InlayHintKind::Parameter, + llvm::lsp::Position(sourceMgr, expr->getLoc().Start)); hint.label = (label + ":").str(); hint.paddingRight = true; inlayHints.emplace_back(std::move(hint)); @@ -1342,22 +1363,24 @@ void PDLDocument::getPDLLViewOutput(raw_ostream &os, namespace { /// This class represents a single chunk of an PDL text file. struct PDLTextFileChunk { - PDLTextFileChunk(uint64_t lineOffset, const lsp::URIForFile &uri, + PDLTextFileChunk(uint64_t lineOffset, const llvm::lsp::URIForFile &uri, StringRef contents, const std::vector &extraDirs, - std::vector &diagnostics) + std::vector &diagnostics) : lineOffset(lineOffset), document(uri, contents, extraDirs, diagnostics) {} /// Adjust the line number of the given range to anchor at the beginning of /// the file, instead of the beginning of this chunk. - void adjustLocForChunkOffset(lsp::Range &range) { + void adjustLocForChunkOffset(llvm::lsp::Range &range) { adjustLocForChunkOffset(range.start); adjustLocForChunkOffset(range.end); } /// Adjust the line number of the given position to anchor at the beginning of /// the file, instead of the beginning of this chunk. - void adjustLocForChunkOffset(lsp::Position &pos) { pos.line += lineOffset; } + void adjustLocForChunkOffset(llvm::lsp::Position &pos) { + pos.line += lineOffset; + } /// The line offset of this chunk from the beginning of the file. uint64_t lineOffset; @@ -1374,38 +1397,41 @@ namespace { /// This class represents a text file containing one or more PDL documents. class PDLTextFile { public: - PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents, + PDLTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector &extraDirs, - std::vector &diagnostics); + std::vector &diagnostics); /// Return the current version of this text file. int64_t getVersion() const { return version; } /// Update the file to the new version using the provided set of content /// changes. Returns failure if the update was unsuccessful. - LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef changes, - std::vector &diagnostics); + LogicalResult + update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef changes, + std::vector &diagnostics); //===--------------------------------------------------------------------===// // LSP Queries //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, lsp::Position defPos, - std::vector &locations); - void findReferencesOf(const lsp::URIForFile &uri, lsp::Position pos, - std::vector &references); - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links); - std::optional findHover(const lsp::URIForFile &uri, - lsp::Position hoverPos); - void findDocumentSymbols(std::vector &symbols); - lsp::CompletionList getCodeCompletion(const lsp::URIForFile &uri, - lsp::Position completePos); - lsp::SignatureHelp getSignatureHelp(const lsp::URIForFile &uri, - lsp::Position helpPos); - void getInlayHints(const lsp::URIForFile &uri, lsp::Range range, - std::vector &inlayHints); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position defPos, + std::vector &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position pos, + std::vector &references); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector &links); + std::optional findHover(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position hoverPos); + void findDocumentSymbols(std::vector &symbols); + llvm::lsp::CompletionList getCodeCompletion(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position completePos); + llvm::lsp::SignatureHelp getSignatureHelp(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position helpPos); + void getInlayHints(const llvm::lsp::URIForFile &uri, llvm::lsp::Range range, + std::vector &inlayHints); lsp::PDLLViewOutputResult getPDLLViewOutput(lsp::PDLLViewOutputKind kind); private: @@ -1413,14 +1439,14 @@ private: std::vector>::iterator>; /// Initialize the text file from the given file contents. - void initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector &diagnostics); + void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector &diagnostics); /// Find the PDL document that contains the given position, and update the /// position to be anchored at the start of the found chunk instead of the /// beginning of the file. - ChunkIterator getChunkItFor(lsp::Position &pos); - PDLTextFileChunk &getChunkFor(lsp::Position &pos) { + ChunkIterator getChunkItFor(llvm::lsp::Position &pos); + PDLTextFileChunk &getChunkFor(llvm::lsp::Position &pos) { return *getChunkItFor(pos); } @@ -1442,20 +1468,21 @@ private: }; } // namespace -PDLTextFile::PDLTextFile(const lsp::URIForFile &uri, StringRef fileContents, - int64_t version, +PDLTextFile::PDLTextFile(const llvm::lsp::URIForFile &uri, + StringRef fileContents, int64_t version, const std::vector &extraDirs, - std::vector &diagnostics) + std::vector &diagnostics) : contents(fileContents.str()), extraIncludeDirs(extraDirs) { initialize(uri, version, diagnostics); } LogicalResult -PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef changes, - std::vector &diagnostics) { - if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) { - lsp::Logger::error("Failed to update contents of {0}", uri.file()); +PDLTextFile::update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef changes, + std::vector &diagnostics) { + if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes, + contents))) { + llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file()); return failure(); } @@ -1464,36 +1491,37 @@ PDLTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, return success(); } -void PDLTextFile::getLocationsOf(const lsp::URIForFile &uri, - lsp::Position defPos, - std::vector &locations) { +void PDLTextFile::getLocationsOf(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position defPos, + std::vector &locations) { PDLTextFileChunk &chunk = getChunkFor(defPos); chunk.document.getLocationsOf(uri, defPos, locations); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset == 0) return; - for (lsp::Location &loc : locations) + for (llvm::lsp::Location &loc : locations) if (loc.uri == uri) chunk.adjustLocForChunkOffset(loc.range); } -void PDLTextFile::findReferencesOf(const lsp::URIForFile &uri, - lsp::Position pos, - std::vector &references) { +void PDLTextFile::findReferencesOf( + const llvm::lsp::URIForFile &uri, llvm::lsp::Position pos, + std::vector &references) { PDLTextFileChunk &chunk = getChunkFor(pos); chunk.document.findReferencesOf(uri, pos, references); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset == 0) return; - for (lsp::Location &loc : references) + for (llvm::lsp::Location &loc : references) if (loc.uri == uri) chunk.adjustLocForChunkOffset(loc.range); } -void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links) { +void PDLTextFile::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector &links) { chunks.front()->document.getDocumentLinks(uri, links); for (const auto &it : llvm::drop_begin(chunks)) { size_t currentNumLinks = links.size(); @@ -1506,10 +1534,12 @@ void PDLTextFile::getDocumentLinks(const lsp::URIForFile &uri, } } -std::optional PDLTextFile::findHover(const lsp::URIForFile &uri, - lsp::Position hoverPos) { +std::optional +PDLTextFile::findHover(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position hoverPos) { PDLTextFileChunk &chunk = getChunkFor(hoverPos); - std::optional hoverInfo = chunk.document.findHover(uri, hoverPos); + std::optional hoverInfo = + chunk.document.findHover(uri, hoverPos); // Adjust any locations within this file for the offset of this chunk. if (chunk.lineOffset != 0 && hoverInfo && hoverInfo->range) @@ -1518,7 +1548,7 @@ std::optional PDLTextFile::findHover(const lsp::URIForFile &uri, } void PDLTextFile::findDocumentSymbols( - std::vector &symbols) { + std::vector &symbols) { if (chunks.size() == 1) return chunks.front()->document.findDocumentSymbols(symbols); @@ -1526,27 +1556,27 @@ void PDLTextFile::findDocumentSymbols( // each chunk. for (unsigned i = 0, e = chunks.size(); i < e; ++i) { PDLTextFileChunk &chunk = *chunks[i]; - lsp::Position startPos(chunk.lineOffset); - lsp::Position endPos((i == e - 1) ? totalNumLines - 1 - : chunks[i + 1]->lineOffset); - lsp::DocumentSymbol symbol("", - lsp::SymbolKind::Namespace, - /*range=*/lsp::Range(startPos, endPos), - /*selectionRange=*/lsp::Range(startPos)); + llvm::lsp::Position startPos(chunk.lineOffset); + llvm::lsp::Position endPos((i == e - 1) ? totalNumLines - 1 + : chunks[i + 1]->lineOffset); + llvm::lsp::DocumentSymbol symbol( + "", llvm::lsp::SymbolKind::Namespace, + /*range=*/llvm::lsp::Range(startPos, endPos), + /*selectionRange=*/llvm::lsp::Range(startPos)); chunk.document.findDocumentSymbols(symbol.children); // Fixup the locations of document symbols within this chunk. if (i != 0) { - SmallVector symbolsToFix; - for (lsp::DocumentSymbol &childSymbol : symbol.children) + SmallVector symbolsToFix; + for (llvm::lsp::DocumentSymbol &childSymbol : symbol.children) symbolsToFix.push_back(&childSymbol); while (!symbolsToFix.empty()) { - lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); + llvm::lsp::DocumentSymbol *symbol = symbolsToFix.pop_back_val(); chunk.adjustLocForChunkOffset(symbol->range); chunk.adjustLocForChunkOffset(symbol->selectionRange); - for (lsp::DocumentSymbol &childSymbol : symbol->children) + for (llvm::lsp::DocumentSymbol &childSymbol : symbol->children) symbolsToFix.push_back(&childSymbol); } } @@ -1556,34 +1586,37 @@ void PDLTextFile::findDocumentSymbols( } } -lsp::CompletionList PDLTextFile::getCodeCompletion(const lsp::URIForFile &uri, - lsp::Position completePos) { +llvm::lsp::CompletionList +PDLTextFile::getCodeCompletion(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position completePos) { PDLTextFileChunk &chunk = getChunkFor(completePos); - lsp::CompletionList completionList = + llvm::lsp::CompletionList completionList = chunk.document.getCodeCompletion(uri, completePos); // Adjust any completion locations. - for (lsp::CompletionItem &item : completionList.items) { + for (llvm::lsp::CompletionItem &item : completionList.items) { if (item.textEdit) chunk.adjustLocForChunkOffset(item.textEdit->range); - for (lsp::TextEdit &edit : item.additionalTextEdits) + for (llvm::lsp::TextEdit &edit : item.additionalTextEdits) chunk.adjustLocForChunkOffset(edit.range); } return completionList; } -lsp::SignatureHelp PDLTextFile::getSignatureHelp(const lsp::URIForFile &uri, - lsp::Position helpPos) { +llvm::lsp::SignatureHelp +PDLTextFile::getSignatureHelp(const llvm::lsp::URIForFile &uri, + llvm::lsp::Position helpPos) { return getChunkFor(helpPos).document.getSignatureHelp(uri, helpPos); } -void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range, - std::vector &inlayHints) { +void PDLTextFile::getInlayHints(const llvm::lsp::URIForFile &uri, + llvm::lsp::Range range, + std::vector &inlayHints) { auto startIt = getChunkItFor(range.start); auto endIt = getChunkItFor(range.end); // Functor used to get the chunks for a given file, and fixup any locations - auto getHintsForChunk = [&](ChunkIterator chunkIt, lsp::Range range) { + auto getHintsForChunk = [&](ChunkIterator chunkIt, llvm::lsp::Range range) { size_t currentNumHints = inlayHints.size(); chunkIt->document.getInlayHints(uri, range, inlayHints); @@ -1605,15 +1638,16 @@ void PDLTextFile::getInlayHints(const lsp::URIForFile &uri, lsp::Range range, // Otherwise, the range is split between multiple chunks. The first chunk // has the correct range start, but covers the total document. - getHintsForChunk(startIt, lsp::Range(range.start, getNumLines(startIt))); + getHintsForChunk(startIt, + llvm::lsp::Range(range.start, getNumLines(startIt))); // Every chunk in between uses the full document. for (++startIt; startIt != endIt; ++startIt) - getHintsForChunk(startIt, lsp::Range(0, getNumLines(startIt))); + getHintsForChunk(startIt, llvm::lsp::Range(0, getNumLines(startIt))); // The range for the last chunk starts at the beginning of the document, up // through the end of the input range. - getHintsForChunk(startIt, lsp::Range(0, range.end)); + getHintsForChunk(startIt, llvm::lsp::Range(0, range.end)); } lsp::PDLLViewOutputResult @@ -1632,8 +1666,9 @@ PDLTextFile::getPDLLViewOutput(lsp::PDLLViewOutputKind kind) { return result; } -void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector &diagnostics) { +void PDLTextFile::initialize(const llvm::lsp::URIForFile &uri, + int64_t newVersion, + std::vector &diagnostics) { version = newVersion; chunks.clear(); @@ -1653,7 +1688,7 @@ void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, // Adjust locations used in diagnostics to account for the offset from the // beginning of the file. - for (lsp::Diagnostic &diag : + for (llvm::lsp::Diagnostic &diag : llvm::drop_begin(diagnostics, currentNumDiags)) { chunk->adjustLocForChunkOffset(diag.range); @@ -1668,14 +1703,15 @@ void PDLTextFile::initialize(const lsp::URIForFile &uri, int64_t newVersion, totalNumLines = lineOffset; } -PDLTextFile::ChunkIterator PDLTextFile::getChunkItFor(lsp::Position &pos) { +PDLTextFile::ChunkIterator +PDLTextFile::getChunkItFor(llvm::lsp::Position &pos) { if (chunks.size() == 1) return chunks.begin(); // Search for the first chunk with a greater line offset, the previous chunk // is the one that contains `pos`. auto it = llvm::upper_bound( - chunks, pos, [](const lsp::Position &pos, const auto &chunk) { + chunks, pos, [](const llvm::lsp::Position &pos, const auto &chunk) { return static_cast(pos.line) < chunk->lineOffset; }); ChunkIterator chunkIt(it == chunks.end() ? (chunks.end() - 1) : --it); @@ -1710,9 +1746,9 @@ lsp::PDLLServer::PDLLServer(const Options &options) : impl(std::make_unique(options)) {} lsp::PDLLServer::~PDLLServer() = default; -void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents, - int64_t version, - std::vector &diagnostics) { +void lsp::PDLLServer::addDocument( + const URIForFile &uri, StringRef contents, int64_t version, + std::vector &diagnostics) { // Build the set of additional include directories. std::vector additionalIncludeDirs = impl->options.extraDirs; const auto &fileInfo = impl->compilationDatabase.getFileInfo(uri.file()); @@ -1724,7 +1760,7 @@ void lsp::PDLLServer::addDocument(const URIForFile &uri, StringRef contents, void lsp::PDLLServer::updateDocument( const URIForFile &uri, ArrayRef changes, - int64_t version, std::vector &diagnostics) { + int64_t version, std::vector &diagnostics) { // Check that we actually have a document for this uri. auto it = impl->files.find(uri.file()); if (it == impl->files.end()) @@ -1746,17 +1782,17 @@ std::optional lsp::PDLLServer::removeDocument(const URIForFile &uri) { return version; } -void lsp::PDLLServer::getLocationsOf(const URIForFile &uri, - const Position &defPos, - std::vector &locations) { +void lsp::PDLLServer::getLocationsOf( + const URIForFile &uri, const Position &defPos, + std::vector &locations) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->getLocationsOf(uri, defPos, locations); } -void lsp::PDLLServer::findReferencesOf(const URIForFile &uri, - const Position &pos, - std::vector &references) { +void lsp::PDLLServer::findReferencesOf( + const URIForFile &uri, const Position &pos, + std::vector &references) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) fileIt->second->findReferencesOf(uri, pos, references); @@ -1769,8 +1805,8 @@ void lsp::PDLLServer::getDocumentLinks( return fileIt->second->getDocumentLinks(uri, documentLinks); } -std::optional lsp::PDLLServer::findHover(const URIForFile &uri, - const Position &hoverPos) { +std::optional +lsp::PDLLServer::findHover(const URIForFile &uri, const Position &hoverPos) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) return fileIt->second->findHover(uri, hoverPos); @@ -1793,8 +1829,9 @@ lsp::PDLLServer::getCodeCompletion(const URIForFile &uri, return CompletionList(); } -lsp::SignatureHelp lsp::PDLLServer::getSignatureHelp(const URIForFile &uri, - const Position &helpPos) { +llvm::lsp::SignatureHelp +lsp::PDLLServer::getSignatureHelp(const URIForFile &uri, + const Position &helpPos) { auto fileIt = impl->files.find(uri.file()); if (fileIt != impl->files.end()) return fileIt->second->getSignatureHelp(uri, helpPos); diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h index 134431fa63bf..d82014d6b068 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/PDLLServer.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Protocol.h" #include #include #include @@ -18,21 +19,22 @@ namespace mlir { namespace lsp { -struct Diagnostic; +using llvm::lsp::CompletionList; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentLink; +using llvm::lsp::DocumentSymbol; +using llvm::lsp::Hover; +using llvm::lsp::InlayHint; +using llvm::lsp::Location; +using llvm::lsp::Position; +using llvm::lsp::Range; +using llvm::lsp::SignatureHelp; +using llvm::lsp::TextDocumentContentChangeEvent; +using llvm::lsp::URIForFile; + class CompilationDatabase; struct PDLLViewOutputResult; enum class PDLLViewOutputKind; -struct CompletionList; -struct DocumentLink; -struct DocumentSymbol; -struct Hover; -struct InlayHint; -struct Location; -struct Position; -struct Range; -struct SignatureHelp; -struct TextDocumentContentChangeEvent; -class URIForFile; /// This class implements all of the PDLL related functionality necessary for a /// language server. This class allows for keeping the PDLL specific logic diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp index 0c9896e3ec1b..ace460536aa1 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "Protocol.h" +#include "mlir/Support/LLVM.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" diff --git a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h index 070631663185..a2775f8cbadc 100644 --- a/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h +++ b/mlir/lib/Tools/mlir-pdll-lsp-server/Protocol.h @@ -20,10 +20,12 @@ #ifndef LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_ #define LIB_MLIR_TOOLS_MLIRPDLLLSPSERVER_PROTOCOL_H_ -#include "mlir/Tools/lsp-server-support/Protocol.h" +#include "llvm/Support/LSP/Protocol.h" namespace mlir { namespace lsp { +using llvm::lsp::URIForFile; + //===----------------------------------------------------------------------===// // PDLLViewOutputParams //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt b/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt index 80fc1ffe4029..b21650ed03b6 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt +++ b/mlir/lib/Tools/tblgen-lsp-server/CMakeLists.txt @@ -2,6 +2,7 @@ set(LLVM_LINK_COMPONENTS Demangle Support TableGen + SupportLSP ) llvm_add_library(TableGenLspServerLib diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp index bb3c0a77747a..95a457f3144c 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.cpp @@ -9,14 +9,33 @@ #include "LSPServer.h" #include "TableGenServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" -#include "mlir/Tools/lsp-server-support/Transport.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" +#include "llvm/Support/LSP/Transport.h" #include using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::Callback; +using llvm::lsp::DidChangeTextDocumentParams; +using llvm::lsp::DidCloseTextDocumentParams; +using llvm::lsp::DidOpenTextDocumentParams; +using llvm::lsp::DocumentLinkParams; +using llvm::lsp::Hover; +using llvm::lsp::InitializedParams; +using llvm::lsp::InitializeParams; +using llvm::lsp::JSONTransport; +using llvm::lsp::Location; +using llvm::lsp::Logger; +using llvm::lsp::MessageHandler; +using llvm::lsp::NoParams; +using llvm::lsp::OutgoingNotification; +using llvm::lsp::PublishDiagnosticsParams; +using llvm::lsp::ReferenceParams; +using llvm::lsp::TextDocumentPositionParams; +using llvm::lsp::TextDocumentSyncKind; + //===----------------------------------------------------------------------===// // LSPServer //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h index 501a9dada8aa..596688b62f8d 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h +++ b/mlir/lib/Tools/tblgen-lsp-server/LSPServer.h @@ -13,17 +13,19 @@ namespace llvm { struct LogicalResult; +namespace lsp { +class JSONTransport; +} // namespace lsp } // namespace llvm namespace mlir { namespace lsp { -class JSONTransport; class TableGenServer; /// Run the main loop of the LSP server using the given TableGen server and /// transport. llvm::LogicalResult runTableGenLSPServer(TableGenServer &server, - JSONTransport &transport); + llvm::lsp::JSONTransport &transport); } // namespace lsp } // namespace mlir diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp index 21af78c9a506..8014b8d6dba4 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenLspServerMain.cpp @@ -9,14 +9,18 @@ #include "mlir/Tools/tblgen-lsp-server/TableGenLspServerMain.h" #include "LSPServer.h" #include "TableGenServer.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Transport.h" #include "llvm/Support/Program.h" using namespace mlir; using namespace mlir::lsp; +using llvm::lsp::JSONStreamStyle; +using llvm::lsp::JSONTransport; +using llvm::lsp::Logger; + LogicalResult mlir::TableGenLspServerMain(int argc, char **argv) { llvm::cl::opt inputStyle{ "input-style", diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp index 5faeeae839f4..3080b78f187b 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.cpp @@ -10,12 +10,12 @@ #include "mlir/Support/IndentedOstream.h" #include "mlir/Tools/lsp-server-support/CompilationDatabase.h" -#include "mlir/Tools/lsp-server-support/Logging.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "mlir/Tools/lsp-server-support/SourceMgrUtils.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/LSP/Logging.h" +#include "llvm/Support/LSP/Protocol.h" #include "llvm/Support/Path.h" #include "llvm/TableGen/Parser.h" #include "llvm/TableGen/Record.h" @@ -36,45 +36,49 @@ static SMRange convertTokenLocToRange(SMLoc loc) { /// Returns a language server uri for the given source location. `mainFileURI` /// corresponds to the uri for the main file of the source manager. -static lsp::URIForFile getURIFromLoc(const SourceMgr &mgr, SMLoc loc, - const lsp::URIForFile &mainFileURI) { +static llvm::lsp::URIForFile +getURIFromLoc(const SourceMgr &mgr, SMLoc loc, + const llvm::lsp::URIForFile &mainFileURI) { int bufferId = mgr.FindBufferContainingLoc(loc); if (bufferId == 0 || bufferId == static_cast(mgr.getMainFileID())) return mainFileURI; - llvm::Expected fileForLoc = lsp::URIForFile::fromFile( - mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); + llvm::Expected fileForLoc = + llvm::lsp::URIForFile::fromFile( + mgr.getBufferInfo(bufferId).Buffer->getBufferIdentifier()); if (fileForLoc) return *fileForLoc; - lsp::Logger::error("Failed to create URI for include file: {0}", - llvm::toString(fileForLoc.takeError())); + llvm::lsp::Logger::error("Failed to create URI for include file: {0}", + llvm::toString(fileForLoc.takeError())); return mainFileURI; } /// Returns a language server location from the given source range. -static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMRange loc, - const lsp::URIForFile &uri) { - return lsp::Location(getURIFromLoc(mgr, loc.Start, uri), - lsp::Range(mgr, loc)); +static llvm::lsp::Location +getLocationFromLoc(SourceMgr &mgr, SMRange loc, + const llvm::lsp::URIForFile &uri) { + return llvm::lsp::Location(getURIFromLoc(mgr, loc.Start, uri), + llvm::lsp::Range(mgr, loc)); } -static lsp::Location getLocationFromLoc(SourceMgr &mgr, SMLoc loc, - const lsp::URIForFile &uri) { +static llvm::lsp::Location +getLocationFromLoc(SourceMgr &mgr, SMLoc loc, + const llvm::lsp::URIForFile &uri) { return getLocationFromLoc(mgr, convertTokenLocToRange(loc), uri); } /// Convert the given TableGen diagnostic to the LSP form. -static std::optional +static std::optional getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag, - const lsp::URIForFile &uri) { + const llvm::lsp::URIForFile &uri) { auto *sourceMgr = const_cast(diag.getSourceMgr()); if (!sourceMgr || !diag.getLoc().isValid()) return std::nullopt; - lsp::Diagnostic lspDiag; + llvm::lsp::Diagnostic lspDiag; lspDiag.source = "tablegen"; lspDiag.category = "Parse Error"; // Try to grab a file location for this diagnostic. - lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri); + llvm::lsp::Location loc = getLocationFromLoc(*sourceMgr, diag.getLoc(), uri); lspDiag.range = loc.range; // Skip diagnostics that weren't emitted within the main file. @@ -84,17 +88,17 @@ getLspDiagnoticFromDiag(const llvm::SMDiagnostic &diag, // Convert the severity for the diagnostic. switch (diag.getKind()) { case SourceMgr::DK_Warning: - lspDiag.severity = lsp::DiagnosticSeverity::Warning; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Warning; break; case SourceMgr::DK_Error: - lspDiag.severity = lsp::DiagnosticSeverity::Error; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Error; break; case SourceMgr::DK_Note: // Notes are emitted separately from the main diagnostic, so we just treat // them as remarks given that we can't determine the diagnostic to relate // them to. case SourceMgr::DK_Remark: - lspDiag.severity = lsp::DiagnosticSeverity::Information; + lspDiag.severity = llvm::lsp::DiagnosticSeverity::Information; break; } lspDiag.message = diag.getMessage().str(); @@ -322,54 +326,59 @@ namespace { /// This class represents a text file containing one or more TableGen documents. class TableGenTextFile { public: - TableGenTextFile(const lsp::URIForFile &uri, StringRef fileContents, + TableGenTextFile(const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector &extraIncludeDirs, - std::vector &diagnostics); + std::vector &diagnostics); /// Return the current version of this text file. int64_t getVersion() const { return version; } /// Update the file to the new version using the provided set of content /// changes. Returns failure if the update was unsuccessful. - LogicalResult update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef changes, - std::vector &diagnostics); + LogicalResult + update(const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef changes, + std::vector &diagnostics); //===--------------------------------------------------------------------===// // Definitions and References //===--------------------------------------------------------------------===// - void getLocationsOf(const lsp::URIForFile &uri, const lsp::Position &defPos, - std::vector &locations); - void findReferencesOf(const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector &references); + void getLocationsOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &defPos, + std::vector &locations); + void findReferencesOf(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &pos, + std::vector &references); //===--------------------------------------------------------------------===// // Document Links //===--------------------------------------------------------------------===// - void getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links); + void getDocumentLinks(const llvm::lsp::URIForFile &uri, + std::vector &links); //===--------------------------------------------------------------------===// // Hover //===--------------------------------------------------------------------===// - std::optional findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos); - lsp::Hover buildHoverForRecord(const Record *record, - const SMRange &hoverRange); - lsp::Hover buildHoverForTemplateArg(const Record *record, + std::optional + findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos); + llvm::lsp::Hover buildHoverForRecord(const Record *record, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForTemplateArg(const Record *record, + const RecordVal *value, + const SMRange &hoverRange); + llvm::lsp::Hover buildHoverForField(const Record *record, const RecordVal *value, const SMRange &hoverRange); - lsp::Hover buildHoverForField(const Record *record, const RecordVal *value, - const SMRange &hoverRange); private: /// Initialize the text file from the given file contents. - void initialize(const lsp::URIForFile &uri, int64_t newVersion, - std::vector &diagnostics); + void initialize(const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector &diagnostics); /// The full string contents of the file. std::string contents; @@ -395,9 +404,9 @@ private: } // namespace TableGenTextFile::TableGenTextFile( - const lsp::URIForFile &uri, StringRef fileContents, int64_t version, + const llvm::lsp::URIForFile &uri, StringRef fileContents, int64_t version, const std::vector &extraIncludeDirs, - std::vector &diagnostics) + std::vector &diagnostics) : contents(fileContents.str()), version(version) { // Build the set of include directories for this file. llvm::SmallString<32> uriDirectory(uri.file()); @@ -409,12 +418,13 @@ TableGenTextFile::TableGenTextFile( initialize(uri, version, diagnostics); } -LogicalResult -TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, - ArrayRef changes, - std::vector &diagnostics) { - if (failed(lsp::TextDocumentContentChangeEvent::applyTo(changes, contents))) { - lsp::Logger::error("Failed to update contents of {0}", uri.file()); +LogicalResult TableGenTextFile::update( + const llvm::lsp::URIForFile &uri, int64_t newVersion, + ArrayRef changes, + std::vector &diagnostics) { + if (failed(llvm::lsp::TextDocumentContentChangeEvent::applyTo(changes, + contents))) { + llvm::lsp::Logger::error("Failed to update contents of {0}", uri.file()); return failure(); } @@ -423,9 +433,9 @@ TableGenTextFile::update(const lsp::URIForFile &uri, int64_t newVersion, return success(); } -void TableGenTextFile::initialize(const lsp::URIForFile &uri, - int64_t newVersion, - std::vector &diagnostics) { +void TableGenTextFile::initialize( + const llvm::lsp::URIForFile &uri, int64_t newVersion, + std::vector &diagnostics) { version = newVersion; sourceMgr = SourceMgr(); recordKeeper = std::make_unique(); @@ -433,7 +443,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // Build a buffer for this file. auto memBuffer = llvm::MemoryBuffer::getMemBuffer(contents, uri.file()); if (!memBuffer) { - lsp::Logger::error("Failed to create memory buffer for file", uri.file()); + llvm::lsp::Logger::error("Failed to create memory buffer for file", + uri.file()); return; } sourceMgr.setIncludeDirs(includeDirs); @@ -442,8 +453,8 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // This class provides a context argument for the SourceMgr diagnostic // handler. struct DiagHandlerContext { - std::vector &diagnostics; - const lsp::URIForFile &uri; + std::vector &diagnostics; + const llvm::lsp::URIForFile &uri; } handlerContext{diagnostics, uri}; // Set the diagnostic handler for the tablegen source manager. @@ -469,9 +480,9 @@ void TableGenTextFile::initialize(const lsp::URIForFile &uri, // TableGenTextFile: Definitions and References //===----------------------------------------------------------------------===// -void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri, - const lsp::Position &defPos, - std::vector &locations) { +void TableGenTextFile::getLocationsOf( + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &defPos, + std::vector &locations) { SMLoc posLoc = defPos.getAsSMLoc(sourceMgr); const TableGenIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -492,8 +503,8 @@ void TableGenTextFile::getLocationsOf(const lsp::URIForFile &uri, } void TableGenTextFile::findReferencesOf( - const lsp::URIForFile &uri, const lsp::Position &pos, - std::vector &references) { + const llvm::lsp::URIForFile &uri, const llvm::lsp::Position &pos, + std::vector &references) { SMLoc posLoc = pos.getAsSMLoc(sourceMgr); const TableGenIndexSymbol *symbol = index.lookup(posLoc); if (!symbol) @@ -508,8 +519,9 @@ void TableGenTextFile::findReferencesOf( // TableGenTextFile: Document Links //===--------------------------------------------------------------------===// -void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri, - std::vector &links) { +void TableGenTextFile::getDocumentLinks( + const llvm::lsp::URIForFile &uri, + std::vector &links) { for (const lsp::SourceMgrInclude &include : parsedIncludes) links.emplace_back(include.range, include.uri); } @@ -518,9 +530,9 @@ void TableGenTextFile::getDocumentLinks(const lsp::URIForFile &uri, // TableGenTextFile: Hover //===----------------------------------------------------------------------===// -std::optional -TableGenTextFile::findHover(const lsp::URIForFile &uri, - const lsp::Position &hoverPos) { +std::optional +TableGenTextFile::findHover(const llvm::lsp::URIForFile &uri, + const llvm::lsp::Position &hoverPos) { // Check for a reference to an include. for (const lsp::SourceMgrInclude &include : parsedIncludes) if (include.range.contains(hoverPos)) @@ -546,9 +558,10 @@ TableGenTextFile::findHover(const lsp::URIForFile &uri, return buildHoverForField(recordVal->record, value, hoverRange); } -lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover +TableGenTextFile::buildHoverForRecord(const Record *record, + const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); @@ -590,9 +603,9 @@ lsp::Hover TableGenTextFile::buildHoverForRecord(const Record *record, return hover; } -lsp::Hover TableGenTextFile::buildHoverForTemplateArg( +llvm::lsp::Hover TableGenTextFile::buildHoverForTemplateArg( const Record *record, const RecordVal *value, const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); StringRef name = value->getName().rsplit(':').second; @@ -604,10 +617,9 @@ lsp::Hover TableGenTextFile::buildHoverForTemplateArg( return hover; } -lsp::Hover TableGenTextFile::buildHoverForField(const Record *record, - const RecordVal *value, - const SMRange &hoverRange) { - lsp::Hover hover(lsp::Range(sourceMgr, hoverRange)); +llvm::lsp::Hover TableGenTextFile::buildHoverForField( + const Record *record, const RecordVal *value, const SMRange &hoverRange) { + llvm::lsp::Hover hover(llvm::lsp::Range(sourceMgr, hoverRange)); { llvm::raw_string_ostream hoverOS(hover.contents.value); hoverOS << "**field** `" << value->getName() << "`\n***\nType: `"; @@ -722,7 +734,7 @@ void lsp::TableGenServer::getDocumentLinks( return fileIt->second->getDocumentLinks(uri, documentLinks); } -std::optional +std::optional lsp::TableGenServer::findHover(const URIForFile &uri, const Position &hoverPos) { auto fileIt = impl->files.find(uri.file()); diff --git a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h index bdc851024a81..e54b8bcf35e2 100644 --- a/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h +++ b/mlir/lib/Tools/tblgen-lsp-server/TableGenServer.h @@ -11,6 +11,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LSP/Protocol.h" #include #include #include @@ -18,13 +19,13 @@ namespace mlir { namespace lsp { -struct Diagnostic; -struct DocumentLink; -struct Hover; -struct Location; -struct Position; -struct TextDocumentContentChangeEvent; -class URIForFile; +using llvm::lsp::Diagnostic; +using llvm::lsp::DocumentLink; +using llvm::lsp::Hover; +using llvm::lsp::Location; +using llvm::lsp::Position; +using llvm::lsp::TextDocumentContentChangeEvent; +using llvm::lsp::URIForFile; /// This class implements all of the TableGen related functionality necessary /// for a language server. This class allows for keeping the TableGen specific diff --git a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp index 10d602fdfe72..712237bbbbca 100644 --- a/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp +++ b/mlir/tools/mlir-lsp-server/mlir-lsp-server.cpp @@ -10,8 +10,8 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" -#include "mlir/Tools/lsp-server-support/Protocol.h" #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" +#include "llvm/Support/LSP/Protocol.h" using namespace mlir; @@ -37,8 +37,8 @@ int main(int argc, char **argv) { // Returns the registry, except in testing mode when the URI contains // "-disable-lsp-registration". Testing for/example of registering dialects // based on URI. - auto registryFn = [®istry, - &empty](const lsp::URIForFile &uri) -> DialectRegistry & { + auto registryFn = [®istry, &empty]( + const llvm::lsp::URIForFile &uri) -> DialectRegistry & { (void)empty; #ifdef MLIR_INCLUDE_TESTS if (uri.uri().contains("-disable-lsp-registration")) diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index c5f0d7e384d0..89332bce5fe0 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -18,7 +18,6 @@ add_subdirectory(Support) add_subdirectory(Rewrite) add_subdirectory(TableGen) add_subdirectory(Target) -add_subdirectory(Tools) add_subdirectory(Transforms) if(MLIR_ENABLE_EXECUTION_ENGINE) diff --git a/mlir/unittests/Tools/CMakeLists.txt b/mlir/unittests/Tools/CMakeLists.txt deleted file mode 100644 index a97588d92866..000000000000 --- a/mlir/unittests/Tools/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(lsp-server-support) diff --git a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt b/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt deleted file mode 100644 index c539c9bc5101..000000000000 --- a/mlir/unittests/Tools/lsp-server-support/CMakeLists.txt +++ /dev/null @@ -1,7 +0,0 @@ -add_mlir_unittest(MLIRLspServerSupportTests - Protocol.cpp - Transport.cpp -) -mlir_target_link_libraries(MLIRLspServerSupportTests - PRIVATE - MLIRLspServerSupportLib)