diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 2ac05880de86b..2b9e919329752 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -9,6 +9,8 @@ #ifndef LLDB_PROTOCOL_MCP_SERVER_H #define LLDB_PROTOCOL_MCP_SERVER_H +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" #include "lldb/Protocol/MCP/Tool.h" @@ -18,31 +20,57 @@ namespace lldb_protocol::mcp { -class Server { +class MCPTransport + : public lldb_private::JSONRPCTransport { public: - Server(std::string name, std::string version); - virtual ~Server() = default; + using LogCallback = std::function; + + MCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out, + std::string client_name, LogCallback log_callback = {}) + : JSONRPCTransport(in, out), m_client_name(std::move(client_name)), + m_log_callback(log_callback) {} + virtual ~MCPTransport() = default; + + void Log(llvm::StringRef message) override { + if (m_log_callback) + m_log_callback(llvm::formatv("{0}: {1}", m_client_name, message).str()); + } + +private: + std::string m_client_name; + LogCallback m_log_callback; +}; + +class Server : public MCPTransport::MessageHandler { +public: + Server(std::string name, std::string version, + std::unique_ptr transport_up, + lldb_private::MainLoop &loop); + ~Server() = default; + + using NotificationHandler = std::function; void AddTool(std::unique_ptr tool); void AddResourceProvider(std::unique_ptr resource_provider); + void AddNotificationHandler(llvm::StringRef method, + NotificationHandler handler); + + llvm::Error Run(); protected: - virtual Capabilities GetCapabilities() = 0; + Capabilities GetCapabilities(); using RequestHandler = std::function(const Request &)>; - using NotificationHandler = std::function; void AddRequestHandlers(); void AddRequestHandler(llvm::StringRef method, RequestHandler handler); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); llvm::Expected> HandleData(llvm::StringRef data); - llvm::Expected Handle(Request request); - void Handle(Notification notification); + llvm::Expected Handle(const Request &request); + void Handle(const Notification ¬ification); llvm::Expected InitializeHandler(const Request &); @@ -52,12 +80,21 @@ class Server { llvm::Expected ResourcesListHandler(const Request &); llvm::Expected ResourcesReadHandler(const Request &); - std::mutex m_mutex; + void Received(const Request &) override; + void Received(const Response &) override; + void Received(const Notification &) override; + void OnError(llvm::Error) override; + void OnClosed() override; + + void TerminateLoop(); private: const std::string m_name; const std::string m_version; + std::unique_ptr m_transport_up; + lldb_private::MainLoop &m_loop; + llvm::StringMap> m_tools; std::vector> m_resource_providers; diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index c359663239dcc..57132534cf680 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -26,24 +26,10 @@ using namespace llvm; LLDB_PLUGIN_DEFINE(ProtocolServerMCP) -static constexpr size_t kChunkSize = 1024; static constexpr llvm::StringLiteral kName = "lldb-mcp"; static constexpr llvm::StringLiteral kVersion = "0.1.0"; -ProtocolServerMCP::ProtocolServerMCP() - : ProtocolServer(), - lldb_protocol::mcp::Server(std::string(kName), std::string(kVersion)) { - AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); - - AddTool( - std::make_unique("lldb_command", "Run an lldb command.")); - - AddResourceProvider(std::make_unique()); -} +ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {} ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } @@ -64,57 +50,37 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { return "MCP Server."; } +void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { + server.AddNotificationHandler("notifications/initialized", + [](const lldb_protocol::mcp::Notification &) { + LLDB_LOG(GetLog(LLDBLog::Host), + "MCP initialization complete"); + }); + server.AddTool( + std::make_unique("lldb_command", "Run an lldb command.")); + server.AddResourceProvider(std::make_unique()); +} + void ProtocolServerMCP::AcceptCallback(std::unique_ptr socket) { - LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", - m_clients.size() + 1); + Log *log = GetLog(LLDBLog::Host); + std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1); + LLDB_LOG(log, "New MCP client connected: {0}", client_name); lldb::IOObjectSP io_sp = std::move(socket); - auto client_up = std::make_unique(); - client_up->io_sp = io_sp; - Client *client = client_up.get(); - - Status status; - auto read_handle_up = m_loop.RegisterReadObject( - io_sp, - [this, client](MainLoopBase &loop) { - if (llvm::Error error = ReadCallback(*client)) { - LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); - client->read_handle_up.reset(); - } - }, - status); - if (status.Fail()) + auto transport_up = std::make_unique( + io_sp, io_sp, std::move(client_name), [&](llvm::StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); + }); + auto instance_up = std::make_unique( + std::string(kName), std::string(kVersion), std::move(transport_up), + m_loop); + Extend(*instance_up); + llvm::Error error = instance_up->Run(); + if (error) { + LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}"); return; - - client_up->read_handle_up = std::move(read_handle_up); - m_clients.emplace_back(std::move(client_up)); -} - -llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { - char chunk[kChunkSize]; - size_t bytes_read = sizeof(chunk); - if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) - return status.takeError(); - client.buffer.append(chunk, bytes_read); - - for (std::string::size_type pos; - (pos = client.buffer.find('\n')) != std::string::npos;) { - llvm::Expected> message = - HandleData(StringRef(client.buffer.data(), pos)); - client.buffer = client.buffer.erase(0, pos + 1); - if (!message) - return message.takeError(); - - if (*message) { - std::string Output; - llvm::raw_string_ostream OS(Output); - OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; - size_t num_bytes = Output.size(); - return client.io_sp->Write(Output.data(), num_bytes).takeError(); - } } - - return llvm::Error::success(); + m_instances.push_back(std::move(instance_up)); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -158,27 +124,11 @@ llvm::Error ProtocolServerMCP::Stop() { // Stop the main loop. m_loop.AddPendingCallback( - [](MainLoopBase &loop) { loop.RequestTermination(); }); + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); // Wait for the main loop to exit. if (m_loop_thread.joinable()) m_loop_thread.join(); - { - std::lock_guard guard(m_mutex); - m_listener.reset(); - m_listen_handlers.clear(); - m_clients.clear(); - } - return llvm::Error::success(); } - -lldb_protocol::mcp::Capabilities ProtocolServerMCP::GetCapabilities() { - lldb_protocol::mcp::Capabilities capabilities; - capabilities.tools.listChanged = true; - // FIXME: Support sending notifications when a debugger/target are - // added/removed. - capabilities.resources.listChanged = false; - return capabilities; -} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 7fe909a728b85..fc650ffe0dfa7 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -18,8 +18,7 @@ namespace lldb_private::mcp { -class ProtocolServerMCP : public ProtocolServer, - public lldb_protocol::mcp::Server { +class ProtocolServerMCP : public ProtocolServer { public: ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; @@ -39,26 +38,24 @@ class ProtocolServerMCP : public ProtocolServer, Socket *GetSocket() const override { return m_listener.get(); } +protected: + // This adds tools and resource providers that + // are specific to this server. Overridable by the unit tests. + virtual void Extend(lldb_protocol::mcp::Server &server) const; + private: void AcceptCallback(std::unique_ptr socket); - lldb_protocol::mcp::Capabilities GetCapabilities() override; - bool m_running = false; - MainLoop m_loop; + lldb_private::MainLoop m_loop; std::thread m_loop_thread; + std::mutex m_mutex; std::unique_ptr m_listener; - std::vector m_listen_handlers; - struct Client { - lldb::IOObjectSP io_sp; - MainLoopBase::ReadHandleUP read_handle_up; - std::string buffer; - }; - llvm::Error ReadCallback(Client &client); - std::vector> m_clients; + std::vector m_listen_handlers; + std::vector> m_instances; }; } // namespace lldb_private::mcp diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index a9c1482e3e378..c1a6026b11090 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -12,8 +12,11 @@ using namespace lldb_protocol::mcp; using namespace llvm; -Server::Server(std::string name, std::string version) - : m_name(std::move(name)), m_version(std::move(version)) { +Server::Server(std::string name, std::string version, + std::unique_ptr transport_up, + lldb_private::MainLoop &loop) + : m_name(std::move(name)), m_version(std::move(version)), + m_transport_up(std::move(transport_up)), m_loop(loop) { AddRequestHandlers(); } @@ -30,7 +33,7 @@ void Server::AddRequestHandlers() { this, std::placeholders::_1)); } -llvm::Expected Server::Handle(Request request) { +llvm::Expected Server::Handle(const Request &request) { auto it = m_request_handlers.find(request.method); if (it != m_request_handlers.end()) { llvm::Expected response = it->second(request); @@ -44,7 +47,7 @@ llvm::Expected Server::Handle(Request request) { llvm::formatv("no handler for request: {0}", request.method).str()); } -void Server::Handle(Notification notification) { +void Server::Handle(const Notification ¬ification) { auto it = m_notification_handlers.find(notification.method); if (it != m_notification_handlers.end()) { it->second(notification); @@ -52,49 +55,7 @@ void Server::Handle(Notification notification) { } } -llvm::Expected> -Server::HandleData(llvm::StringRef data) { - auto message = llvm::json::parse(/*JSON=*/data); - if (!message) - return message.takeError(); - - if (const Request *request = std::get_if(&(*message))) { - llvm::Expected response = Handle(*request); - - // Handle failures by converting them into an Error message. - if (!response) { - Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request->id; - error_response.result = std::move(protocol_error); - return error_response; - } - - return *response; - } - - if (const Notification *notification = - std::get_if(&(*message))) { - Handle(*notification); - return std::nullopt; - } - - if (std::get_if(&(*message))) - return llvm::createStringError("unexpected MCP message: response"); - - llvm_unreachable("all message types handled"); -} - void Server::AddTool(std::unique_ptr tool) { - std::lock_guard guard(m_mutex); - if (!tool) return; m_tools[tool->GetName()] = std::move(tool); @@ -102,21 +63,17 @@ void Server::AddTool(std::unique_ptr tool) { void Server::AddResourceProvider( std::unique_ptr resource_provider) { - std::lock_guard guard(m_mutex); - if (!resource_provider) return; m_resource_providers.push_back(std::move(resource_provider)); } void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - std::lock_guard guard(m_mutex); m_request_handlers[method] = std::move(handler); } void Server::AddNotificationHandler(llvm::StringRef method, NotificationHandler handler) { - std::lock_guard guard(m_mutex); m_notification_handlers[method] = std::move(handler); } @@ -182,7 +139,6 @@ llvm::Expected Server::ResourcesListHandler(const Request &request) { llvm::json::Array resources; - std::lock_guard guard(m_mutex); for (std::unique_ptr &resource_provider_up : m_resource_providers) { for (const Resource &resource : resource_provider_up->GetResources()) @@ -211,7 +167,6 @@ llvm::Expected Server::ResourcesReadHandler(const Request &request) { if (uri_str.empty()) return llvm::createStringError("no resource uri"); - std::lock_guard guard(m_mutex); for (std::unique_ptr &resource_provider_up : m_resource_providers) { llvm::Expected result = @@ -232,3 +187,71 @@ llvm::Expected Server::ResourcesReadHandler(const Request &request) { llvm::formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } + +Capabilities Server::GetCapabilities() { + lldb_protocol::mcp::Capabilities capabilities; + capabilities.tools.listChanged = true; + // FIXME: Support sending notifications when a debugger/target are + // added/removed. + capabilities.resources.listChanged = false; + return capabilities; +} + +llvm::Error Server::Run() { + auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); + if (!handle) + return handle.takeError(); + + lldb_private::Status status = m_loop.Run(); + if (status.Fail()) + return status.takeError(); + + return llvm::Error::success(); +} + +void Server::Received(const Request &request) { + auto SendResponse = [this](const Response &response) { + if (llvm::Error error = m_transport_up->Send(response)) + m_transport_up->Log(llvm::toString(std::move(error))); + }; + + llvm::Expected response = Handle(request); + if (response) + return SendResponse(*response); + + lldb_protocol::mcp::Error protocol_error; + llvm::handleAllErrors( + response.takeError(), + [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); + }); + Response error_response; + error_response.id = request.id; + error_response.result = std::move(protocol_error); + SendResponse(error_response); +} + +void Server::Received(const Response &response) { + m_transport_up->Log("unexpected MCP message: response"); +} + +void Server::Received(const Notification ¬ification) { + Handle(notification); +} + +void Server::OnError(llvm::Error error) { + m_transport_up->Log(llvm::toString(std::move(error))); + TerminateLoop(); +} + +void Server::OnClosed() { + m_transport_up->Log("EOF"); + TerminateLoop(); +} + +void Server::TerminateLoop() { + m_loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); +} diff --git a/lldb/unittests/CMakeLists.txt b/lldb/unittests/CMakeLists.txt index 5533c73c3de87..4c5267ae25b74 100644 --- a/lldb/unittests/CMakeLists.txt +++ b/lldb/unittests/CMakeLists.txt @@ -79,10 +79,6 @@ add_subdirectory(Utility) add_subdirectory(ValueObject) add_subdirectory(tools) -if(LLDB_ENABLE_PROTOCOL_SERVERS) - add_subdirectory(ProtocolServer) -endif() - if(LLDB_CAN_USE_DEBUGSERVER AND LLDB_TOOL_DEBUGSERVER_BUILD AND NOT LLDB_USE_SYSTEM_DEBUGSERVER) add_subdirectory(debugserver) endif() diff --git a/lldb/unittests/Protocol/CMakeLists.txt b/lldb/unittests/Protocol/CMakeLists.txt index bbac69611e011..f877517ea233d 100644 --- a/lldb/unittests/Protocol/CMakeLists.txt +++ b/lldb/unittests/Protocol/CMakeLists.txt @@ -1,5 +1,6 @@ add_lldb_unittest(ProtocolTests ProtocolMCPTest.cpp + ProtocolMCPServerTest.cpp LINK_LIBS lldbHost diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp similarity index 65% rename from lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp rename to lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index 18112428950ce..b3fe22dbd38e5 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -1,4 +1,4 @@ -//===-- ProtocolServerMCPTest.cpp -----------------------------------------===// +//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,21 +6,20 @@ // //===----------------------------------------------------------------------===// -#include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" -#include "Plugins/Protocol/MCP/ProtocolServerMCP.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" +#include "TestingSupport/Host/PipeTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" -#include "lldb/Core/Debugger.h" -#include "lldb/Core/ProtocolServer.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" #include "lldb/Host/JSONTransport.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" #include "lldb/Host/Socket.h" -#include "lldb/Host/common/TCPSocket.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Resource.h" +#include "lldb/Protocol/MCP/Server.h" +#include "lldb/Protocol/MCP/Tool.h" #include "llvm/Support/Error.h" #include "llvm/Support/JSON.h" #include "llvm/Testing/Support/Error.h" @@ -37,22 +36,12 @@ using namespace lldb_protocol::mcp; using testing::_; namespace { -class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { +class TestMCPTransport final : public MCPTransport { public: - using ProtocolServerMCP::AddNotificationHandler; - using ProtocolServerMCP::AddRequestHandler; - using ProtocolServerMCP::AddResourceProvider; - using ProtocolServerMCP::AddTool; - using ProtocolServerMCP::GetSocket; - using ProtocolServerMCP::ProtocolServerMCP; -}; + TestMCPTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + : lldb_protocol::mcp::MCPTransport(in, out, "unittest") {} -using Message = typename Transport::Message; - -class TestJSONTransport final - : public lldb_private::JSONRPCTransport { -public: - using JSONRPCTransport::JSONRPCTransport; + using MCPTransport::Write; void Log(llvm::StringRef message) override { log_messages.emplace_back(message); @@ -61,6 +50,11 @@ class TestJSONTransport final std::vector log_messages; }; +class TestServer : public Server { +public: + using Server::Server; +}; + /// Test tool that returns it argument as text. class TestTool : public Tool { public: @@ -136,27 +130,20 @@ class FailTool : public Tool { } }; -class ProtocolServerMCPTest : public ::testing::Test { +class ProtocolServerMCPTest : public PipePairTest { public: - SubsystemRAII subsystems; - DebuggerSP m_debugger_sp; + SubsystemRAII subsystems; - lldb::IOObjectSP m_io_sp; - std::unique_ptr m_transport_up; - std::unique_ptr m_server_up; + std::unique_ptr transport_up; + std::unique_ptr server_up; MainLoop loop; MockMessageHandler message_handler; - static constexpr llvm::StringLiteral k_localhost = "localhost"; - llvm::Error Write(llvm::StringRef message) { - std::string output = llvm::formatv("{0}\n", message).str(); - size_t bytes_written = output.size(); - return m_io_sp->Write(output.data(), bytes_written).takeError(); - } - - void CloseInput() { - EXPECT_THAT_ERROR(m_io_sp->Close().takeError(), Succeeded()); + llvm::Expected value = json::parse(message); + if (!value) + return value.takeError(); + return transport_up->Write(*value); } /// Run the transport MainLoop and return any messages received. @@ -164,48 +151,34 @@ class ProtocolServerMCPTest : public ::testing::Test { Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) { loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, timeout); - auto handle = m_transport_up->RegisterMessageHandler(loop, message_handler); + auto handle = transport_up->RegisterMessageHandler(loop, message_handler); if (!handle) return handle.takeError(); - return loop.Run().takeError(); + return server_up->Run(); } void SetUp() override { - // Create a debugger. - ArchSpec arch("arm64-apple-macosx-"); - Platform::SetHostPlatform( - PlatformRemoteMacOSX::CreateInstance(true, &arch)); - m_debugger_sp = Debugger::CreateInstance(); - - // Create & start the server. - ProtocolServer::Connection connection; - connection.protocol = Socket::SocketProtocol::ProtocolTcp; - connection.name = llvm::formatv("{0}:0", k_localhost).str(); - m_server_up = std::make_unique(); - m_server_up->AddTool(std::make_unique("test", "test tool")); - m_server_up->AddResourceProvider(std::make_unique()); - ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); - - // Connect to the server over a TCP socket. - auto connect_socket_up = std::make_unique(true); - ASSERT_THAT_ERROR(connect_socket_up - ->Connect(llvm::formatv("{0}:{1}", k_localhost, - static_cast( - m_server_up->GetSocket()) - ->GetLocalPortNumber()) - .str()) - .ToError(), - llvm::Succeeded()); - - // Set up JSON transport for the client. - m_io_sp = std::move(connect_socket_up); - m_transport_up = std::make_unique(m_io_sp, m_io_sp); - } - - void TearDown() override { - // Stop the server. - ASSERT_THAT_ERROR(m_server_up->Stop(), llvm::Succeeded()); + PipePairTest::SetUp(); + + transport_up = std::make_unique( + std::make_shared(input.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(output.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)); + + server_up = std::make_unique( + "lldb-mcp", "0.1.0", + std::make_unique( + std::make_shared(output.GetReadFileDescriptor(), + File::eOpenOptionReadOnly, + NativeFile::Unowned), + std::make_shared(input.GetWriteFileDescriptor(), + File::eOpenOptionWriteOnly, + NativeFile::Unowned)), + loop); } }; @@ -225,6 +198,8 @@ TEST_F(ProtocolServerMCPTest, Initialization) { } TEST_F(ProtocolServerMCPTest, ToolsList) { + server_up->AddTool(std::make_unique("test", "test tool")); + llvm::StringLiteral request = R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":"one"})json"; @@ -233,20 +208,10 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; - ToolDefinition lldb_command_tool; - lldb_command_tool.description = "Run an lldb command."; - lldb_command_tool.name = "lldb_command"; - lldb_command_tool.inputSchema = json::Object{ - {"type", "object"}, - {"properties", - json::Object{{"arguments", json::Object{{"type", "string"}}}, - {"debugger_id", json::Object{{"type", "number"}}}}}, - {"required", json::Array{"debugger_id"}}}; Response response; response.id = "one"; response.result = json::Object{ - {"tools", - json::Array{std::move(test_tool), std::move(lldb_command_tool)}}, + {"tools", json::Array{std::move(test_tool)}}, }; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); @@ -255,6 +220,8 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { } TEST_F(ProtocolServerMCPTest, ResourcesList) { + server_up->AddResourceProvider(std::make_unique()); + llvm::StringLiteral request = R"json({"method":"resources/list","params":{},"jsonrpc":"2.0","id":2})json"; llvm::StringLiteral response = @@ -268,6 +235,8 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { } TEST_F(ProtocolServerMCPTest, ToolsCall) { + server_up->AddTool(std::make_unique("test", "test tool")); + llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = @@ -281,7 +250,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { } TEST_F(ProtocolServerMCPTest, ToolsCallError) { - m_server_up->AddTool(std::make_unique("error", "error tool")); + server_up->AddTool(std::make_unique("error", "error tool")); llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; @@ -296,7 +265,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { - m_server_up->AddTool(std::make_unique("fail", "fail tool")); + server_up->AddTool(std::make_unique("fail", "fail tool")); llvm::StringLiteral request = R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; @@ -315,19 +284,13 @@ TEST_F(ProtocolServerMCPTest, NotificationInitialized) { std::condition_variable cv; std::mutex mutex; - m_server_up->AddNotificationHandler( - "notifications/initialized", [&](const Notification ¬ification) { - { - std::lock_guard lock(mutex); - handler_called = true; - } - cv.notify_all(); - }); + server_up->AddNotificationHandler( + "notifications/initialized", + [&](const Notification ¬ification) { handler_called = true; }); llvm::StringLiteral request = R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - - std::unique_lock lock(mutex); - cv.wait(lock, [&] { return handler_called; }); + EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_TRUE(handler_called); } diff --git a/lldb/unittests/ProtocolServer/CMakeLists.txt b/lldb/unittests/ProtocolServer/CMakeLists.txt deleted file mode 100644 index 6117430b35bf0..0000000000000 --- a/lldb/unittests/ProtocolServer/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_lldb_unittest(ProtocolServerTests - ProtocolMCPServerTest.cpp - - LINK_LIBS - lldbCore - lldbUtility - lldbHost - lldbPluginPlatformMacOSX - lldbPluginProtocolServerMCP - LLVMTestingSupport - )