From 413ac7324e50f16311e1d21b108d639d089c8368 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Sat, 15 Mar 2025 22:59:53 -0400 Subject: [PATCH] Overhaul McpClientFactory/McpServerFactory Handlers are now specified before connecting the client/server. Otherwise, race conditions exist. Handlers move into being a part of the McpClientOptions/McpServerOptions, so that they're provided to the factories. Where relevant, the handlers are specified as part of the capability descriptors, so that they go hand in hand. The factories are no longer stateful. Instead of allocating a factory and then calling a create method on it, you just call a static factory method. --- README.MD | 84 ++-- .../anthropic/tools/ToolsConsole/Program.cs | 23 +- .../tools/ToolsConsole/Program.cs | 16 +- .../CallerArgumentExpressionAttribute.cs | 15 + .../CancellationTokenSourceExtensions.cs | 2 - .../McpSessionScope.cs | 21 +- src/mcpdotnet/Client/IMcpClient.cs | 21 - src/mcpdotnet/Client/McpClient.cs | 90 +--- src/mcpdotnet/Client/McpClientExtensions.cs | 98 +---- src/mcpdotnet/Client/McpClientFactory.cs | 297 +++++-------- .../Configuration/DefaultMcpServerBuilder.cs | 7 +- .../McpServerBuilderExtensions.Handler.cs | 66 +-- .../McpServerBuilderExtensions.Tools.cs | 12 +- .../McpServerBuilderExtensions.Transports.cs | 6 +- .../McpServerServiceCollectionExtension.cs | 20 +- .../Hosting/McpServerHostedService.cs | 5 +- src/mcpdotnet/Logging/Log.cs | 127 +----- .../Protocol/Messages/OperationNames.cs | 6 +- .../Protocol/Messages/RequestIdConverter.cs | 12 +- .../Protocol/Transport/PasteArguments.cs | 102 ----- .../Protocol/Transport/SseClientTransport.cs | 20 +- .../Transport/StdioClientTransport.cs | 42 +- .../Transport/StdioClientTransportOptions.cs | 2 +- .../Transport/StdioServerTransport.cs | 25 +- .../Protocol/Transport/TransportBase.cs | 2 +- src/mcpdotnet/Protocol/Types/Capabilities.cs | 60 ++- src/mcpdotnet/Server/IMcpServer.cs | 14 - src/mcpdotnet/Server/IMcpServerFactory.cs | 13 - src/mcpdotnet/Server/McpServer.cs | 266 ++++-------- src/mcpdotnet/Server/McpServerDelegates.cs | 93 ----- src/mcpdotnet/Server/McpServerExtensions.cs | 182 +------- src/mcpdotnet/Server/McpServerFactory.cs | 56 +-- src/mcpdotnet/Server/McpServerHandlers.cs | 135 ++++++ src/mcpdotnet/Server/McpServerOptions.cs | 7 + src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs | 31 +- src/mcpdotnet/Utils/Throw.cs | 41 ++ .../IntegrationTests.cs | 11 +- tests/mcpdotnet.TestServer/Program.cs | 386 +++++++++-------- .../Client/McpClientFactoryTests.cs | 355 +++++----------- .../ClientIntegrationTestFixture.cs | 30 +- .../mcpdotnet.Tests/ClientIntegrationTests.cs | 85 ++-- .../McpServerBuilderExtensionsHandlerTests.cs | 18 +- .../McpServerBuilderExtensionsToolsTests.cs | 32 +- .../Server/McpServerDelegatesTests.cs | 79 ++-- .../Server/McpServerFactoryTests.cs | 45 +- .../mcpdotnet.Tests/Server/McpServerTests.cs | 390 +++++++++--------- tests/mcpdotnet.Tests/SseIntegrationTests.cs | 99 ++--- 47 files changed, 1306 insertions(+), 2243 deletions(-) create mode 100644 src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs delete mode 100644 src/mcpdotnet/Protocol/Transport/PasteArguments.cs delete mode 100644 src/mcpdotnet/Server/IMcpServerFactory.cs delete mode 100644 src/mcpdotnet/Server/McpServerDelegates.cs create mode 100644 src/mcpdotnet/Server/McpServerHandlers.cs create mode 100644 src/mcpdotnet/Utils/Throw.cs diff --git a/README.MD b/README.MD index 4cb6e37..8cc534d 100644 --- a/README.MD +++ b/README.MD @@ -71,9 +71,7 @@ McpServerConfig config = new() } }; -var factory = new McpClientFactory([config], options, NullLoggerFactory.Instance); - -var client = await factory.GetClientAsync("everything"); +var client = await McpClientFactory.CreateAsync(config, options); // Print the list of tools available from the server. await foreach (var tool in client.ListToolsAsync()) @@ -136,55 +134,58 @@ using McpDotNet.Protocol.Types; using McpDotNet.Server; using Microsoft.Extensions.Logging.Abstractions; -var loggerFactory = NullLoggerFactory.Instance; McpServerOptions options = new() { ServerInfo = new() { Name = "MyServer", Version = "1.0.0" }, - Capabilities = new() { Tools = new() }, -}; -McpServerFactory factory = new(new StdioServerTransport("MyServer", loggerFactory), options, loggerFactory); -IMcpServer server = factory.CreateServer(); - -server.SetListToolsHandler(async (request, cancellationToken) => -{ - return new ListToolsResult() + Capabilities = new() { - Tools = - [ - new Tool() + Tools = new() + { + ListToolsHandler = async (request, cancellationToken) => { - Name = "echo", - Description = "Echoes the input back to the client.", - InputSchema = new JsonSchema() + return new ListToolsResult() { - Type = "object", - Properties = new Dictionary() + Tools = + [ + new Tool() + { + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = new JsonSchema() + { + Type = "object", + Properties = new Dictionary() + { + ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + } + }, + } + ] + }; + }, + + CallToolHandler = async (request, cancellationToken) => + { + if (request.Params?.Name == "echo") + { + if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) { - ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + throw new McpServerException("Missing required argument 'message'"); } - }, - } - ] - }; -}); -server.SetCallToolHandler(async (request, cancellationToken) => -{ - if (request.Params?.Name == "echo") - { - if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) - { - throw new McpServerException("Missing required argument 'message'"); - } + return new CallToolResponse() + { + Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] + }; + } - return new CallToolResponse() - { - Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] - }; - } + throw new McpServerException($"Unknown tool: '{request.Params?.Name}'"); + }, + } + }, +}; - throw new McpServerException($"Unknown tool: '{request.Params?.Name}'"); -}); +await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options); await server.StartAsync(); @@ -192,7 +193,6 @@ await server.StartAsync(); await Task.Delay(Timeout.Infinite); ``` - ## Roadmap - Expand documentation with detailed guides for: diff --git a/samples/anthropic/tools/ToolsConsole/Program.cs b/samples/anthropic/tools/ToolsConsole/Program.cs index 667fdf2..ba7cbc3 100644 --- a/samples/anthropic/tools/ToolsConsole/Program.cs +++ b/samples/anthropic/tools/ToolsConsole/Program.cs @@ -1,24 +1,21 @@ using Anthropic.SDK; using Anthropic.SDK.Constants; using Anthropic.SDK.Messaging; -using System.Linq; using McpDotNet; using McpDotNet.Client; using McpDotNet.Configuration; using McpDotNet.Protocol.Transport; -using Microsoft.Extensions.Logging.Abstractions; internal class Program { private static async Task GetMcpClientAsync() { - - McpClientOptions options = new() + McpClientOptions clientOptions = new() { ClientInfo = new() { Name = "SimpleToolsConsole", Version = "1.0.0" } }; - var config = new McpServerConfig + McpServerConfig serverConfig = new() { Id = "everything", Name = "Everything", @@ -30,13 +27,7 @@ private static async Task GetMcpClientAsync() } }; - var factory = new McpClientFactory( - [config], - options, - NullLoggerFactory.Instance - ); - - return await factory.GetClientAsync("everything"); + return await McpClientFactory.CreateAsync(serverConfig, clientOptions); } private static async Task Main(string[] args) @@ -44,7 +35,7 @@ private static async Task Main(string[] args) try { Console.WriteLine("Initializing MCP 'everything' server"); - var client = await GetMcpClientAsync(); + await using var client = await GetMcpClientAsync(); Console.WriteLine("MCP 'everything' server initialized"); Console.WriteLine("Listing tools..."); var tools = await client.ListToolsAsync().ToListAsync(); @@ -60,10 +51,10 @@ private static async Task Main(string[] args) Console.WriteLine("Asking Claude to call the Echo Tool..."); - var messages = new List - { + List messages = + [ new Message(RoleType.User, "Please call the echo tool with the string 'Hello MCP!' and show me the echoed response.") - }; + ]; var parameters = new MessageParameters() { diff --git a/samples/microsoft.extensions.ai/tools/ToolsConsole/Program.cs b/samples/microsoft.extensions.ai/tools/ToolsConsole/Program.cs index 73894b0..2ed9f75 100644 --- a/samples/microsoft.extensions.ai/tools/ToolsConsole/Program.cs +++ b/samples/microsoft.extensions.ai/tools/ToolsConsole/Program.cs @@ -3,20 +3,18 @@ using McpDotNet.Extensions.AI; using McpDotNet.Protocol.Transport; using Microsoft.Extensions.AI; -using Microsoft.Extensions.Logging.Abstractions; using OpenAI; internal class Program { private static async Task GetMcpClientAsync() { - - McpClientOptions options = new() + McpClientOptions clientOptions = new() { ClientInfo = new() { Name = "SimpleToolsConsole", Version = "1.0.0" } }; - var config = new McpServerConfig + McpServerConfig serverConfig = new() { Id = "everything", Name = "Everything", @@ -28,13 +26,7 @@ private static async Task GetMcpClientAsync() } }; - var factory = new McpClientFactory( - [config], - options, - NullLoggerFactory.Instance - ); - - return await factory.GetClientAsync("everything"); + return await McpClientFactory.CreateAsync(serverConfig, clientOptions); } private static async Task Main(string[] args) @@ -42,7 +34,7 @@ private static async Task Main(string[] args) try { Console.WriteLine("Initializing MCP 'everything' server"); - var client = await GetMcpClientAsync(); + await using var client = await GetMcpClientAsync(); Console.WriteLine("MCP 'everything' server initialized"); Console.WriteLine("Listing tools..."); var mappedTools = await client.ListToolsAsync().Select(t => t.ToAITool(client)).ToListAsync(); diff --git a/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs b/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs new file mode 100644 index 0000000..968c31e --- /dev/null +++ b/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Runtime.CompilerServices; + +[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = false)] +internal sealed class CallerArgumentExpressionAttribute : Attribute +{ + public CallerArgumentExpressionAttribute(string parameterName) + { + ParameterName = parameterName; + } + + public string ParameterName { get; } +} diff --git a/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs b/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs index be96fd1..93e34fd 100644 --- a/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs +++ b/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs @@ -1,5 +1,3 @@ -using System.Text; - namespace System.Threading.Tasks; internal static class CancellationTokenSourceExtensions diff --git a/src/McpDotNet.Extensions.AI/McpSessionScope.cs b/src/McpDotNet.Extensions.AI/McpSessionScope.cs index 0a528e2..5f03de0 100644 --- a/src/McpDotNet.Extensions.AI/McpSessionScope.cs +++ b/src/McpDotNet.Extensions.AI/McpSessionScope.cs @@ -2,7 +2,6 @@ using McpDotNet.Configuration; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Extensions.AI; @@ -45,7 +44,7 @@ public static async Task CreateAsync(McpServerConfig serverConf } var scope = new McpSessionScope(); - var client = await scope.AddClientAsync(serverConfig, options, loggerFactory).ConfigureAwait(false); + var client = await AddClientAsync(serverConfig, options, loggerFactory).ConfigureAwait(false); scope.Tools = []; await foreach (var tool in client.ListToolsAsync().ConfigureAwait(false)) @@ -81,7 +80,7 @@ public static async Task CreateAsync(IEnumerable CreateAsync(IEnumerable AddClientAsync(McpServerConfig config, - McpClientOptions? options, + private static Task AddClientAsync( + McpServerConfig serverConfig, + McpClientOptions? clientOptions, ILoggerFactory? loggerFactory = null) { - using var factory = new McpClientFactory([config], - options ?? new() { ClientInfo = new() { Name = "AnonymousClient", Version = "1.0.0.0" } }, - loggerFactory ?? NullLoggerFactory.Instance); - factory.DisposeClientsOnDispose = false; - var client = await factory.GetClientAsync(config.Id).ConfigureAwait(false); - _clients.Add(client); - return client; + return McpClientFactory.CreateAsync( + serverConfig, + clientOptions ?? new() { ClientInfo = new() { Name = "AnonymousClient", Version = "1.0.0.0" } }, + loggerFactory: loggerFactory); } /// diff --git a/src/mcpdotnet/Client/IMcpClient.cs b/src/mcpdotnet/Client/IMcpClient.cs index 0d7442f..0a49572 100644 --- a/src/mcpdotnet/Client/IMcpClient.cs +++ b/src/mcpdotnet/Client/IMcpClient.cs @@ -30,20 +30,6 @@ public interface IMcpClient : IAsyncDisposable /// string? ServerInstructions { get; } - /// Sets a handler for the named operation. - /// The name of the operation. - /// The handler. Each operation requires a specific delegate signature. - /// - /// - /// Each operation may have only a single handler. Setting a handler for an operation that already has one - /// will replace the existing handler. - /// - /// - /// provides constants for common operations. - /// - /// - void SetOperationHandler(string operationName, Delegate handler); - /// /// Adds a handler for server notifications of a specific method. /// @@ -60,13 +46,6 @@ public interface IMcpClient : IAsyncDisposable /// void AddNotificationHandler(string method, Func handler); - /// - /// Establishes a connection to the server. - /// - /// A token to cancel the operation. - /// A task representing the asynchronous operation. - Task ConnectAsync(CancellationToken cancellationToken = default); - /// /// Sends a generic JSON-RPC request to the server. /// diff --git a/src/mcpdotnet/Client/McpClient.cs b/src/mcpdotnet/Client/McpClient.cs index c7e4e24..4195f5d 100644 --- a/src/mcpdotnet/Client/McpClient.cs +++ b/src/mcpdotnet/Client/McpClient.cs @@ -6,8 +6,7 @@ using McpDotNet.Protocol.Types; using McpDotNet.Shared; using Microsoft.Extensions.Logging; - -#pragma warning disable CA1508 // Avoid dead conditional code +using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Client; @@ -15,15 +14,11 @@ namespace McpDotNet.Client; internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient { private readonly McpClientOptions _options; - private readonly McpServerConfig _serverConfig; - private readonly ILogger _logger; + private readonly ILogger _logger; private readonly IClientTransport _clientTransport; private volatile bool _isInitializing; - private Func>? _samplingHandler; - private Func>? _rootsHandler; - /// /// Initializes a new instance of the class. /// @@ -31,26 +26,37 @@ internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient /// Options for the client, defining protocol version and capabilities. /// The server configuration. /// The logger factory. - public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory loggerFactory) + public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) : base(transport, loggerFactory) { _options = options; - _serverConfig = serverConfig; - _logger = loggerFactory.CreateLogger(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _clientTransport = transport; EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; - if (options.Capabilities?.Sampling is not null) + if (options.Capabilities?.Sampling is { } samplingCapability) { + if (samplingCapability.SamplingHandler is not { } samplingHandler) + { + throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler."); + } + SetRequestHandler( - "sampling/createMessage", request => HandleRequest("Sampling", _samplingHandler, request)); + "sampling/createMessage", + request => samplingHandler(request, CancellationTokenSource?.Token ?? default)); } - if (options.Capabilities?.Roots is not null) + if (options.Capabilities?.Roots is { } rootsCapability) { + if (rootsCapability.RootsHandler is not { } rootsHandler) + { + throw new InvalidOperationException($"Roots capability was set but it did not provide a handler."); + } + SetRequestHandler( - "roots/list", request => HandleRequest("Roots", _rootsHandler, request)); + "roots/list", + request => rootsHandler(request, CancellationTokenSource?.Token ?? default)); } } @@ -66,47 +72,6 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer /// public override string EndpointName { get; } - public void SetOperationHandler(string operationName, Delegate handler) - { - if (operationName is null) - { - throw new ArgumentNullException(nameof(operationName)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - if (!TrySetOperationHandler(OperationNames.Sampling, operationName, handler, ref _samplingHandler) && - !TrySetOperationHandler(OperationNames.Roots, operationName, handler, ref _rootsHandler)) - { - throw new ArgumentException($"Unknown operation '{operationName}'", nameof(operationName)); - } - - static bool TrySetOperationHandler( - string targetOperationName, - string operationName, - Delegate handler, - ref Func>? field) - { - if (operationName == targetOperationName) - { - if (handler is Func> typed) - { - field = typed; - return true; - } - - throw new ArgumentException( - $"Handler must be of type {typeof(Func>)}", - nameof(handler)); - } - - return false; - } - } - /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { @@ -195,19 +160,4 @@ await SendMessageAsync( throw new McpClientException("Initialization timed out"); } } - - private async Task HandleRequest( - string friendlyName, - Func>? handler, - TRequest args) - { - if (handler is not null) - { - return await handler(args, CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - } - - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.HandlerNotConfigured(friendlyName, EndpointName); - throw new McpClientException($"{friendlyName} handler not configured."); - } } diff --git a/src/mcpdotnet/Client/McpClientExtensions.cs b/src/mcpdotnet/Client/McpClientExtensions.cs index 6b8a93a..7e4703c 100644 --- a/src/mcpdotnet/Client/McpClientExtensions.cs +++ b/src/mcpdotnet/Client/McpClientExtensions.cs @@ -1,8 +1,7 @@ using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Types; -using Microsoft.Extensions.Logging; +using McpDotNet.Utils; using System.Runtime.CompilerServices; -using System.Xml.XPath; namespace McpDotNet.Client; @@ -20,10 +19,7 @@ public static class McpClientExtensions /// A token to cancel the operation. public static Task SendNotificationAsync(this IMcpClient client, string method, object? parameters = null, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendMessageAsync( new JsonRpcNotification { Method = method, Params = parameters }, @@ -38,10 +34,7 @@ public static Task SendNotificationAsync(this IMcpClient client, string method, /// A task that completes when the ping is successful. public static Task PingAsync(this IMcpClient client, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("ping", null), @@ -80,10 +73,7 @@ public static async IAsyncEnumerable ListToolsAsync( /// A task containing the server's response with tool information. public static Task ListToolsAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("tools/list", CreateCursorDictionary(cursor)), @@ -122,10 +112,7 @@ public static async IAsyncEnumerable ListPromptsAsync( /// A task containing the server's response with prompt information. public static Task ListPromptsAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("prompts/list", CreateCursorDictionary(cursor)), @@ -142,10 +129,7 @@ public static Task ListPromptsAsync(this IMcpClient client, s /// A task containing the prompt's content and messages. public static Task GetPromptAsync(this IMcpClient client, string name, Dictionary? arguments = null, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("prompts/get", CreateParametersDictionary(name, arguments)), @@ -183,10 +167,7 @@ public static async IAsyncEnumerable ListResourcesAsync( /// A token to cancel the operation. public static Task ListResourcesAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("resources/list", CreateCursorDictionary(cursor)), @@ -201,10 +182,7 @@ public static Task ListResourcesAsync(this IMcpClient clien /// A token to cancel the operation. public static Task ReadResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("resources/read", new() { ["uri"] = uri }), @@ -221,22 +199,9 @@ public static Task ReadResourceAsync(this IMcpClient client, /// A token to cancel the operation. public static Task GetCompletionAsync(this IMcpClient client, Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } - - if (reference is null) - { - throw new ArgumentNullException(nameof(reference)); - } - - if (string.IsNullOrWhiteSpace(argumentName)) - { - throw argumentName is null ? - new ArgumentNullException(nameof(argumentName)) : - new ArgumentException("Argument name cannot be empty.", nameof(argumentName)); - } + Throw.IfNull(client); + Throw.IfNull(reference); + Throw.IfNullOrWhiteSpace(argumentName); if (!reference.Validate(out string? validationMessage)) { @@ -260,10 +225,7 @@ public static Task GetCompletionAsync(this IMcpClient client, Re /// A token to cancel the operation. public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("resources/subscribe", new() { ["uri"] = uri }), @@ -278,10 +240,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, /// A token to cancel the operation. public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("resources/unsubscribe", new() { ["uri"] = uri }), @@ -298,42 +257,13 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u /// A task containing the tool's response. public static Task CallToolAsync(this IMcpClient client, string toolName, Dictionary arguments, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("tools/call", CreateParametersDictionary(toolName, arguments)), cancellationToken); } - /// Sets the handler for server sampling requests. - /// The client. - /// The sampling request handler. - public static void SetSamplingHandler(this IMcpClient client, Func> handler) - { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } - - client.SetOperationHandler(OperationNames.Sampling, handler); - } - - /// Sets the handler for server roots requests. - /// The client. - /// The roots request handler. - public static void SetRootsHandler(this IMcpClient client, Func> handler) - { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } - - client.SetOperationHandler(OperationNames.Roots, handler); - } - private static JsonRpcRequest CreateRequest(string method, Dictionary? parameters) => new JsonRpcRequest { diff --git a/src/mcpdotnet/Client/McpClientFactory.cs b/src/mcpdotnet/Client/McpClientFactory.cs index fa1bc26..8f499fb 100644 --- a/src/mcpdotnet/Client/McpClientFactory.cs +++ b/src/mcpdotnet/Client/McpClientFactory.cs @@ -3,236 +3,129 @@ using McpDotNet.Configuration; using McpDotNet.Logging; using McpDotNet.Protocol.Transport; +using McpDotNet.Utils; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Client; -/// -/// Factory for creating MCP clients based on configuration. It caches clients for reuse, so it is safe to call GetClientAsync multiple times. -/// Call GetClientAsync to get a client for a specific server (by ID), which will create a new client and connect if it doesn't already exist. -/// All server configurations must be passed in the constructor. -/// Capabilities (as defined in client options) are shared across all clients, as the client host can always decide not to use certain capabilities. -/// -public class McpClientFactory : IDisposable -{ - private const string ARGUMENTS_OPTIONS_KEY = "arguments"; - private const string COMMAND_OPTIONS_KEY = "command"; - private readonly Dictionary _serverConfigs; - private readonly McpClientOptions _clientOptions; - private readonly Dictionary _clients = []; - private readonly Func _transportFactoryMethod; - private readonly Func _clientFactoryMethod; - private readonly ILoggerFactory _loggerFactory; - private readonly ILogger _logger; - private bool _isDisposed; - - /// - /// Gets or sets a value indicating whether clients should be disposed when the factory is disposed. - /// - public bool DisposeClientsOnDispose { get; set; } = true; - /// - /// Initializes a new instance of the class. - /// It is not necessary to pass factory methods for creating transports and clients, as default implementations are provided. - /// Custom factory methods can be provided for mocking or to use custom transport or client implementations. - /// - /// Configuration objects for each server the factory should support. - /// A configuration object which specifies client capabilities and protocol version. +/// Provides factory methods for creating MCP clients. +public static class McpClientFactory +{ + /// Creates an , connecting it to the specified server. + /// Configuration for the target server to which the client should connect. + /// A client configuration object which specifies client capabilities and protocol version. + /// An optional factory method which returns transport implementations based on a server configuration. /// A logger factory for creating loggers for clients. - /// An optional factory method which returns transport implementations based on a server configuration. - /// An optional factory method which creates a client based on client options and transport implementation. - public McpClientFactory( - IEnumerable serverConfigs, + /// A token to cancel the operation. + /// An that's connected to the specified server. + /// is . + /// is . + /// contains invalid information. + /// returns an invalid transport. + public static async Task CreateAsync( + McpServerConfig serverConfig, McpClientOptions clientOptions, + Func? createTransportFunc = null, ILoggerFactory? loggerFactory = null, - Func? transportFactoryMethod = null, - Func? clientFactoryMethod = null) + CancellationToken cancellationToken = default) { - if (serverConfigs is null) - { - throw new ArgumentNullException(nameof(serverConfigs)); - } + Throw.IfNull(serverConfig); + Throw.IfNull(clientOptions); - if (clientOptions is null) - { - throw new ArgumentNullException(nameof(clientOptions)); - } + createTransportFunc ??= CreateTransport; + + string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; - loggerFactory ??= NullLoggerFactory.Instance; + var logger = loggerFactory?.CreateLogger(typeof(McpClientFactory)) ?? NullLogger.Instance; + logger.CreatingClient(endpointName); - _serverConfigs = serverConfigs.ToDictionary(c => c.Id); - _clientOptions = clientOptions; - _loggerFactory = loggerFactory; - _logger = loggerFactory.CreateLogger(); - _transportFactoryMethod = transportFactoryMethod ?? CreateTransport; - _clientFactoryMethod = clientFactoryMethod ?? ((transport, serverConfig, options) => new McpClient(transport, options, serverConfig, loggerFactory)); + var transport = + createTransportFunc(serverConfig, loggerFactory) ?? + throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport."); - // Initialize commands for stdio transport, this is to run commands in a shell even if specified directly, as otherwise - // the stdio protocol will not work correctly. - _logger.InitializingStdioCommands(); - foreach (var config in _serverConfigs.Values.Where(c => c.TransportType.Equals(TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase))) + try { - InitializeCommand(config); + McpClient client = new(transport, clientOptions, serverConfig, loggerFactory); + try + { + await client.ConnectAsync(cancellationToken).ConfigureAwait(false); + logger.ClientCreated(endpointName); + return client; + } + catch + { + await client.DisposeAsync().ConfigureAwait(false); + throw; + } } - } - - /// - /// Gets or creates a client for the specified server. The first time a server is requested, a new client is created and connected. - /// Note that this will often spawn the server process during connection, so in some cases you want to call this method only when needed. - /// In other cases, you may want to call it ahead of time to ensure the server is ready when needed, as it may take some time to start up. - /// - /// The ID of the server to connect to. It must have been passed in the serverConfigs when constructing the factory. - /// A token to cancel the operation. - public async Task GetClientAsync(string serverId, CancellationToken cancellationToken = default) - { - if (!_serverConfigs.TryGetValue(serverId, out var config)) + catch { - _logger.ServerNotFound(serverId); - throw new ArgumentException($"Server with ID '{serverId}' not found.", nameof(serverId)); + await transport.DisposeAsync().ConfigureAwait(false); + throw; } - - string endpointName = $"Client ({serverId}: {config.Name})"; - - if (_clients.TryGetValue(serverId, out var existingClient)) - { - _logger.ClientExists(endpointName); - return existingClient; - } - - _logger.CreatingClient(endpointName); - - var transport = _transportFactoryMethod(config); - var client = _clientFactoryMethod(transport, config, _clientOptions); - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - - _logger.ClientCreated(endpointName); - - _clients[serverId] = client; - return client; } - internal Func TransportFactoryMethod => _transportFactoryMethod; - - private IClientTransport CreateTransport(McpServerConfig config) + private static IClientTransport CreateTransport(McpServerConfig serverConfig, ILoggerFactory? loggerFactory) { - string endpointName = $"Client ({config.Id}: {config.Name})"; - - var options = string.Join(", ", config.TransportOptions?.Select(kv => $"{kv.Key}={kv.Value}") ?? []); - _logger.CreatingTransport(endpointName, config.TransportType, options); - - if (string.Equals(config.TransportType, TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase)) + if (string.Equals(serverConfig.TransportType, TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase)) { + string? command = serverConfig.TransportOptions?.GetValueOrDefault("command"); + if (string.IsNullOrWhiteSpace(command)) + { + command = serverConfig.Location; + if (string.IsNullOrWhiteSpace(command)) + { + throw new ArgumentException("Command is required for stdio transport.", nameof(serverConfig)); + } + } + + string? arguments = serverConfig.TransportOptions?.GetValueOrDefault("arguments"); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && + serverConfig.TransportType.Equals(TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase) && + !string.IsNullOrEmpty(command) && + !string.Equals(Path.GetFileName(command), "cmd.exe", StringComparison.OrdinalIgnoreCase)) + { + // On Windows, for stdio, we need to wrap non-shell commands with cmd.exe /c {command} (usually npx or uvicorn). + // The stdio transport will not work correctly if the command is not run in a shell. + arguments = string.IsNullOrWhiteSpace(arguments) ? + $"/c {command}" : + $"/c {command} {arguments}"; + command = "cmd.exe"; + } + return new StdioClientTransport(new StdioClientTransportOptions { - Command = GetCommand(config), - Arguments = config.TransportOptions?.GetValueOrDefault(ARGUMENTS_OPTIONS_KEY)?.Split(' '), - WorkingDirectory = config.TransportOptions?.GetValueOrDefault("workingDirectory"), - EnvironmentVariables = config.TransportOptions? + Command = command!, + Arguments = arguments, + WorkingDirectory = serverConfig.TransportOptions?.GetValueOrDefault("workingDirectory"), + EnvironmentVariables = serverConfig.TransportOptions? .Where(kv => kv.Key.StartsWith("env:", StringComparison.Ordinal)) - .ToDictionary(kv => kv.Key.Substring(4), kv => kv.Value), - ShutdownTimeout = TimeSpan.TryParse(config.TransportOptions?.GetValueOrDefault("shutdownTimeout"), CultureInfo.InvariantCulture, out var timespan) ? timespan : StdioClientTransportOptions.DefaultShutdownTimeout - }, config, _loggerFactory); + .ToDictionary(kv => kv.Key.Substring("env:".Length), kv => kv.Value), + ShutdownTimeout = TimeSpan.TryParse(serverConfig.TransportOptions?.GetValueOrDefault("shutdownTimeout"), CultureInfo.InvariantCulture, out var timespan) ? timespan : StdioClientTransportOptions.DefaultShutdownTimeout + }, serverConfig, loggerFactory); } - if (string.Equals(config.TransportType, TransportTypes.Sse, StringComparison.OrdinalIgnoreCase) || - string.Equals(config.TransportType, "http", StringComparison.OrdinalIgnoreCase)) + if (string.Equals(serverConfig.TransportType, TransportTypes.Sse, StringComparison.OrdinalIgnoreCase) || + string.Equals(serverConfig.TransportType, "http", StringComparison.OrdinalIgnoreCase)) { - return new SseClientTransport( - new SseClientTransportOptions - { - ConnectionTimeout = TimeSpan.FromSeconds(ParseOrDefault(config.TransportOptions, "connectionTimeout", 30)), - MaxReconnectAttempts = ParseOrDefault(config.TransportOptions, "maxReconnectAttempts", 3), - ReconnectDelay = TimeSpan.FromSeconds(ParseOrDefault(config.TransportOptions, "reconnectDelay", 5)), - AdditionalHeaders = config.TransportOptions? - .Where(kv => kv.Key.StartsWith("header.", StringComparison.Ordinal)) - .ToDictionary(kv => kv.Key.Substring(7), kv => kv.Value) - }, config, _loggerFactory); - } - - throw new ArgumentException($"Unsupported transport type '{config.TransportType}'.", nameof(config)); - } - - private static int ParseOrDefault(Dictionary? options, string key, int defaultValue) - { - if (options?.TryGetValue(key, out var value) ?? false) - { - if (!int.TryParse(value, out var result)) - throw new FormatException($"Invalid value '{value}' for option '{key}'"); - return result; - } - return defaultValue; - } - - private static string GetCommand(McpServerConfig config) - { - var command = config.TransportOptions?.GetValueOrDefault(COMMAND_OPTIONS_KEY); - - if (string.IsNullOrEmpty(command)) - return config.Location!; - - return RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "cmd.exe" : command!; - } - - /// - /// Initializes a non-shell command by injecting a /c {command} argument, as the command will be run in a shell. - /// - private void InitializeCommand(McpServerConfig config) - { - string endpointName = $"Client ({config.Id}: {config.Name})"; - - // If the command is empty or already contains cmd.exe, we don't need to do anything - var command = config.TransportOptions?.GetValueOrDefault(COMMAND_OPTIONS_KEY); - - if (string.IsNullOrEmpty(command) || !RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - return; - } - - // On Windows, we need to wrap non-shell commands with cmd.exe /c - if (command!.IndexOf("cmd.exe", StringComparison.OrdinalIgnoreCase) >= 0) - { - _logger.SkippingShellWrapper(endpointName); - return; - } - - // If the command is not empty and does not contain cmd.exe, we need to inject /c {command} (usually npx or uvicorn) - // This is because the stdio transport will not work correctly if the command is not run in a shell - _logger.PromotingCommandToShellArgumentForStdio(endpointName, command, config.TransportOptions!.GetValueOrDefault(ARGUMENTS_OPTIONS_KEY) ?? ""); - config.TransportOptions![ARGUMENTS_OPTIONS_KEY] = config.TransportOptions.TryGetValue(ARGUMENTS_OPTIONS_KEY, out var args) - ? $"/c {command} {args}" - : $"/c {command}"; - } - - /// - /// Disposes all clients created by the factory. - /// - /// - protected virtual void Dispose(bool disposing) - { - if (!_isDisposed) - { - if (disposing && DisposeClientsOnDispose) - DisposeClients(); - - _isDisposed = true; - } - } - - private void DisposeClients() - { - foreach (var client in _clients.Values) - { - client?.DisposeAsync().AsTask().Wait(); + return new SseClientTransport(new SseClientTransportOptions + { + ConnectionTimeout = TimeSpan.FromSeconds(ParseInt32OrDefault(serverConfig.TransportOptions, "connectionTimeout", 30)), + MaxReconnectAttempts = ParseInt32OrDefault(serverConfig.TransportOptions, "maxReconnectAttempts", 3), + ReconnectDelay = TimeSpan.FromSeconds(ParseInt32OrDefault(serverConfig.TransportOptions, "reconnectDelay", 5)), + AdditionalHeaders = serverConfig.TransportOptions? + .Where(kv => kv.Key.StartsWith("header.", StringComparison.Ordinal)) + .ToDictionary(kv => kv.Key.Substring("header.".Length), kv => kv.Value) + }, serverConfig, loggerFactory); + + static int ParseInt32OrDefault(Dictionary? options, string key, int defaultValue) => + options?.TryGetValue(key, out var value) is not true ? defaultValue : + int.TryParse(value, out var result) ? result : + throw new ArgumentException($"Invalid value '{value}' for option '{key}' in transport options.", nameof(serverConfig)); } - _clients.Clear(); - } - - /// - public void Dispose() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: true); - GC.SuppressFinalize(this); + throw new ArgumentException($"Unsupported transport type '{serverConfig.TransportType}'.", nameof(serverConfig)); } } \ No newline at end of file diff --git a/src/mcpdotnet/Configuration/DefaultMcpServerBuilder.cs b/src/mcpdotnet/Configuration/DefaultMcpServerBuilder.cs index 5559d71..5bcea8d 100644 --- a/src/mcpdotnet/Configuration/DefaultMcpServerBuilder.cs +++ b/src/mcpdotnet/Configuration/DefaultMcpServerBuilder.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.DependencyInjection; +using McpDotNet.Utils; +using Microsoft.Extensions.DependencyInjection; namespace McpDotNet.Configuration; @@ -17,6 +18,8 @@ internal class DefaultMcpServerBuilder : IMcpServerBuilder /// public DefaultMcpServerBuilder(IServiceCollection services) { - Services = services ?? throw new ArgumentNullException(nameof(services)); + Throw.IfNull(services); + + Services = services; } } diff --git a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Handler.cs b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Handler.cs index fe0b955..9ffcf9a 100644 --- a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Handler.cs +++ b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Handler.cs @@ -1,6 +1,7 @@ using McpDotNet.Configuration; using McpDotNet.Protocol.Types; using McpDotNet.Server; +using McpDotNet.Utils; using Microsoft.Extensions.DependencyInjection; namespace McpDotNet; @@ -17,12 +18,9 @@ public static partial class McpServerBuilderExtensions /// The handler. public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.ListToolsHandler = handler); + builder.Services.Configure(s => s.ListToolsHandler = handler); return builder; } @@ -33,12 +31,9 @@ public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder buil /// The handler. public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.CallToolHandler = handler); + builder.Services.Configure(s => s.CallToolHandler = handler); return builder; } @@ -49,12 +44,9 @@ public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder build /// The handler. public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.ListPromptsHandler = handler); + builder.Services.Configure(s => s.ListPromptsHandler = handler); return builder; } @@ -65,12 +57,9 @@ public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder bu /// The handler. public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.GetPromptHandler = handler); + builder.Services.Configure(s => s.GetPromptHandler = handler); return builder; } @@ -81,12 +70,9 @@ public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder buil /// The handler. public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.ListResourcesHandler = handler); + builder.Services.Configure(s => s.ListResourcesHandler = handler); return builder; } @@ -97,28 +83,22 @@ public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder /// The handler. public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.ReadResourceHandler = handler); + builder.Services.Configure(s => s.ReadResourceHandler = handler); return builder; } /// - /// Sets the handler for get resources requests. + /// Sets the handler for get completion requests. /// /// The builder instance. /// The handler. public static IMcpServerBuilder WithGetCompletionHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.GetCompletionHandler = handler); + builder.Services.Configure(s => s.GetCompletionHandler = handler); return builder; } @@ -129,12 +109,9 @@ public static IMcpServerBuilder WithGetCompletionHandler(this IMcpServerBuilder /// The handler. public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.SubscribeToResourcesHandler = handler); + builder.Services.Configure(s => s.SubscribeToResourcesHandler = handler); return builder; } @@ -145,12 +122,9 @@ public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerB /// The handler. public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.UnsubscribeFromResourcesHandler = handler); + builder.Services.Configure(s => s.UnsubscribeFromResourcesHandler = handler); return builder; } } diff --git a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Tools.cs b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Tools.cs index bbc030f..9b4b5a0 100644 --- a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Tools.cs +++ b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Tools.cs @@ -6,6 +6,7 @@ using McpDotNet.Configuration; using McpDotNet.Protocol.Types; using McpDotNet.Server; +using McpDotNet.Utils; using Microsoft.Extensions.DependencyInjection; namespace McpDotNet; @@ -40,15 +41,14 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder) /// Types with marked methods to add as tools to the server. public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params Type[] toolTypes) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); if (toolTypes is null || toolTypes.Length == 0) + { throw new ArgumentException("At least one tool type must be provided.", nameof(toolTypes)); + } - var tools = new List(); + List tools = []; Dictionary, CancellationToken, Task>> callbacks = []; foreach (var type in toolTypes) @@ -114,7 +114,9 @@ public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder bui } if (toolTypes.Count == 0) + { throw new ArgumentException("No types with marked methods found in the assembly.", nameof(assembly)); + } return WithTools(builder, [.. toolTypes]); } diff --git a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Transports.cs b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Transports.cs index 6a1a165..2eeae4e 100644 --- a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Transports.cs +++ b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Transports.cs @@ -1,5 +1,6 @@ using McpDotNet.Configuration; using McpDotNet.Protocol.Transport; +using McpDotNet.Utils; using Microsoft.Extensions.DependencyInjection; namespace McpDotNet; @@ -15,10 +16,7 @@ public static partial class McpServerBuilderExtensions /// The builder instance. public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder builder) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); builder.Services.AddSingleton(); return builder; diff --git a/src/mcpdotnet/Configuration/McpServerServiceCollectionExtension.cs b/src/mcpdotnet/Configuration/McpServerServiceCollectionExtension.cs index 7f9e8a5..79b1776 100644 --- a/src/mcpdotnet/Configuration/McpServerServiceCollectionExtension.cs +++ b/src/mcpdotnet/Configuration/McpServerServiceCollectionExtension.cs @@ -1,10 +1,12 @@ -using System.Reflection; -using McpDotNet.Configuration; +using McpDotNet.Configuration; using McpDotNet.Hosting; using McpDotNet.Protocol.Transport; using McpDotNet.Protocol.Types; using McpDotNet.Server; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using System.Reflection; namespace McpDotNet; @@ -36,11 +38,21 @@ public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, A public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, McpServerOptions serverOptions) { services.AddSingleton(serverOptions); - services.AddSingleton(); services.AddHostedService(); services.AddOptions(); + services.AddSingleton(services => + { + IServerTransport serverTransport = services.GetRequiredService(); + McpServerOptions options = services.GetRequiredService(); + ILoggerFactory? loggerFactory = services.GetService(); + + if (services.GetService>() is { } handlersOptions) + { + options = handlersOptions.Value.OverwriteWithSetHandlers(options); + } - services.AddSingleton(sp => sp.GetRequiredService().CreateServer()); + return McpServerFactory.Create(serverTransport, options, loggerFactory, services); + }); return new DefaultMcpServerBuilder(services); } diff --git a/src/mcpdotnet/Hosting/McpServerHostedService.cs b/src/mcpdotnet/Hosting/McpServerHostedService.cs index 11d4006..fcb1666 100644 --- a/src/mcpdotnet/Hosting/McpServerHostedService.cs +++ b/src/mcpdotnet/Hosting/McpServerHostedService.cs @@ -1,4 +1,5 @@ using McpDotNet.Server; +using McpDotNet.Utils; using Microsoft.Extensions.Hosting; namespace McpDotNet.Hosting; @@ -17,7 +18,9 @@ public class McpServerHostedService : BackgroundService /// public McpServerHostedService(IMcpServer server) { - _server = server ?? throw new ArgumentNullException(nameof(server)); + Throw.IfNull(server); + + _server = server; } /// diff --git a/src/mcpdotnet/Logging/Log.cs b/src/mcpdotnet/Logging/Log.cs index 551c5aa..4092efe 100644 --- a/src/mcpdotnet/Logging/Log.cs +++ b/src/mcpdotnet/Logging/Log.cs @@ -8,60 +8,15 @@ namespace McpDotNet.Logging; /// internal static partial class Log { - [LoggerMessage(Level = LogLevel.Debug, Message = "Client {clientId} initializing connection to server {serverId}")] - internal static partial void ClientConnecting(this ILogger logger, string clientId, string serverId); - [LoggerMessage(Level = LogLevel.Information, Message = "Server {endpointName} capabilities received: {capabilities}, server info: {serverInfo}")] internal static partial void ServerCapabilitiesReceived(this ILogger logger, string endpointName, string capabilities, string serverInfo); - [LoggerMessage(Level = LogLevel.Warning, Message = "Request {requestId} timed out")] - internal static partial void RequestTimeout(this ILogger logger, int requestId); - - [LoggerMessage(Level = LogLevel.Debug, Message = "Command for {endpointName} already contains shell wrapper, skipping argument injection")] - internal static partial void SkippingShellWrapper(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "Server connection config for Id={serverId} not found")] - internal static partial void ServerNotFound(this ILogger logger, string serverId); - - [LoggerMessage(Level = LogLevel.Debug, Message = "Client for {endpointName} already created, returning cached client")] - internal static partial void ClientExists(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Information, Message = "Creating client for {endpointName}")] internal static partial void CreatingClient(this ILogger logger, string endpointName); [LoggerMessage(Level = LogLevel.Information, Message = "Client for {endpointName} created and connected")] internal static partial void ClientCreated(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Information, Message = "Creating transport for {endpointName} with type {transportType} and options {options}")] - internal static partial void CreatingTransport(this ILogger logger, string endpointName, string transportType, string options); - - [LoggerMessage(Level = LogLevel.Debug, Message = "Promoting command for {endpointName} to shell argument for stdio transport with command {command} and arguments {arguments}")] - internal static partial void PromotingCommandToShellArgumentForStdio(this ILogger logger, string endpointName, string command, string arguments); - - [LoggerMessage(Level = LogLevel.Debug, Message = "Initializing stdio commands")] - internal static partial void InitializingStdioCommands(this ILogger logger); - - [LoggerMessage(Level = LogLevel.Error, Message = "{handlerName} handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void HandlerNotConfigured(this ILogger logger, string handlerName, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "List tools handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void ListToolsHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "Call tool handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void CallToolHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "List prompts handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void ListPromptsHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "Get prompt handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void GetPromptHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "List resources handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void ListResourcesHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "Read resource handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void ReadResourceHandlerNotConfigured(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Error, Message = "Client server {endpointName} already initializing")] internal static partial void ClientAlreadyInitializing(this ILogger logger, string endpointName); @@ -77,33 +32,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Error, Message = "Client {endpointName} initialization timeout")] internal static partial void ClientInitializationTimeout(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Information, Message = "Pinging sent ({endpointName})")] - internal static partial void PingingServer(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "Listing tools for {endpointName} with cursor {cursor}")] - internal static partial void ListingTools(this ILogger logger, string endpointName, string cursor); - - [LoggerMessage(Level = LogLevel.Information, Message = "Listing prompts for {endpointName} with cursor {cursor}")] - internal static partial void ListingPrompts(this ILogger logger, string endpointName, string cursor); - - [LoggerMessage(Level = LogLevel.Information, Message = "Getting prompt {name} for {endpointName} with arguments {arguments}")] - internal static partial void GettingPrompt(this ILogger logger, string endpointName, string name, string arguments); - - [LoggerMessage(Level = LogLevel.Information, Message = "Listing resources for {endpointName} with cursor {cursor}")] - internal static partial void ListingResources(this ILogger logger, string endpointName, string cursor); - - [LoggerMessage(Level = LogLevel.Information, Message = "Reading resource {uri} for {endpointName}")] - internal static partial void ReadingResource(this ILogger logger, string endpointName, string uri); - - [LoggerMessage(Level = LogLevel.Information, Message = "Subscribing to resource {uri} for {endpointName}")] - internal static partial void SubscribingToResource(this ILogger logger, string endpointName, string uri); - - [LoggerMessage(Level = LogLevel.Information, Message = "Unsubscribing from resource {uri} for {endpointName}")] - internal static partial void UnsubscribingFromResource(this ILogger logger, string endpointName, string uri); - - [LoggerMessage(Level = LogLevel.Information, Message = "Calling tool {toolName} for {endpointName} with arguments {arguments}")] - internal static partial void CallingTool(this ILogger logger, string endpointName, string toolName, string arguments); - [LoggerMessage(Level = LogLevel.Information, Message = "Endpoint message processing cancelled for {endpointName}")] internal static partial void EndpointMessageProcessingCancelled(this ILogger logger, string endpointName); @@ -152,9 +80,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Error, Message = "Request invalid response type for {endpointName} with method {method}")] internal static partial void RequestInvalidResponseType(this ILogger logger, string endpointName, string method); - [LoggerMessage(Level = LogLevel.Error, Message = "Request params type conversion error for {endpointName} with method {method}: expected {expectedType}")] - internal static partial void RequestParamsTypeConversionError(this ILogger logger, string endpointName, string method, Type expectedType); - [LoggerMessage(Level = LogLevel.Information, Message = "Cleaning up endpoint {endpointName}")] internal static partial void CleaningUpEndpoint(this ILogger logger, string endpointName); @@ -168,7 +93,7 @@ internal static partial class Log internal static partial void TransportConnecting(this ILogger logger, string endpointName); [LoggerMessage(Level = LogLevel.Information, Message = "Creating process for transport for {endpointName} with command {command}, arguments {arguments}, environment {environment}, working directory {workingDirectory}, shutdown timeout {shutdownTimeout}")] - internal static partial void CreateProcessForTransport(this ILogger logger, string endpointName, string command, string arguments, string environment, string workingDirectory, string shutdownTimeout); + internal static partial void CreateProcessForTransport(this ILogger logger, string endpointName, string command, string? arguments, string environment, string workingDirectory, string shutdownTimeout); [LoggerMessage(Level = LogLevel.Error, Message = "Transport for {endpointName} error: {data}")] internal static partial void TransportError(this ILogger logger, string endpointName, string data); @@ -239,9 +164,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Debug, Message = "Transport waiting for shutdown for {endpointName}")] internal static partial void TransportWaitingForShutdown(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Warning, Message = "Transport killing process for {endpointName}")] - internal static partial void TransportKillingProcess(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Error, Message = "Transport shutdown failed for {endpointName}")] internal static partial void TransportShutdownFailed(this ILogger logger, string endpointName, Exception exception); @@ -266,9 +188,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Debug, Message = "Sending message to {endpointName}: {message}")] internal static partial void SendingMessage(this ILogger logger, string endpointName, string message); - [LoggerMessage(Level = LogLevel.Information, Message = "Sending notification for {endpointName}: {method}")] - internal static partial void SendingNotification(this ILogger logger, string endpointName, string method); - [LoggerMessage( EventId = 7000, Level = LogLevel.Error, @@ -310,50 +229,6 @@ public static partial void TransportEndpointEventParseFailed( string data, Exception exception); - [LoggerMessage( - EventId = 7004, - Level = LogLevel.Trace, - Message = "Invalid completion reference {reference} for {endpointName}: {validationMessage}" - )] - public static partial void InvalidCompletionReference( - this ILogger logger, - string endpointName, - string reference, - string validationMessage); - - [LoggerMessage( - EventId = 7005, - Level = LogLevel.Trace, - Message = "Invalid completion argument name {argumentName} for {endpointName}" - )] - public static partial void InvalidCompletionArgumentName( - this ILogger logger, - string endpointName, - string argumentName); - - [LoggerMessage( - EventId = 7006, - Level = LogLevel.Trace, - Message = "Invalid completion argument value {argumentValue} for {argumentName} for {endpointName}" - )] - public static partial void InvalidCompletionArgumentValue( - this ILogger logger, - string endpointName, - string argumentValue, - string argumentName); - - [LoggerMessage( - EventId = 7007, - Level = LogLevel.Debug, - Message = "Getting completion for {endpointName} with reference {reference}, argument name {argumentName}, argument value {argumentValue}" - )] - public static partial void GettingCompletion( - this ILogger logger, - string endpointName, - string reference, - string argumentName, - string argumentValue); - [LoggerMessage( EventId = 7008, Level = LogLevel.Error, diff --git a/src/mcpdotnet/Protocol/Messages/OperationNames.cs b/src/mcpdotnet/Protocol/Messages/OperationNames.cs index f8f3bcc..903256c 100644 --- a/src/mcpdotnet/Protocol/Messages/OperationNames.cs +++ b/src/mcpdotnet/Protocol/Messages/OperationNames.cs @@ -1,8 +1,4 @@ -using McpDotNet.Protocol.Messages; -using McpDotNet.Protocol.Types; -using Microsoft.Extensions.Logging; - -namespace McpDotNet.Protocol.Messages; +namespace McpDotNet.Protocol.Messages; /// Provides names of standard operations for use with registering handlers. /// diff --git a/src/mcpdotnet/Protocol/Messages/RequestIdConverter.cs b/src/mcpdotnet/Protocol/Messages/RequestIdConverter.cs index 4599e95..33b6d9e 100644 --- a/src/mcpdotnet/Protocol/Messages/RequestIdConverter.cs +++ b/src/mcpdotnet/Protocol/Messages/RequestIdConverter.cs @@ -1,4 +1,5 @@ -using System.Text.Json; +using McpDotNet.Utils; +using System.Text.Json; using System.Text.Json.Serialization; namespace McpDotNet.Protocol.Messages; @@ -22,14 +23,15 @@ public override RequestId Read(ref Utf8JsonReader reader, Type typeToConvert, Js /// public override void Write(Utf8JsonWriter writer, RequestId value, JsonSerializerOptions options) { - if (writer is null) - { - throw new ArgumentNullException(nameof(writer)); - } + Throw.IfNull(writer); if (value.IsString) + { writer.WriteStringValue(value.AsString); + } else + { writer.WriteNumberValue(value.AsNumber); + } } } \ No newline at end of file diff --git a/src/mcpdotnet/Protocol/Transport/PasteArguments.cs b/src/mcpdotnet/Protocol/Transport/PasteArguments.cs deleted file mode 100644 index be4750c..0000000 --- a/src/mcpdotnet/Protocol/Transport/PasteArguments.cs +++ /dev/null @@ -1,102 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -// Copied from: -// https://github.com/dotnet/runtime/blob/d2650b6ae7023a2d9d2c74c56116f1f18472ab04/src/libraries/System.Private.CoreLib/src/System/PasteArguments.cs -// and changed from using ValueStringBuilder to StringBuilder. - -using System.Text; - -namespace McpDotNet.Protocol.Transport -{ - internal static partial class PasteArguments - { - internal static void AppendArgument(StringBuilder stringBuilder, string argument) - { - if (stringBuilder.Length != 0) - { - stringBuilder.Append(' '); - } - - // Parsing rules for non-argv[0] arguments: - // - Backslash is a normal character except followed by a quote. - // - 2N backslashes followed by a quote ==> N literal backslashes followed by unescaped quote - // - 2N+1 backslashes followed by a quote ==> N literal backslashes followed by a literal quote - // - Parsing stops at first whitespace outside of quoted region. - // - (post 2008 rule): A closing quote followed by another quote ==> literal quote, and parsing remains in quoting mode. - if (argument.Length != 0 && ContainsNoWhitespaceOrQuotes(argument)) - { - // Simple case - no quoting or changes needed. - stringBuilder.Append(argument); - } - else - { - stringBuilder.Append(Quote); - int idx = 0; - while (idx < argument.Length) - { - char c = argument[idx++]; - if (c == Backslash) - { - int numBackSlash = 1; - while (idx < argument.Length && argument[idx] == Backslash) - { - idx++; - numBackSlash++; - } - - if (idx == argument.Length) - { - // We'll emit an end quote after this so must double the number of backslashes. - stringBuilder.Append(Backslash, numBackSlash * 2); - } - else if (argument[idx] == Quote) - { - // Backslashes will be followed by a quote. Must double the number of backslashes. - stringBuilder.Append(Backslash, numBackSlash * 2 + 1); - stringBuilder.Append(Quote); - idx++; - } - else - { - // Backslash will not be followed by a quote, so emit as normal characters. - stringBuilder.Append(Backslash, numBackSlash); - } - - continue; - } - - if (c == Quote) - { - // Escape the quote so it appears as a literal. This also guarantees that we won't end up generating a closing quote followed - // by another quote (which parses differently pre-2008 vs. post-2008.) - stringBuilder.Append(Backslash); - stringBuilder.Append(Quote); - continue; - } - - stringBuilder.Append(c); - } - - stringBuilder.Append(Quote); - } - } - - private static bool ContainsNoWhitespaceOrQuotes(string s) - { - for (int i = 0; i < s.Length; i++) - { - char c = s[i]; - if (char.IsWhiteSpace(c) || c == Quote) - { - return false; - } - } - - return true; - } - - private const char Quote = '\"'; - private const char Backslash = '\\'; - } -} \ No newline at end of file diff --git a/src/mcpdotnet/Protocol/Transport/SseClientTransport.cs b/src/mcpdotnet/Protocol/Transport/SseClientTransport.cs index 9cc4c7a..94e9d28 100644 --- a/src/mcpdotnet/Protocol/Transport/SseClientTransport.cs +++ b/src/mcpdotnet/Protocol/Transport/SseClientTransport.cs @@ -5,6 +5,7 @@ using McpDotNet.Configuration; using McpDotNet.Logging; using McpDotNet.Protocol.Messages; +using McpDotNet.Utils; using McpDotNet.Utils.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -54,27 +55,16 @@ public SseClientTransport(SseClientTransportOptions transportOptions, McpServerC public SseClientTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, HttpClient httpClient, ILoggerFactory? loggerFactory, bool ownsHttpClient = false) : base(loggerFactory) { - if (transportOptions is null) - { - throw new ArgumentNullException(nameof(transportOptions)); - } - - if (serverConfig is null) - { - throw new ArgumentNullException(nameof(serverConfig)); - } - - if (httpClient is null) - { - throw new ArgumentNullException(nameof(httpClient)); - } + Throw.IfNull(transportOptions); + Throw.IfNull(serverConfig); + Throw.IfNull(httpClient); _options = transportOptions; _serverConfig = serverConfig; _sseEndpoint = new Uri(serverConfig.Location!); _httpClient = httpClient; _connectionCts = new CancellationTokenSource(); - _logger = loggerFactory is not null ? loggerFactory.CreateLogger() : NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _jsonOptions = JsonSerializerOptionsExtensions.DefaultOptions; _connectionEstablished = new TaskCompletionSource(); _ownsHttpClient = ownsHttpClient; diff --git a/src/mcpdotnet/Protocol/Transport/StdioClientTransport.cs b/src/mcpdotnet/Protocol/Transport/StdioClientTransport.cs index cb705c2..d8613d5 100644 --- a/src/mcpdotnet/Protocol/Transport/StdioClientTransport.cs +++ b/src/mcpdotnet/Protocol/Transport/StdioClientTransport.cs @@ -1,6 +1,4 @@ using System.Diagnostics; -using System.Runtime.InteropServices; -using System.Text; using System.Text.Json; using McpDotNet.Configuration; using McpDotNet.Logging; @@ -10,6 +8,8 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +#pragma warning disable CA2213 // Disposable fields should be disposed + namespace McpDotNet.Protocol.Transport; /// @@ -37,15 +37,8 @@ public sealed class StdioClientTransport : TransportBase, IClientTransport public StdioClientTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null) : base(loggerFactory) { - if (options is null) - { - throw new ArgumentNullException(nameof(options)); - } - - if (serverConfig is null) - { - throw new ArgumentNullException(nameof(serverConfig)); - } + Throw.IfNull(options); + Throw.IfNull(serverConfig); _options = options; _serverConfig = serverConfig; @@ -71,23 +64,17 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) var startInfo = new ProcessStartInfo { FileName = _options.Command, - UseShellExecute = false, RedirectStandardInput = true, RedirectStandardOutput = true, RedirectStandardError = true, - CreateNoWindow = RuntimeInformation.IsOSPlatform(OSPlatform.Windows), + UseShellExecute = false, + CreateNoWindow = true, WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, }; - if (_options.Arguments is { Length: > 0 }) + if (!string.IsNullOrWhiteSpace(_options.Arguments)) { - StringBuilder argsBuilder = new(); - foreach (var arg in _options.Arguments) - { - PasteArguments.AppendArgument(argsBuilder, arg); - } - - startInfo.Arguments = argsBuilder.ToString(); + startInfo.Arguments = _options.Arguments; } if (_options.EnvironmentVariables != null) @@ -105,10 +92,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) _process = new Process { StartInfo = startInfo }; // Set up error logging - _process.ErrorDataReceived += (sender, args) => - { - _logger.TransportError(EndpointName, args.Data ?? "(no data)"); - }; + _process.ErrorDataReceived += (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)"); if (!_process.Start()) { @@ -120,7 +104,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) _process.BeginErrorReadLine(); // Start reading messages in the background - _readTask = Task.Run(async () => await ReadMessagesAsync(_shutdownCts.Token).ConfigureAwait(false), CancellationToken.None); + _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); _logger.TransportReadingMessages(EndpointName); SetConnected(true); @@ -272,10 +256,10 @@ private async Task CleanupAsync(CancellationToken cancellationToken) _process = null; } - if (_shutdownCts != null) + if (_shutdownCts is { } shutdownCts) { - await _shutdownCts.CancelAsync().ConfigureAwait(false); - _shutdownCts.Dispose(); + await shutdownCts.CancelAsync().ConfigureAwait(false); + shutdownCts.Dispose(); _shutdownCts = null; } diff --git a/src/mcpdotnet/Protocol/Transport/StdioClientTransportOptions.cs b/src/mcpdotnet/Protocol/Transport/StdioClientTransportOptions.cs index 35f303d..1ee77e5 100644 --- a/src/mcpdotnet/Protocol/Transport/StdioClientTransportOptions.cs +++ b/src/mcpdotnet/Protocol/Transport/StdioClientTransportOptions.cs @@ -19,7 +19,7 @@ public record StdioClientTransportOptions /// /// Arguments to pass to the server process. /// - public string[]? Arguments { get; set; } = []; + public string? Arguments { get; set; } /// /// The working directory for the server process. diff --git a/src/mcpdotnet/Protocol/Transport/StdioServerTransport.cs b/src/mcpdotnet/Protocol/Transport/StdioServerTransport.cs index cc3a09f..30045ff 100644 --- a/src/mcpdotnet/Protocol/Transport/StdioServerTransport.cs +++ b/src/mcpdotnet/Protocol/Transport/StdioServerTransport.cs @@ -1,13 +1,14 @@ using System.Text.Json; -using McpDotNet.Configuration; using McpDotNet.Logging; using McpDotNet.Protocol.Messages; using McpDotNet.Server; +using McpDotNet.Utils; using McpDotNet.Utils.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; #pragma warning disable CA2208 // Instantiate argument exceptions correctly +#pragma warning disable CA2213 // Disposable fields should be disposed namespace McpDotNet.Protocol.Transport; @@ -104,12 +105,15 @@ public StdioServerTransport(McpServerOptions serverOptions, TextReader input, Te public StdioServerTransport(string serverName, TextReader input, TextWriter output, ILoggerFactory? loggerFactory = null) : base(loggerFactory) { - _serverName = serverName ?? throw new ArgumentNullException(nameof(serverName)); - _input = input ?? throw new ArgumentNullException(nameof(input)); - _output = output ?? throw new ArgumentNullException(nameof(output)); + Throw.IfNull(serverName); + Throw.IfNull(input); + Throw.IfNull(output); - _logger = loggerFactory is not null ? loggerFactory.CreateLogger() : NullLogger.Instance; + _serverName = serverName; + _input = input; + _output = output; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _jsonOptions = JsonSerializerOptionsExtensions.DefaultOptions; } @@ -274,15 +278,8 @@ private async Task CleanupAsync(CancellationToken cancellationToken) /// Validates the and extracts from it the server name to use. private static string GetServerName(McpServerOptions serverOptions) { - if (serverOptions is null) - { - throw new ArgumentNullException(nameof(serverOptions)); - } - - if (serverOptions.ServerInfo is null) - { - throw new ArgumentNullException($"{nameof(serverOptions)}.{nameof(serverOptions.ServerInfo)}"); - } + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); return serverOptions.ServerInfo.Name; } diff --git a/src/mcpdotnet/Protocol/Transport/TransportBase.cs b/src/mcpdotnet/Protocol/Transport/TransportBase.cs index d55807b..b57733c 100644 --- a/src/mcpdotnet/Protocol/Transport/TransportBase.cs +++ b/src/mcpdotnet/Protocol/Transport/TransportBase.cs @@ -26,7 +26,7 @@ protected TransportBase(ILoggerFactory? loggerFactory) SingleReader = true, SingleWriter = true, }); - _logger = loggerFactory is not null ? loggerFactory.CreateLogger() : NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; } /// diff --git a/src/mcpdotnet/Protocol/Types/Capabilities.cs b/src/mcpdotnet/Protocol/Types/Capabilities.cs index 20a5e13..b623f29 100644 --- a/src/mcpdotnet/Protocol/Types/Capabilities.cs +++ b/src/mcpdotnet/Protocol/Types/Capabilities.cs @@ -1,6 +1,8 @@ -using System.Text.Json.Serialization; +using McpDotNet.Server; +using System.Text.Json.Serialization; namespace McpDotNet.Protocol.Types; + /// /// Represents the capabilities that a client may support. /// See the schema for details @@ -37,6 +39,10 @@ public record RootsCapability /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } + + /// Gets or sets the handler for sampling requests. + [JsonIgnore] + public Func>? RootsHandler { get; init; } } /// @@ -46,6 +52,10 @@ public record RootsCapability public record SamplingCapability { // Currently empty in the spec, but may be extended in the future + + /// Gets or sets the handler for sampling requests. + [JsonIgnore] + public Func>? SamplingHandler { get; init; } } /// @@ -68,6 +78,18 @@ public record PromptsCapability /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } + + /// + /// Gets or sets the handler for list prompts requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? ListPromptsHandler { get; init; } + + /// + /// Gets or sets the handler for get prompt requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? GetPromptHandler { get; init; } } /// @@ -87,6 +109,30 @@ public record ResourcesCapability /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } + + /// + /// Gets or sets the handler for list resources requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? ListResourcesHandler { get; init; } + + /// + /// Gets or sets the handler for read resources requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? ReadResourceHandler { get; init; } + + /// + /// Gets or sets the handler for subscribe to resources messages. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? SubscribeToResourcesHandler { get; init; } + + /// + /// Gets or sets the handler for unsubscribe from resources messages. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? UnsubscribeFromResourcesHandler { get; init; } } /// @@ -100,4 +146,16 @@ public record ToolsCapability /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } + + /// + /// Gets or sets the handler for list tools requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? ListToolsHandler { get; init; } + + /// + /// Gets or sets the handler for call tool requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? CallToolHandler { get; init; } } \ No newline at end of file diff --git a/src/mcpdotnet/Server/IMcpServer.cs b/src/mcpdotnet/Server/IMcpServer.cs index 8672701..024b883 100644 --- a/src/mcpdotnet/Server/IMcpServer.cs +++ b/src/mcpdotnet/Server/IMcpServer.cs @@ -28,20 +28,6 @@ public interface IMcpServer : IAsyncDisposable /// IServiceProvider? ServiceProvider { get; } - /// Sets a handler for the named operation. - /// The name of the operation. - /// The handler. Each operation requires a specific delegate signature. - /// - /// - /// Each operation may have only a single handler. Setting a handler for an operation that already has one - /// will replace the existing handler. - /// - /// - /// provides constants for common operations. - /// - /// - void SetOperationHandler(string operationName, Delegate handler); - /// /// Adds a handler for client notifications of a specific method. /// diff --git a/src/mcpdotnet/Server/IMcpServerFactory.cs b/src/mcpdotnet/Server/IMcpServerFactory.cs deleted file mode 100644 index a8c1e70..0000000 --- a/src/mcpdotnet/Server/IMcpServerFactory.cs +++ /dev/null @@ -1,13 +0,0 @@ -namespace McpDotNet.Server; - -/// -/// Factory for creating instances. -/// -public interface IMcpServerFactory -{ - /// - /// Creates a new server instance. - /// - /// - IMcpServer CreateServer(); -} diff --git a/src/mcpdotnet/Server/McpServer.cs b/src/mcpdotnet/Server/McpServer.cs index 2c525c4..ddd2be6 100644 --- a/src/mcpdotnet/Server/McpServer.cs +++ b/src/mcpdotnet/Server/McpServer.cs @@ -1,11 +1,11 @@ - -using System.Text.Json.Nodes; +using System.Text.Json.Nodes; using McpDotNet.Logging; -using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Transport; using McpDotNet.Protocol.Types; using McpDotNet.Shared; +using McpDotNet.Utils; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Server; @@ -15,17 +15,7 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer private readonly IServerTransport _serverTransport; private readonly McpServerOptions _options; private volatile bool _isInitializing; - private readonly ILogger _logger; - - private Func, CancellationToken, Task>? _listToolsHandler; - private Func, CancellationToken, Task>? _callToolHandler; - private Func, CancellationToken, Task>? _listPromptsHandler; - private Func, CancellationToken, Task>? _getPromptHandler; - private Func, CancellationToken, Task>? _listResourcesHandler; - private Func, CancellationToken, Task>? _readResourceHandler; - private Func, CancellationToken, Task>? _getCompletionHandler; - private Func, CancellationToken, Task>? _subscribeToResourcesHandler; - private Func, CancellationToken, Task>? _unsubscribeFromResourcesHandler; + private readonly ILogger _logger; /// /// Creates a new instance of . @@ -36,22 +26,29 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer /// Logger factory to use for logging /// Optional service provider to use for dependency injection /// - public McpServer(IServerTransport transport, McpServerOptions options, ILoggerFactory loggerFactory, IServiceProvider? serviceProvider) + public McpServer(IServerTransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) : base(transport, loggerFactory) { + Throw.IfNull(options); + _serverTransport = transport; - _options = options ?? throw new ArgumentNullException(nameof(options)); - _logger = loggerFactory.CreateLogger(); + _options = options; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; ServerInstructions = options.ServerInstructions; ServiceProvider = serviceProvider; + AddNotificationHandler("notifications/initialized", _ => + { + IsInitialized = true; + return Task.CompletedTask; + }); + + SetInitializeHandler(options); + SetCompletionHandler(options); + SetPingHandler(); SetToolsHandler(options); SetPromptsHandler(options); SetResourcesHandler(options); - SetCompletionHandler(); - SetInitializeHandler(options); - SetPingHandler(); - AddNotificationHandler(); } public ClientCapabilities? ClientCapabilities { get; set; } @@ -69,54 +66,6 @@ public McpServer(IServerTransport transport, McpServerOptions options, ILoggerFa public override string EndpointName => $"Server ({_options.ServerInfo.Name} {_options.ServerInfo.Version}), Client ({ClientInfo?.Name} {ClientInfo?.Version})"; - public void SetOperationHandler(string operationName, Delegate handler) - { - if (operationName is null) - { - throw new ArgumentNullException(nameof(operationName)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - if (!TrySetOperationHandler(OperationNames.ListTools, operationName, handler, ref _listToolsHandler) && - !TrySetOperationHandler(OperationNames.CallTool, operationName, handler, ref _callToolHandler) && - !TrySetOperationHandler(OperationNames.ListPrompts, operationName, handler, ref _listPromptsHandler) && - !TrySetOperationHandler(OperationNames.GetPrompt, operationName, handler, ref _getPromptHandler) && - !TrySetOperationHandler(OperationNames.ListResources, operationName, handler, ref _listResourcesHandler) && - !TrySetOperationHandler(OperationNames.ReadResource, operationName, handler, ref _readResourceHandler) && - !TrySetOperationHandler(OperationNames.GetCompletion, operationName, handler, ref _getCompletionHandler) && - !TrySetOperationHandler(OperationNames.SubscribeToResources, operationName, handler, ref _subscribeToResourcesHandler) && - !TrySetOperationHandler(OperationNames.UnsubscribeFromResources, operationName, handler, ref _unsubscribeFromResourcesHandler)) - { - throw new ArgumentException($"Unknown operation '{operationName}'", nameof(operationName)); - } - - static bool TrySetOperationHandler( - string targetOperationName, - string operationName, - Delegate handler, - ref Func? field) - { - if (operationName == targetOperationName) - { - if (handler is Func typed) - { - field = typed; - return true; - } - - throw new ArgumentException( - $"Handler must be of type {typeof(Func)}", - nameof(handler)); - } - - return false; - } - } - /// public async Task StartAsync(CancellationToken cancellationToken = default) { @@ -153,15 +102,6 @@ public async Task StartAsync(CancellationToken cancellationToken = default) } } - private void AddNotificationHandler() - { - AddNotificationHandler("notifications/initialized", (notification) => - { - IsInitialized = true; - return Task.CompletedTask; - }); - } - private void SetPingHandler() { SetRequestHandler("ping", @@ -171,136 +111,94 @@ private void SetPingHandler() private void SetInitializeHandler(McpServerOptions options) { SetRequestHandler("initialize", - (request) => - { - ClientCapabilities = request?.Capabilities ?? new(); - ClientInfo = request?.ClientInfo; - return Task.FromResult(new InitializeResult() - { - ProtocolVersion = options.ProtocolVersion, - Instructions = ServerInstructions, - ServerInfo = _options.ServerInfo, - Capabilities = options.Capabilities ?? new ServerCapabilities(), - }); - }); + request => + { + ClientCapabilities = request?.Capabilities ?? new(); + ClientInfo = request?.ClientInfo; + return Task.FromResult(new InitializeResult() + { + ProtocolVersion = options.ProtocolVersion, + Instructions = ServerInstructions, + ServerInfo = _options.ServerInfo, + Capabilities = options.Capabilities ?? new ServerCapabilities(), + }); + }); } - private void SetCompletionHandler() + private void SetCompletionHandler(McpServerOptions options) { + // This capability is not optional, so return an empty result if there is no handler. SetRequestHandler("completion/complete", - async (request) => - { - if (_getCompletionHandler is null) - { - // This capability is not optional, so return an empty result if there is no handler - return new CompleteResult() - { - Completion = new() - { - Values = [], - Total = 0, - HasMore = false - } - }; - } - - return await _getCompletionHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + options.GetCompletionHandler is { } handler ? + request => handler(new(this, request), CancellationTokenSource?.Token ?? default) : + request => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } })); } private void SetResourcesHandler(McpServerOptions options) { - if (options.Capabilities?.Resources is not null) + if (options.Capabilities?.Resources is not { } resourcesCapability) { - SetRequestHandler("resources/list", - async (request) => - { - if (_listResourcesHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.ListResourcesHandlerNotConfigured(EndpointName); - throw new McpServerException("ListResources handler not configured"); - } - - return await _listResourcesHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + return; + } - SetRequestHandler("resources/read", - async (request) => - { - if (_readResourceHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.ReadResourceHandlerNotConfigured(EndpointName); - throw new McpServerException("ReadResource handler not configured"); - } - - return await _readResourceHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + if (resourcesCapability.ListResourcesHandler is not { } listResourcesHandler || + resourcesCapability.ReadResourceHandler is not { } readResourceHandler) + { + throw new McpServerException("Resources capability was enabled, but ListResources and/or ReadResource handlers were not specified."); + } + + CancellationToken cancellationToken = CancellationTokenSource?.Token ?? default; + SetRequestHandler("resources/list", request => listResourcesHandler(new(this, request), cancellationToken)); + SetRequestHandler("resources/read", request => readResourceHandler(new(this, request), cancellationToken)); + + if (resourcesCapability.Subscribe is not true) + { + return; + } + + var subscribeHandler = resourcesCapability.SubscribeToResourcesHandler; + var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler; + if (subscribeHandler is null || unsubscribeHandler is null) + { + throw new McpServerException("Resources capability was enabled with subscribe support, but SubscribeToResources and/or UnsubscribeFromResources handlers were not specified."); } + + // TODO: Implement Subscribe support } private void SetPromptsHandler(McpServerOptions options) { - if (options.Capabilities?.Prompts is not null) + if (options.Capabilities?.Prompts is not { } promptsCapability) { - SetRequestHandler("prompts/list", - async (request) => - { - if (_listPromptsHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.ListPromptsHandlerNotConfigured(EndpointName); - throw new McpServerException("ListPrompts handler not configured"); - } - - return await _listPromptsHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + return; + } - SetRequestHandler("prompts/get", - async (request) => - { - if (_getPromptHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.GetPromptHandlerNotConfigured(EndpointName); - throw new McpServerException("GetPrompt handler not configured"); - } - - return await _getPromptHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + if (promptsCapability.ListPromptsHandler is not { } listPromptsHandler || + promptsCapability.GetPromptHandler is not { } getPromptHandler) + { + throw new McpServerException("Prompts capability was enabled, but ListPrompts and/or GetPrompt handlers were not specified."); } + + CancellationToken cancellationToken = CancellationTokenSource?.Token ?? default; + SetRequestHandler("prompts/list", request => listPromptsHandler(new(this, request), cancellationToken)); + SetRequestHandler("prompts/get", request => getPromptHandler(new(this, request), cancellationToken)); } private void SetToolsHandler(McpServerOptions options) { - if (options.Capabilities?.Tools is not null) + if (options.Capabilities?.Tools is not { } toolsCapability) { - SetRequestHandler("tools/list", - async (request) => - { - if (_listToolsHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.ListToolsHandlerNotConfigured(EndpointName); - throw new McpServerException("ListTools handler not configured"); - } - - return await _listToolsHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + return; + } - SetRequestHandler("tools/call", - async (request) => - { - if (_callToolHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.CallToolHandlerNotConfigured(EndpointName); - throw new McpServerException("CallTool handler not configured"); - } - - return await _callToolHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + if (toolsCapability.ListToolsHandler is not { } listToolsHandler || + toolsCapability.CallToolHandler is not { } callToolHandler) + { + throw new McpServerException("ListTools and/or CallTool handlers were specified but the Tools capability was not enabled."); } + + CancellationToken cancellationToken = CancellationTokenSource?.Token ?? default; + SetRequestHandler("tools/list", request => listToolsHandler(new(this, request), cancellationToken)); + SetRequestHandler("tools/call", request => callToolHandler(new(this, request), cancellationToken)); } } diff --git a/src/mcpdotnet/Server/McpServerDelegates.cs b/src/mcpdotnet/Server/McpServerDelegates.cs deleted file mode 100644 index 419eabd..0000000 --- a/src/mcpdotnet/Server/McpServerDelegates.cs +++ /dev/null @@ -1,93 +0,0 @@ -using McpDotNet.Protocol.Types; - -namespace McpDotNet.Server; - -/// -/// Container for delegates that can be applied to an MCP server. -/// -public class McpServerDelegates -{ - /// - /// Gets or sets the handler for list tools requests. - /// - public Func, CancellationToken, Task>? ListToolsHandler { get; set; } - - /// - /// Gets or sets the handler for call tool requests. - /// - public Func, CancellationToken, Task>? CallToolHandler { get; set; } - - /// - /// Gets or sets the handler for list prompts requests. - /// - public Func, CancellationToken, Task>? ListPromptsHandler { get; set; } - - /// - /// Gets or sets the handler for get prompt requests. - /// - public Func, CancellationToken, Task>? GetPromptHandler { get; set; } - - /// - /// Gets or sets the handler for list resources requests. - /// - public Func, CancellationToken, Task>? ListResourcesHandler { get; set; } - - /// - /// Gets or sets the handler for read resources requests. - /// - public Func, CancellationToken, Task>? ReadResourceHandler { get; set; } - - /// - /// Gets or sets the handler for get resources requests. - /// - public Func, CancellationToken, Task>? GetCompletionHandler { get; set; } - - /// - /// Gets or sets the handler for subscribe to resources messages. - /// - public Func, CancellationToken, Task>? SubscribeToResourcesHandler { get; set; } - - /// - /// Gets or sets the handler for subscribe to resources messages. - /// - public Func, CancellationToken, Task>? UnsubscribeFromResourcesHandler { get; set; } - - /// - /// Applies the delegates to the server. - /// - /// - public void Apply(IMcpServer server) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (ListToolsHandler != null) - server.SetListToolsHandler(ListToolsHandler); - - if (CallToolHandler != null) - server.SetCallToolHandler(CallToolHandler); - - if (ListPromptsHandler != null) - server.SetListPromptsHandler(ListPromptsHandler); - - if (GetPromptHandler != null) - server.SetGetPromptHandler(GetPromptHandler); - - if (ListResourcesHandler != null) - server.SetListResourcesHandler(ListResourcesHandler); - - if (ReadResourceHandler != null) - server.SetReadResourceHandler(ReadResourceHandler); - - if (GetCompletionHandler != null) - server.SetGetCompletionHandler(GetCompletionHandler); - - if (SubscribeToResourcesHandler != null) - server.SetSubscribeToResourcesHandler(SubscribeToResourcesHandler); - - if (UnsubscribeFromResourcesHandler != null) - server.SetUnsubscribeFromResourcesHandler(UnsubscribeFromResourcesHandler); - } -} diff --git a/src/mcpdotnet/Server/McpServerExtensions.cs b/src/mcpdotnet/Server/McpServerExtensions.cs index 5ed5f4c..06e9ded 100644 --- a/src/mcpdotnet/Server/McpServerExtensions.cs +++ b/src/mcpdotnet/Server/McpServerExtensions.cs @@ -1,5 +1,6 @@ using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Types; +using McpDotNet.Utils; namespace McpDotNet.Server; @@ -12,10 +13,7 @@ public static class McpServerExtensions public static Task RequestSamplingAsync( this IMcpServer server, CreateMessageRequestParams request, CancellationToken cancellationToken) { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } + Throw.IfNull(server); if (server.ClientCapabilities?.Sampling is null) { @@ -33,10 +31,7 @@ public static Task RequestSamplingAsync( public static Task RequestRootsAsync( this IMcpServer server, ListRootsRequestParams request, CancellationToken cancellationToken) { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } + Throw.IfNull(server); if (server.ClientCapabilities?.Roots is null) { @@ -47,175 +42,4 @@ public static Task RequestRootsAsync( new JsonRpcRequest { Method = "roots/list", Params = request }, cancellationToken); } - - /// - /// Sets the handler for list tools requests. - /// - public static void SetListToolsHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.ListTools, handler); - } - - /// - /// Sets the handler for call tool requests. - /// - public static void SetCallToolHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.CallTool, handler); - } - - /// - /// Sets the handler for list prompts requests. - /// - public static void SetListPromptsHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.ListPrompts, handler); - } - - /// - /// Sets the handler for get prompt requests. - /// - public static void SetGetPromptHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.GetPrompt, handler); - } - - /// - /// Sets the handler for list resources requests. - /// - public static void SetListResourcesHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.ListResources, handler); - } - - /// - /// Sets the handler for read resource requests. - /// - public static void SetReadResourceHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.ReadResource, handler); - } - - /// - /// Sets the handler for get completion requests. - /// - public static void SetGetCompletionHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.GetCompletion, handler); - } - - /// - /// Sets the handler for subscribe to resources requests. - /// - public static void SetSubscribeToResourcesHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.SubscribeToResources, handler); - } - - /// - /// Sets the handler for unsubscribe from resources requests. - /// - public static void SetUnsubscribeFromResourcesHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.UnsubscribeFromResources, handler); - } } diff --git a/src/mcpdotnet/Server/McpServerFactory.cs b/src/mcpdotnet/Server/McpServerFactory.cs index 26e1e45..a8e37d2 100644 --- a/src/mcpdotnet/Server/McpServerFactory.cs +++ b/src/mcpdotnet/Server/McpServerFactory.cs @@ -1,62 +1,36 @@ using McpDotNet.Protocol.Transport; +using McpDotNet.Utils; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.Extensions.Options; namespace McpDotNet.Server; /// -/// Factory for creating instances. -/// This is the main entry point for creating a server. -/// Pass the server transport, options, and logger factory to the constructor. Server instructions are optional. -/// -/// Then call CreateServer to create a new server instance. -/// You can create multiple servers with the same factory, but the transport must be able to handle multiple connections. -/// -/// You must register handlers for all supported capabilities on the server instance, before calling BeginListeningAsync. +/// Provides a factory for creating instances. /// -public class McpServerFactory : IMcpServerFactory +public static class McpServerFactory { - private readonly IServerTransport _serverTransport; - private readonly McpServerOptions _options; - private readonly ILoggerFactory _loggerFactory; - private readonly McpServerDelegates? _serverDelegates; - private readonly IServiceProvider? _serviceProvider; - /// /// Initializes a new instance of the class. /// /// Transport to use for the server - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// /// Optional service provider to create new instances. /// Logger factory to use for logging - /// - public McpServerFactory( + /// An that's started and ready to receive connections. + /// is . + /// is . + public static IMcpServer Create( IServerTransport serverTransport, - McpServerOptions options, + McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null, - IOptions? serverDelegates = null, IServiceProvider? serviceProvider = null) { - _serverTransport = serverTransport ?? throw new ArgumentNullException(nameof(serverTransport)); - _options = options ?? throw new ArgumentNullException(nameof(options)); - _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; - _serverDelegates = serverDelegates?.Value; - _serviceProvider = serviceProvider; - } - - /// - /// Creates a new server instance. - /// - /// NB! You must register handlers for all supported capabilities on the server instance, before calling BeginListeningAsync. - /// - public IMcpServer CreateServer() - { - var server = new McpServer(_serverTransport, _options, _loggerFactory, _serviceProvider); - - _serverDelegates?.Apply(server); + Throw.IfNull(serverTransport); + Throw.IfNull(serverOptions); - return server; + return new McpServer(serverTransport, serverOptions, loggerFactory, serviceProvider); } } diff --git a/src/mcpdotnet/Server/McpServerHandlers.cs b/src/mcpdotnet/Server/McpServerHandlers.cs new file mode 100644 index 0000000..507b9b4 --- /dev/null +++ b/src/mcpdotnet/Server/McpServerHandlers.cs @@ -0,0 +1,135 @@ +using McpDotNet.Protocol.Types; + +namespace McpDotNet.Server; + +/// +/// Container for handlers used in the creation of an MCP server. +/// +public sealed class McpServerHandlers +{ + /// + /// Gets or sets the handler for list tools requests. + /// + public Func, CancellationToken, Task>? ListToolsHandler { get; set; } + + /// + /// Gets or sets the handler for call tool requests. + /// + public Func, CancellationToken, Task>? CallToolHandler { get; set; } + + /// + /// Gets or sets the handler for list prompts requests. + /// + public Func, CancellationToken, Task>? ListPromptsHandler { get; set; } + + /// + /// Gets or sets the handler for get prompt requests. + /// + public Func, CancellationToken, Task>? GetPromptHandler { get; set; } + + /// + /// Gets or sets the handler for list resources requests. + /// + public Func, CancellationToken, Task>? ListResourcesHandler { get; set; } + + /// + /// Gets or sets the handler for read resources requests. + /// + public Func, CancellationToken, Task>? ReadResourceHandler { get; set; } + + /// + /// Gets or sets the handler for get completion requests. + /// + public Func, CancellationToken, Task>? GetCompletionHandler { get; set; } + + /// + /// Gets or sets the handler for subscribe to resources messages. + /// + public Func, CancellationToken, Task>? SubscribeToResourcesHandler { get; set; } + + /// + /// Gets or sets the handler for unsubscribe from resources messages. + /// + public Func, CancellationToken, Task>? UnsubscribeFromResourcesHandler { get; set; } + + /// + /// Overwrite any handlers in McpServerOptions with non-null handlers from this instance. + /// + /// + /// + internal McpServerOptions OverwriteWithSetHandlers(McpServerOptions options) + { + PromptsCapability? promptsCapability = options.Capabilities?.Prompts; + if (ListPromptsHandler is not null || GetPromptHandler is not null) + { + promptsCapability = promptsCapability is null ? + new() + { + ListPromptsHandler = ListPromptsHandler, + GetPromptHandler = GetPromptHandler, + } : + promptsCapability with + { + ListPromptsHandler = ListPromptsHandler ?? promptsCapability.ListPromptsHandler, + GetPromptHandler = GetPromptHandler ?? promptsCapability.GetPromptHandler, + }; + } + + ResourcesCapability? resourcesCapability = options.Capabilities?.Resources; + if (ListResourcesHandler is not null || + ReadResourceHandler is not null || + SubscribeToResourcesHandler is not null || + UnsubscribeFromResourcesHandler is not null) + { + resourcesCapability = resourcesCapability is null ? + new() + { + ListResourcesHandler = ListResourcesHandler, + ReadResourceHandler = ReadResourceHandler, + SubscribeToResourcesHandler = SubscribeToResourcesHandler, + UnsubscribeFromResourcesHandler = UnsubscribeFromResourcesHandler, + } : + resourcesCapability with + { + ListResourcesHandler = ListResourcesHandler ?? resourcesCapability.ListResourcesHandler, + ReadResourceHandler = ReadResourceHandler ?? resourcesCapability.ReadResourceHandler, + SubscribeToResourcesHandler = SubscribeToResourcesHandler ?? resourcesCapability.SubscribeToResourcesHandler, + UnsubscribeFromResourcesHandler = UnsubscribeFromResourcesHandler ?? resourcesCapability.UnsubscribeFromResourcesHandler, + }; + } + + ToolsCapability? toolsCapability = options.Capabilities?.Tools; + if (ListToolsHandler is not null || CallToolHandler is not null) + { + toolsCapability = toolsCapability is null ? + new() + { + ListToolsHandler = ListToolsHandler, + CallToolHandler = CallToolHandler, + } : + toolsCapability with + { + ListToolsHandler = ListToolsHandler ?? toolsCapability.ListToolsHandler, + CallToolHandler = CallToolHandler ?? toolsCapability.CallToolHandler, + }; + } + + return options with + { + GetCompletionHandler = GetCompletionHandler ?? options.GetCompletionHandler, + Capabilities = options.Capabilities is null ? + new() + { + Prompts = promptsCapability, + Resources = resourcesCapability, + Tools = toolsCapability, + } : + options.Capabilities with + { + Prompts = promptsCapability, + Resources = resourcesCapability, + Tools = toolsCapability, + }, + }; + } +} diff --git a/src/mcpdotnet/Server/McpServerOptions.cs b/src/mcpdotnet/Server/McpServerOptions.cs index 76f5e1d..40bfa4f 100644 --- a/src/mcpdotnet/Server/McpServerOptions.cs +++ b/src/mcpdotnet/Server/McpServerOptions.cs @@ -1,5 +1,6 @@  using McpDotNet.Protocol.Types; +using System.Text.Json.Serialization; namespace McpDotNet.Server; @@ -34,4 +35,10 @@ public record McpServerOptions /// Optional server instructions to send to clients /// public string ServerInstructions { get; init; } = string.Empty; + + /// + /// Gets or sets the handler for get completion requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? GetCompletionHandler { get; init; } } diff --git a/src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs b/src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs index 85eb384..880668b 100644 --- a/src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs +++ b/src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs @@ -4,8 +4,10 @@ using McpDotNet.Logging; using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Transport; +using McpDotNet.Utils; using McpDotNet.Utils.Json; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Shared; @@ -24,7 +26,7 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable private readonly Dictionary>> _requestHandlers = []; private int _nextRequestId; private readonly JsonSerializerOptions _jsonOptions; - private readonly ILogger _logger; + private readonly ILogger _logger; private bool _isDisposed; /// @@ -32,19 +34,18 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable /// /// An MCP transport implementation. /// The logger factory. - protected McpJsonRpcEndpoint(ITransport transport, ILoggerFactory loggerFactory) + protected McpJsonRpcEndpoint(ITransport transport, ILoggerFactory? loggerFactory = null) { - if (transport is null) - { - throw new ArgumentNullException(nameof(transport)); - } + Throw.IfNull(transport); + + loggerFactory ??= NullLoggerFactory.Instance; _transport = transport; _pendingRequests = new(); _notificationHandlers = new(); _nextRequestId = 1; _jsonOptions = JsonSerializerOptionsExtensions.DefaultOptions; - _logger = loggerFactory.CreateLogger(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; } /// @@ -282,10 +283,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { - if (message is null) - { - throw new ArgumentNullException(nameof(message)); - } + Throw.IfNull(message); if (!_transport.IsConnected) { @@ -333,15 +331,8 @@ public async ValueTask DisposeAsync() /// Handler to be called when a request with specified method identifier is received protected void SetRequestHandler(string method, Func> handler) { - if (method is null) - { - throw new ArgumentNullException(nameof(method)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } + Throw.IfNull(method); + Throw.IfNull(handler); _requestHandlers[method] = async (request) => { diff --git a/src/mcpdotnet/Utils/Throw.cs b/src/mcpdotnet/Utils/Throw.cs new file mode 100644 index 0000000..5671759 --- /dev/null +++ b/src/mcpdotnet/Utils/Throw.cs @@ -0,0 +1,41 @@ +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace McpDotNet.Utils; + +/// Provides helper methods for throwing exceptions. +internal static class Throw +{ + // NOTE: Most of these should be replaced with extension statics for the relevant extension + // type as downlevel polyfills once the C# 14 extension everything feature is available. + + public static void IfNull([NotNull] object? arg, [CallerArgumentExpression(nameof(arg))] string? parameterName = null) + { + if (arg is null) + { + ThrowArgumentNullException(parameterName); + } + } + + public static void IfNullOrWhiteSpace([NotNull] string? arg, [CallerArgumentExpression(nameof(arg))] string? parameterName = null) + { + if (arg is null || arg.AsSpan().IsWhiteSpace()) + { + ThrowArgumentNullOrWhiteSpaceException(arg); + } + } + + [DoesNotReturn] + private static void ThrowArgumentNullOrWhiteSpaceException(string? parameterName) + { + if (parameterName is null) + { + ThrowArgumentNullException(parameterName); + } + + throw new ArgumentException("Value cannot be empty or composed entirely of whitespace.", parameterName); + } + + [DoesNotReturn] + private static void ThrowArgumentNullException(string? parameterName) => throw new ArgumentNullException(parameterName); +} diff --git a/tests/McpDotNet.Extensions.AI.Tests/IntegrationTests.cs b/tests/McpDotNet.Extensions.AI.Tests/IntegrationTests.cs index b88d6a8..8931b03 100644 --- a/tests/McpDotNet.Extensions.AI.Tests/IntegrationTests.cs +++ b/tests/McpDotNet.Extensions.AI.Tests/IntegrationTests.cs @@ -1,7 +1,6 @@ using McpDotNet.Client; using McpDotNet.Configuration; using Microsoft.Extensions.AI; -using Microsoft.Extensions.Logging.Abstractions; using OpenAI; namespace McpDotNet.Extensions.AI.Tests; @@ -35,19 +34,13 @@ private static async Task GetMcpClientAsync() ClientInfo = new() { Name = "McpDotNet.Extensions.AI.Tests", Version = "1.0.0" } }; - var factory = new McpClientFactory( - [GetEverythingServerConfig()], - options, - NullLoggerFactory.Instance - ); - - return await factory.GetClientAsync("everything"); + return await McpClientFactory.CreateAsync(GetEverythingServerConfig(), options); } [Fact] public async Task IntegrateWithMeai_UsingEverythingServer_ToolsAreProperlyCalled() { - var client = await GetMcpClientAsync(); + await using var client = await GetMcpClientAsync(); var mappedTools = await client.ListToolsAsync().Select(t => t.ToAITool(client)).ToListAsync(); IChatClient openaiClient = new OpenAIClient(_openAIKey) diff --git a/tests/mcpdotnet.TestServer/Program.cs b/tests/mcpdotnet.TestServer/Program.cs index 7d77bb2..bc6bd1b 100644 --- a/tests/mcpdotnet.TestServer/Program.cs +++ b/tests/mcpdotnet.TestServer/Program.cs @@ -34,24 +34,17 @@ private static async Task Main(string[] args) ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" }, Capabilities = new ServerCapabilities() { - Tools = new(), - Resources = new(), - Prompts = new(), + Tools = ConfigureTools(), + Resources = ConfigureResources(), + Prompts = ConfigurePrompts(), }, ProtocolVersion = "2024-11-05", ServerInstructions = "This is a test server with only stub functionality", + GetCompletionHandler = ConfigureCompletion(), }; var loggerFactory = CreateLoggerFactory(); - McpServerFactory factory = new(new StdioServerTransport("TestServer", loggerFactory), options, loggerFactory); - IMcpServer server = factory.CreateServer(); - - Log.Logger.Information("Server object created, registering handlers."); - - ConfigureTools(server); - ConfigureResources(server); - ConfigurePrompts(server); - ConfigureCompletion(server); + await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("TestServer", loggerFactory), options, loggerFactory); Log.Logger.Information("Server initialized."); @@ -60,184 +53,187 @@ private static async Task Main(string[] args) Log.Logger.Information("Server started."); // Run until process is stopped by the client (parent process) - while (true) // NOSONAR - { - await Task.Delay(1000); - } + await Task.Delay(Timeout.Infinite); } - private static void ConfigureTools(IMcpServer server) + private static ToolsCapability ConfigureTools() { - server.SetListToolsHandler((request, cancellationToken) => + return new() { - return Task.FromResult(new ListToolsResult() + ListToolsHandler = (request, cancellationToken) => { - Tools = - [ - new Tool() - { - Name = "echo", - Description = "Echoes the input back to the client.", - InputSchema = new JsonSchema() + return Task.FromResult(new ListToolsResult() + { + Tools = + [ + new Tool() { - Type = "object", - Properties = new Dictionary() + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = new JsonSchema() { - ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } - } + Type = "object", + Properties = new Dictionary() + { + ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + } + }, }, - }, - new Tool() - { - Name = "sampleLLM", - Description = "Samples from an LLM using MCP's sampling feature.", - InputSchema = new JsonSchema() + new Tool() { - Type = "object", - Properties = new Dictionary() + Name = "sampleLLM", + Description = "Samples from an LLM using MCP's sampling feature.", + InputSchema = new JsonSchema() { - ["prompt"] = new JsonSchemaProperty() { Type = "string", Description = "The prompt to send to the LLM" }, - ["maxTokens"] = new JsonSchemaProperty() { Type = "number", Description = "Maximum number of tokens to generate" } - } - }, - } - ] - }); - }); + Type = "object", + Properties = new Dictionary() + { + ["prompt"] = new JsonSchemaProperty() { Type = "string", Description = "The prompt to send to the LLM" }, + ["maxTokens"] = new JsonSchemaProperty() { Type = "number", Description = "Maximum number of tokens to generate" } + } + }, + } + ] + }); + }, - server.SetCallToolHandler(async (request, cancellationToken) => - { - if (request.Params?.Name == "echo") + CallToolHandler = async (request, cancellationToken) => { - if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) + if (request.Params?.Name == "echo") { - throw new McpServerException("Missing required argument 'message'"); + if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) + { + throw new McpServerException("Missing required argument 'message'"); + } + return new CallToolResponse() + { + Content = [new Content() { Text = "Echo: " + message?.ToString(), Type = "text" }] + }; } - return new CallToolResponse() - { - Content = [new Content() { Text = "Echo: " + message?.ToString(), Type = "text" }] - }; - } - else if (request.Params?.Name == "sampleLLM") - { - if (request.Params?.Arguments is null || - !request.Params.Arguments.TryGetValue("prompt", out var prompt) || - !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) + else if (request.Params?.Name == "sampleLLM") { - throw new McpServerException("Missing required arguments 'prompt' and 'maxTokens'"); - } - var sampleResult = await server.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), - cancellationToken); + if (request.Params?.Arguments is null || + !request.Params.Arguments.TryGetValue("prompt", out var prompt) || + !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) + { + throw new McpServerException("Missing required arguments 'prompt' and 'maxTokens'"); + } + var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), + cancellationToken); - return new CallToolResponse() + return new CallToolResponse() + { + Content = [new Content() { Text = $"LLM sampling result: {sampleResult.Content.Text}", Type = "text" }] + }; + } + else { - Content = [new Content() { Text = $"LLM sampling result: {sampleResult.Content.Text}", Type = "text" }] - }; - } - else - { - throw new McpServerException($"Unknown tool: {request.Params?.Name}"); + throw new McpServerException($"Unknown tool: {request.Params?.Name}"); + } } - }); + }; } - private static void ConfigurePrompts(IMcpServer server) + private static PromptsCapability ConfigurePrompts() { - server.SetListPromptsHandler((request, cancellationToken) => - { - return Task.FromResult(new ListPromptsResult() - { - Prompts = [ - new Prompt() - { - Name = "simple_prompt", - Description = "A prompt without arguments" - }, - new Prompt() - { - Name = "complex_prompt", - Description = "A prompt with arguments", - Arguments = - [ - new PromptArgument() - { - Name = "temperature", - Description = "Temperature setting", - Required = true - }, - new PromptArgument() - { - Name = "style", - Description = "Output style", - Required = false - } - ] - } - ] - }); - }); - - server.SetGetPromptHandler((request, cancellationToken) => + return new() { - List messages = new(); - if (request.Params?.Name == "simple_prompt") + ListPromptsHandler = (request, cancellationToken) => { - messages.Add(new PromptMessage() + return Task.FromResult(new ListPromptsResult() { - Role = Role.User, - Content = new Content() - { - Type = "text", - Text = "This is a simple prompt without arguments." - } + Prompts = [ + new Prompt() + { + Name = "simple_prompt", + Description = "A prompt without arguments" + }, + new Prompt() + { + Name = "complex_prompt", + Description = "A prompt with arguments", + Arguments = + [ + new PromptArgument() + { + Name = "temperature", + Description = "Temperature setting", + Required = true + }, + new PromptArgument() + { + Name = "style", + Description = "Output style", + Required = false + } + ] + } + ] }); - } - else if (request.Params?.Name == "complex_prompt") + }, + + GetPromptHandler = (request, cancellationToken) => { - string temperature = request.Params.Arguments?["temperature"]?.ToString() ?? "unknown"; - string style = request.Params.Arguments?["style"]?.ToString() ?? "unknown"; - messages.Add(new PromptMessage() + List messages = []; + if (request.Params?.Name == "simple_prompt") { - Role = Role.User, - Content = new Content() + messages.Add(new PromptMessage() { - Type = "text", - Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" - } - }); - messages.Add(new PromptMessage() + Role = Role.User, + Content = new Content() + { + Type = "text", + Text = "This is a simple prompt without arguments." + } + }); + } + else if (request.Params?.Name == "complex_prompt") { - Role = Role.Assistant, - Content = new Content() + string temperature = request.Params.Arguments?["temperature"]?.ToString() ?? "unknown"; + string style = request.Params.Arguments?["style"]?.ToString() ?? "unknown"; + messages.Add(new PromptMessage() { - Type = "text", - Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" - } - }); - messages.Add(new PromptMessage() - { - Role = Role.User, - Content = new Content() + Role = Role.User, + Content = new Content() + { + Type = "text", + Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" + } + }); + messages.Add(new PromptMessage() { - Type = "image", - Data = MCP_TINY_IMAGE, - MimeType = "image/png" - } + Role = Role.Assistant, + Content = new Content() + { + Type = "text", + Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" + } + }); + messages.Add(new PromptMessage() + { + Role = Role.User, + Content = new Content() + { + Type = "image", + Data = MCP_TINY_IMAGE, + MimeType = "image/png" + } + }); + } + else + { + throw new McpServerException($"Unknown prompt: {request.Params?.Name}"); + } + + return Task.FromResult(new GetPromptResult() + { + Messages = messages }); } - else - { - throw new McpServerException($"Unknown prompt: {request.Params?.Name}"); - } - - return Task.FromResult(new GetPromptResult() - { - Messages = messages - }); - }); + }; } - private static void ConfigureResources(IMcpServer server) + private static ResourcesCapability ConfigureResources() { List resources = []; List resourceContents = []; @@ -279,54 +275,56 @@ private static void ConfigureResources(IMcpServer server) const int pageSize = 10; - server.SetListResourcesHandler((request, cancellationToken) => + return new() { - int startIndex = 0; - request ??= new(server, new()); - if (request.Params?.Cursor is not null) + ListResourcesHandler = (request, cancellationToken) => { - try + int startIndex = 0; + if (request.Params?.Cursor is not null) { - var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(request.Params.Cursor)); - startIndex = Convert.ToInt32(startIndexAsString); + try + { + var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(request.Params.Cursor)); + startIndex = Convert.ToInt32(startIndexAsString); + } + catch + { + throw new McpServerException("Invalid cursor"); + } } - catch + + int endIndex = Math.Min(startIndex + pageSize, resources.Count); + string? nextCursor = null; + + if (endIndex < resources.Count) { - throw new McpServerException("Invalid cursor"); + nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); } - } - - int endIndex = Math.Min(startIndex + pageSize, resources.Count); - string? nextCursor = null; + return Task.FromResult(new ListResourcesResult() + { + NextCursor = nextCursor, + Resources = resources.GetRange(startIndex, endIndex - startIndex) + }); + }, - if (endIndex < resources.Count) - { - nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); - } - return Task.FromResult(new ListResourcesResult() + ReadResourceHandler = (request, cancellationToken) => { - NextCursor = nextCursor, - Resources = resources.GetRange(startIndex, endIndex - startIndex) - }); - }); + if (request.Params?.Uri is null) + { + throw new McpServerException("Missing required argument 'uri'"); + } + ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) + ?? throw new McpServerException("Resource not found"); - server.SetReadResourceHandler((request, cancellationToken) => - { - if (request.Params?.Uri is null) - { - throw new McpServerException("Missing required argument 'uri'"); + return Task.FromResult(new ReadResourceResult() + { + Contents = [contents] + }); } - ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) - ?? throw new McpServerException("Resource not found"); - - return Task.FromResult(new ReadResourceResult() - { - Contents = [contents] - }); - }); + }; } - private static void ConfigureCompletion(IMcpServer server) + private static Func, CancellationToken, Task> ConfigureCompletion() { List sampleResourceIds = ["1", "2", "3", "4", "5"]; Dictionary> exampleCompletions = new() @@ -335,7 +333,7 @@ private static void ConfigureCompletion(IMcpServer server) {"temperature", ["0", "0.5", "0.7", "1.0"]}, }; - server.SetGetCompletionHandler((request, cancellationToken) => + return (request, cancellationToken) => { if (request.Params?.Ref?.Type == "ref/resource") { @@ -344,7 +342,7 @@ private static void ConfigureCompletion(IMcpServer server) return Task.FromResult(new CompleteResult() { Completion = new() { Values = [] } }); // Filter resource IDs that start with the input value - var values = sampleResourceIds.Where(id => id.StartsWith(request.Params.Argument.Value)).ToArray(); + var values = sampleResourceIds.Where(id => id.StartsWith(request.Params!.Argument.Value)).ToArray(); return Task.FromResult(new CompleteResult() { Completion = new() { Values = values, HasMore = false, Total = values.Length } }); } @@ -360,7 +358,7 @@ private static void ConfigureCompletion(IMcpServer server) } throw new McpServerException($"Unknown reference type: {request.Params?.Ref.Type}"); - }); + }; } static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) diff --git a/tests/mcpdotnet.Tests/Client/McpClientFactoryTests.cs b/tests/mcpdotnet.Tests/Client/McpClientFactoryTests.cs index 3c15020..e76f90a 100644 --- a/tests/mcpdotnet.Tests/Client/McpClientFactoryTests.cs +++ b/tests/mcpdotnet.Tests/Client/McpClientFactoryTests.cs @@ -3,8 +3,7 @@ using McpDotNet.Configuration; using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Transport; -using Microsoft.Extensions.Logging.Abstractions; -using Moq; +using McpDotNet.Protocol.Types; namespace McpDotNet.Tests.Client; @@ -16,10 +15,43 @@ public class McpClientFactoryTests }; [Fact] - public async Task GetClientAsync_WithValidStdioConfig_CreatesNewClient() + public async Task CreateAsync_WithInvalidArgs_Throws() + { + await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, _defaultOptions)); + + await Assert.ThrowsAsync("clientOptions", () => McpClientFactory.CreateAsync( + new McpServerConfig() + { + Name = "name", + Id = "id", + TransportType = TransportTypes.StdIo, + }, (McpClientOptions)null!)); + + await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync( + new McpServerConfig() + { + Name = "name", + Id = "id", + TransportType = "somethingunsupported", + }, + _defaultOptions)); + + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync( + new McpServerConfig() + { + Name = "name", + Id = "id", + TransportType = TransportTypes.StdIo, + }, + _defaultOptions, + (_, __) => null!)); + } + + [Fact] + public async Task CreateAsync_WithValidStdioConfig_CreatesNewClient() { // Arrange - var config = new McpServerConfig + var serverConfig = new McpServerConfig { Id = "test-server", Name = "Test Server", @@ -32,30 +64,11 @@ public async Task GetClientAsync_WithValidStdioConfig_CreatesNewClient() } }; - // Create a mock transport - var mockTransport = new Mock(); - mockTransport.Setup(t => t.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockTransport.Setup(t => t.IsConnected).Returns(true); - mockTransport.Setup(t => t.MessageReader).Returns(Mock.Of>()); - - // Create a mock client - var mockClient = new Mock(); - mockClient.Setup(c => c.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockClient.Setup(c => c.IsInitialized).Returns(true); - - // Inject the mock transport into the factory - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance, - transportFactoryMethod: _ => mockTransport.Object, - clientFactoryMethod: (_, _, _) => mockClient.Object - ); - // Act - var client = await factory.GetClientAsync("test-server"); + var client = await McpClientFactory.CreateAsync( + serverConfig, + _defaultOptions, + (_, __) => new NopTransport()); // Assert Assert.NotNull(client); @@ -63,137 +76,33 @@ public async Task GetClientAsync_WithValidStdioConfig_CreatesNewClient() } [Fact] - public async Task GetClientAsync_CalledTwice_ReturnsSameInstance() + public async Task CreateAsync_WithNoTransportOptions_CreatesNewClient() { // Arrange - var config = new McpServerConfig + var serverConfig = new McpServerConfig { Id = "test-server", Name = "Test Server", TransportType = TransportTypes.StdIo, - Location = "/path/to/server" + Location = "/path/to/server", }; - // Create a mock transport - var mockTransport = new Mock(); - mockTransport.Setup(t => t.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockTransport.Setup(t => t.IsConnected).Returns(true); - mockTransport.Setup(t => t.MessageReader).Returns(Mock.Of>()); - - // Create a mock client - var mockClient = new Mock(); - mockClient.Setup(c => c.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockClient.Setup(c => c.IsInitialized).Returns(true); - - using var factory = new McpClientFactory([config], - _defaultOptions, - NullLoggerFactory.Instance, - transportFactoryMethod: _ => mockTransport.Object, - clientFactoryMethod: (_, _, _) => mockClient.Object); - // Act - var client1 = await factory.GetClientAsync("test-server"); - var client2 = await factory.GetClientAsync("test-server"); - - // Assert - Assert.Same(client1, client2); - } - - [Fact] - public async Task GetClientAsync_WithInvalidServerId_ThrowsArgumentException() - { - // Arrange - using var factory = new McpClientFactory([], + var client = await McpClientFactory.CreateAsync( + serverConfig, _defaultOptions, - NullLoggerFactory.Instance); - - // Act & Assert - await Assert.ThrowsAsync( - () => factory.GetClientAsync("non-existent-server") - ); - } - - [Fact] - public async Task GetClientAsync_WithUnsupportedTransport_ThrowsArgumentException() - { - // Arrange - var config = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = "unsupported", - Location = "/path/to/server" - }; - - using var factory = new McpClientFactory([config], _defaultOptions, - NullLoggerFactory.Instance); - - // Act & Assert - await Assert.ThrowsAsync( - () => factory.GetClientAsync("test-server") - ); - } - - [Fact] - public async Task GetClientAsync_WithNoTransportOptions_CreatesClientWithDefaults() - { - // Arrange - var config = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.StdIo, - Location = "/path/to/server" - }; - - // Create a mock transport - var mockTransport = new Mock(); - mockTransport.Setup(t => t.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockTransport.Setup(t => t.IsConnected).Returns(true); - mockTransport.Setup(t => t.MessageReader).Returns(Mock.Of>()); - - // Create a mock client - var mockClient = new Mock(); - mockClient.Setup(c => c.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockClient.Setup(c => c.IsInitialized).Returns(true); - - using var factory = new McpClientFactory([config], - _defaultOptions, - NullLoggerFactory.Instance, - transportFactoryMethod: _ => mockTransport.Object, - clientFactoryMethod: (_, _, _) => mockClient.Object); - - // Act - var client = await factory.GetClientAsync("test-server"); + (_, __) => new NopTransport()); // Assert Assert.NotNull(client); + // We could add more assertions here about the client's configuration } [Fact] - public void Constructor_WithDuplicateServerIds_ThrowsArgumentException() - { - // Arrange - McpServerConfig[] configs = - [ - new McpServerConfig { Id = "duplicate", Name = "duplicate", TransportType = TransportTypes.StdIo, Location = "/path1" }, - new McpServerConfig { Id = "duplicate", Name = "duplicate", TransportType = TransportTypes.StdIo, Location = "/path2" } - ]; - - // Act & Assert - Assert.Throws(() => new McpClientFactory(configs, _defaultOptions, - NullLoggerFactory.Instance)); - } - - [Fact] - public async Task GetClientAsync_WithSseTransport_CanCreateClient() + public async Task CreateAsync_WithValidSseConfig_CreatesNewClient() { // Arrange - var config = new McpServerConfig + var serverConfig = new McpServerConfig { Id = "test-server", Name = "Test Server", @@ -201,40 +110,22 @@ public async Task GetClientAsync_WithSseTransport_CanCreateClient() Location = "http://localhost:8080" }; - // Create a mock transport - var mockTransport = new Mock(); - mockTransport.Setup(t => t.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockTransport.Setup(t => t.IsConnected).Returns(true); - mockTransport.Setup(t => t.MessageReader).Returns(Mock.Of>()); - - // Create a mock client - var mockClient = new Mock(); - mockClient.Setup(c => c.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockClient.Setup(c => c.IsInitialized).Returns(true); - - // Inject the mock transport into the factory - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance, - transportFactoryMethod: _ => mockTransport.Object, - clientFactoryMethod: (_, _, _) => mockClient.Object - ); - // Act - var client = await factory.GetClientAsync("test-server"); + var client = await McpClientFactory.CreateAsync( + serverConfig, + _defaultOptions, + (_, __) => new NopTransport()); // Assert Assert.NotNull(client); + // We could add more assertions here about the client's configuration } [Fact] - public void McpFactory_WithSse_CreatesCorrectTransportOptions() + public async Task CreateAsync_WithSse_CreatesCorrectTransportOptions() { // Arrange - var config = new McpServerConfig + var serverConfig = new McpServerConfig { Id = "test-server", Name = "Test Server", @@ -250,58 +141,23 @@ public void McpFactory_WithSse_CreatesCorrectTransportOptions() }; // Act - using var factory = new McpClientFactory( - [config], + var client = await McpClientFactory.CreateAsync( + serverConfig, _defaultOptions, - NullLoggerFactory.Instance - ); - - var transport = factory.TransportFactoryMethod(config) as SseClientTransport; + (_, __) => new NopTransport()); // Assert - Assert.NotNull(transport); - Assert.Equal(TimeSpan.FromSeconds(10), transport.Options.ConnectionTimeout); - Assert.Equal(2, transport.Options.MaxReconnectAttempts); - Assert.Equal(TimeSpan.FromSeconds(5), transport.Options.ReconnectDelay); - Assert.NotNull(transport.Options.AdditionalHeaders); - Assert.Equal("the_header_value", transport.Options.AdditionalHeaders["test"]); - } - - [Fact] - public void McpFactory_WithSseAndNoOptions_CreatesDefaultTransportOptions() - { - // Arrange - var config = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.Sse, - Location = "http://localhost:8080" - }; - - var defaultOptions = new SseClientTransportOptions(); - - // Act - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance - ); - - var transport = factory.TransportFactoryMethod(config) as SseClientTransport; - - // Assert - Assert.NotNull(transport); - Assert.True(transport.Options.ConnectionTimeout == defaultOptions.ConnectionTimeout); - Assert.True(transport.Options.MaxReconnectAttempts == defaultOptions.MaxReconnectAttempts); - Assert.True(transport.Options.ReconnectDelay == defaultOptions.ReconnectDelay); - Assert.True(transport.Options.AdditionalHeaders == null && defaultOptions.AdditionalHeaders == null); + Assert.NotNull(client); + // We could add more assertions here about the client's configuration } - [Fact] - public void McpFactory_WithSseAndMissingOptions_CreatesCorrectTransportOptions() + [Theory] + [InlineData("connectionTimeout", "not_a_number")] + [InlineData("maxReconnectAttempts", "invalid")] + [InlineData("reconnectDelay", "bad_value")] + public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(string key, string value) { - // Arrange + // arrange var config = new McpServerConfig { Id = "test-server", @@ -310,59 +166,50 @@ public void McpFactory_WithSseAndMissingOptions_CreatesCorrectTransportOptions() Location = "http://localhost:8080", TransportOptions = new Dictionary { - ["connectionTimeout"] = "10", - ["header.test"] = "the_header_value" + [key] = value } }; - var defaultOptions = new SseClientTransportOptions(); + // act & assert + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions)); + } - // Act - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance - ); + private sealed class NopTransport : IClientTransport + { + private readonly Channel _channel = Channel.CreateUnbounded(); - var transport = factory.TransportFactoryMethod(config) as SseClientTransport; + public bool IsConnected => true; - // Assert - Assert.NotNull(transport); - Assert.Equal(TimeSpan.FromSeconds(10), transport.Options.ConnectionTimeout); - Assert.Equal(defaultOptions.MaxReconnectAttempts, transport.Options.MaxReconnectAttempts); - Assert.Equal(defaultOptions.ReconnectDelay, transport.Options.ReconnectDelay); - Assert.NotNull(transport.Options.AdditionalHeaders); - Assert.Equal("the_header_value", transport.Options.AdditionalHeaders["test"]); - } + public ChannelReader MessageReader => _channel.Reader; - [Theory] - [InlineData("connectionTimeout", "not_a_number")] - [InlineData("maxReconnectAttempts", "invalid")] - [InlineData("reconnectDelay", "bad_value")] - public void McpFactory_WithInvalidTransportOptions_ThrowsFormatException(string key, string value) - { - // arrange - var config = new McpServerConfig + public Task ConnectAsync(CancellationToken cancellationToken = default) => + Task.CompletedTask; + + public ValueTask DisposeAsync() => default; + + public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.Sse, - Location = "http://localhost:8080", - TransportOptions = new Dictionary + switch (message) { - [key] = value + case JsonRpcRequest request: + _channel.Writer.TryWrite(new JsonRpcResponse + { + Id = ((JsonRpcRequest)message).Id, + Result = new InitializeResult() + { + Capabilities = new ServerCapabilities(), + ProtocolVersion = "2024-11-05", + ServerInfo = new Implementation() + { + Name = "NopTransport", + Version = "1.0.0" + }, + } + }); + break; } - }; - // Act - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance - ); - - // act & assert - Assert.Throws(() => - factory.TransportFactoryMethod(config)); + return Task.CompletedTask; + } } } diff --git a/tests/mcpdotnet.Tests/ClientIntegrationTestFixture.cs b/tests/mcpdotnet.Tests/ClientIntegrationTestFixture.cs index 1ec22ff..2c271c7 100644 --- a/tests/mcpdotnet.Tests/ClientIntegrationTestFixture.cs +++ b/tests/mcpdotnet.Tests/ClientIntegrationTestFixture.cs @@ -8,8 +8,11 @@ namespace McpDotNet.Tests; public class ClientIntegrationTestFixture : IDisposable { public ILoggerFactory LoggerFactory { get; } - public McpClientFactory Factory { get; } public McpClientOptions DefaultOptions { get; } + public McpServerConfig EverythingServerConfig { get; } + public McpServerConfig TestServerConfig { get; } + + public static IEnumerable ClientIds => ["everything", "test_server"]; public ClientIntegrationTestFixture() { @@ -20,10 +23,9 @@ public ClientIntegrationTestFixture() DefaultOptions = new() { ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }, - Capabilities = new() { Sampling = new(), Roots = new() } }; - var everythingServerConfig = new McpServerConfig + EverythingServerConfig = new() { Id = "everything", Name = "Everything", @@ -36,7 +38,7 @@ public ClientIntegrationTestFixture() } }; - var testServerConfig = new McpServerConfig + TestServerConfig = new() { Id = "test_server", Name = "TestServer", @@ -49,19 +51,21 @@ public ClientIntegrationTestFixture() }; if (!OperatingSystem.IsWindows()) - testServerConfig.TransportOptions["arguments"] = "TestServer.dll"; - - // Inject the mock transport into the factory - Factory = new McpClientFactory( - [everythingServerConfig, testServerConfig], - DefaultOptions, - LoggerFactory - ); + { + TestServerConfig.TransportOptions["arguments"] = "TestServer.dll"; + } } + public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => + McpClientFactory.CreateAsync(clientId switch + { + "everything" => EverythingServerConfig, + "test_server" => TestServerConfig, + _ => throw new ArgumentException($"Unknown client ID: {clientId}") + }, clientOptions ?? DefaultOptions, loggerFactory: LoggerFactory); + public void Dispose() { - Factory?.Dispose(); LoggerFactory?.Dispose(); GC.SuppressFinalize(this); } diff --git a/tests/mcpdotnet.Tests/ClientIntegrationTests.cs b/tests/mcpdotnet.Tests/ClientIntegrationTests.cs index cb0a728..7e293e5 100644 --- a/tests/mcpdotnet.Tests/ClientIntegrationTests.cs +++ b/tests/mcpdotnet.Tests/ClientIntegrationTests.cs @@ -15,11 +15,8 @@ public ClientIntegrationTests(ClientIntegrationTestFixture fixture) _fixture = fixture; } - public static IEnumerable GetClients() - { - yield return ["everything"]; - yield return ["test_server"]; - } + public static IEnumerable GetClients() => + ClientIntegrationTestFixture.ClientIds.Select(id => new object[] { id }); [Theory] [MemberData(nameof(GetClients))] @@ -28,7 +25,7 @@ public async Task ConnectAndPing_Stdio(string clientId) // Arrange // Act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); await client.PingAsync(CancellationToken.None); // Assert @@ -42,7 +39,7 @@ public async Task Connect_ShouldProvideServerFields(string clientId) // Arrange // Act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); // Assert Assert.NotNull(client.ServerCapabilities); @@ -58,7 +55,7 @@ public async Task ListTools_Stdio(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var tools = await client.ListToolsAsync().ToListAsync(); // assert @@ -73,7 +70,7 @@ public async Task CallTool_Stdio_EchoServer(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.CallToolAsync( "echo", new Dictionary @@ -97,7 +94,7 @@ public async Task ListPrompts_Stdio(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var prompts = await client.ListPromptsAsync().ToListAsync(); // assert @@ -114,7 +111,7 @@ public async Task GetPrompt_Stdio_SimplePrompt(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.GetPromptAsync("simple_prompt", null, CancellationToken.None); // assert @@ -129,7 +126,7 @@ public async Task GetPrompt_Stdio_ComplexPrompt(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var arguments = new Dictionary { { "temperature", "0.7" }, @@ -149,7 +146,7 @@ public async Task GetPrompt_NonExistent_ThrowsException(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); await Assert.ThrowsAsync(() => client.GetPromptAsync("non_existent_prompt", null, CancellationToken.None)); } @@ -161,7 +158,7 @@ public async Task ListResources_Stdio(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); List allResources = []; string? cursor = null; @@ -184,7 +181,7 @@ public async Task ReadResource_Stdio_TextResource(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); // Odd numbered resources are text in the everything server (despite the docs saying otherwise) // 1 is index 0, which is "even" in the 0-based index var result = await client.ReadResourceAsync("test://static/resource/1", CancellationToken.None); @@ -201,7 +198,7 @@ public async Task ReadResource_Stdio_BinaryResource(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); // Even numbered resources are binary in the everything server (despite the docs saying otherwise) // 2 is index 1, which is "odd" in the 0-based index var result = await client.ReadResourceAsync("test://static/resource/2", CancellationToken.None); @@ -218,7 +215,7 @@ public async Task GetCompletion_Stdio_ResourceReference(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.GetCompletionAsync(new Reference { Type = "ref/resource", @@ -240,7 +237,7 @@ public async Task GetCompletion_Stdio_PromptReference(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.GetCompletionAsync(new Reference { Type = "ref/prompt", @@ -258,24 +255,32 @@ public async Task GetCompletion_Stdio_PromptReference(string clientId) [Theory] [MemberData(nameof(GetClients))] public async Task Sampling_Stdio(string clientId) - { - var client = await _fixture.Factory.GetClientAsync(clientId); - + { // Set up the sampling handler int samplingHandlerCalls = 0; - client.SetSamplingHandler((_, _) => + await using var client = await _fixture.CreateClientAsync(clientId, new() { - samplingHandlerCalls++; - return Task.FromResult(new CreateMessageResult + ClientInfo = new() { Name = "Sampling_Stdio", Version = "1.0.0" }, + Capabilities = new() { - Model = "test-model", - Role = "assistant", - Content = new Content + Sampling = new() { - Type = "text", - Text = "Test response" - } - }); + SamplingHandler = (_, _) => + { + samplingHandlerCalls++; + return Task.FromResult(new CreateMessageResult + { + Model = "test-model", + Role = "assistant", + Content = new Content + { + Type = "text", + Text = "Test response" + } + }); + }, + }, + }, }); // Call the server's sampleLLM tool which should trigger our sampling handler @@ -306,7 +311,7 @@ public async Task Sampling_Stdio(string clientId) // new() { Uri = "file:///test/root2", Name = "Test Root 2" } // }; - // var client = await _fixture.Factory.GetClientAsync(clientId); + // await using var client = await _fixture.Factory.GetClientAsync(clientId); // // Set up the roots handler // client.SetRootsHandler((request, ct) => @@ -330,9 +335,7 @@ public async Task Sampling_Stdio(string clientId) [MemberData(nameof(GetClients))] public async Task Notifications_Stdio(string clientId) { - var client = await _fixture.Factory.GetClientAsync(clientId); - - await client.ConnectAsync(); + await using var client = await _fixture.CreateClientAsync(clientId); // Verify we can send notifications without errors await client.SendNotificationAsync(NotificationMethods.RootsUpdatedNotification); @@ -347,7 +350,7 @@ public async Task Notifications_Stdio(string clientId) public async Task CallTool_Stdio_MemoryServer() { // arrange - var config = new McpServerConfig + McpServerConfig serverConfig = new() { Id = "memory", Name = "memory", @@ -359,21 +362,17 @@ public async Task CallTool_Stdio_MemoryServer() } }; - var options = new McpClientOptions + McpClientOptions clientOptions = new() { ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - using var factory = new McpClientFactory([config], options, _fixture.LoggerFactory); - var client = await factory.GetClientAsync("memory"); - - await client.ConnectAsync(); + await using var client = await McpClientFactory.CreateAsync(serverConfig, clientOptions, loggerFactory: _fixture.LoggerFactory); // act var result = await client.CallToolAsync( "read_graph", - [], - CancellationToken.None + [] ); // assert diff --git a/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs b/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs index bce2560..eeafa67 100644 --- a/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs +++ b/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs @@ -27,7 +27,7 @@ public void WithListToolsHandler_Sets_Handler() _builder.Object.WithListToolsHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.ListToolsHandler); } @@ -40,7 +40,7 @@ public void WithCallToolHandler_Sets_Handler() _builder.Object.WithCallToolHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.CallToolHandler); } @@ -53,7 +53,7 @@ public void WithListPromptsHandler_Sets_Handler() _builder.Object.WithListPromptsHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.ListPromptsHandler); } @@ -66,7 +66,7 @@ public void WithGetPromptHandler_Sets_Handler() _builder.Object.WithGetPromptHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.GetPromptHandler); } @@ -79,7 +79,7 @@ public void WithListResourcesHandler_Sets_Handler() _builder.Object.WithListResourcesHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.ListResourcesHandler); } @@ -92,7 +92,7 @@ public void WithReadResourceHandler_Sets_Handler() _builder.Object.WithReadResourceHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.ReadResourceHandler); } @@ -105,7 +105,7 @@ public void WithGetCompletionHandler_Sets_Handler() _builder.Object.WithGetCompletionHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.GetCompletionHandler); } @@ -118,7 +118,7 @@ public void WithSubscribeToResourcesHandler_Sets_Handler() _builder.Object.WithSubscribeToResourcesHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.SubscribeToResourcesHandler); } @@ -131,7 +131,7 @@ public void WithUnsubscribeFromResourcesHandler_Sets_Handler() _builder.Object.WithUnsubscribeFromResourcesHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.UnsubscribeFromResourcesHandler); } diff --git a/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 6d8b4f9..eaacf09 100644 --- a/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -26,7 +26,7 @@ public void Adds_Tools_To_Server() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.NotNull(options.ListToolsHandler); Assert.NotNull(options.CallToolHandler); @@ -38,7 +38,7 @@ public async Task Can_List_Registered_Tool() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); Assert.NotNull(result); @@ -69,7 +69,7 @@ public async Task Can_Call_Registered_Tool() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "Echo", Arguments = new() { { "message", "Peter" } } }), CancellationToken.None); Assert.NotNull(result); @@ -86,7 +86,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "EchoArray", Arguments = new() { { "message", "Peter" } } }), CancellationToken.None); Assert.NotNull(result); @@ -103,7 +103,7 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnNull" }), CancellationToken.None); Assert.NotNull(result); @@ -120,7 +120,7 @@ public async Task Can_Call_Registered_Tool_With_Json_Result() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnJson" }), CancellationToken.None); Assert.NotNull(result); @@ -137,7 +137,7 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnInteger" }), CancellationToken.None); Assert.NotNull(result); @@ -154,7 +154,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_Cancellation_Token() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; using var cts = new CancellationTokenSource(); var token = cts.Token; @@ -173,7 +173,7 @@ public async Task Can_Call_Registered_Tool_And_Returns_Cancelled_Response() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; using var cts = new CancellationTokenSource(); var token = cts.Token; @@ -190,7 +190,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "EchoComplex", Arguments = new() { { "complex", JsonDocument.Parse("{\"Name\": \"Peter\", \"Age\": 25}").RootElement } } }), CancellationToken.None); Assert.NotNull(result); @@ -207,7 +207,7 @@ public async Task Throws_Exception_When_Tool_Fails() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var action = async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnError" }), CancellationToken.None); @@ -228,7 +228,7 @@ public async Task Can_Call_Registered_Tool_With_Dependency_Injection() _builder.Object.WithTool(); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var mcpServer = new Mock(); mcpServer.SetupGet(s => s.ServiceProvider).Returns(serviceProvider); @@ -249,7 +249,7 @@ public async Task Throws_Exception_On_Unknown_Tool() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var exception = await Assert.ThrowsAsync(async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "NotRegisteredTool" }), CancellationToken.None)); Assert.Equal("Unknown tool: NotRegisteredTool", exception.Message); @@ -261,7 +261,7 @@ public async Task Throws_Exception_Missing_Parameter() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var exception = await Assert.ThrowsAsync(async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "Echo" }), CancellationToken.None)); Assert.Equal("Missing required argument 'message'.", exception.Message); @@ -291,7 +291,7 @@ public async Task Register_Tools_From_Current_Assembly() _builder.Object.WithTools(); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); Assert.NotNull(result); @@ -316,7 +316,7 @@ public async Task Recognizes_Parameter_Types() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); Assert.NotNull(result); diff --git a/tests/mcpdotnet.Tests/Server/McpServerDelegatesTests.cs b/tests/mcpdotnet.Tests/Server/McpServerDelegatesTests.cs index 95f5891..aa8d37b 100644 --- a/tests/mcpdotnet.Tests/Server/McpServerDelegatesTests.cs +++ b/tests/mcpdotnet.Tests/Server/McpServerDelegatesTests.cs @@ -1,54 +1,43 @@ -using McpDotNet.Protocol.Messages; -using McpDotNet.Protocol.Types; +using McpDotNet.Protocol.Types; using McpDotNet.Server; namespace McpDotNet.Tests.Server; -public class McpServerDelegatesTests +public class McpServerHandlerTests { [Fact] - public void Applies_All_Given_Delegates() + public void AllPropertiesAreSettable() { - var container = new McpServerDelegates(); - var server = new ExposeSetHandlersServer(); - - container.ListToolsHandler = (p, c) => Task.FromResult(new ListToolsResult()); - container.CallToolHandler = (p, c) => Task.FromResult(new CallToolResponse()); - container.ListPromptsHandler = (p, c) => Task.FromResult(new ListPromptsResult()); - container.GetPromptHandler = (p, c) => Task.FromResult(new GetPromptResult()); - container.ListResourcesHandler = (p, c) => Task.FromResult(new ListResourcesResult()); - container.ReadResourceHandler = (p, c) => Task.FromResult(new ReadResourceResult()); - container.GetCompletionHandler = (p, c) => Task.FromResult(new CompleteResult()); - container.SubscribeToResourcesHandler = (s, c) => Task.CompletedTask; - container.UnsubscribeFromResourcesHandler = (s, c) => Task.CompletedTask; - - container.Apply(server); - - Assert.Equal(container.ListToolsHandler, server.Handlers[OperationNames.ListTools]); - Assert.Equal(container.CallToolHandler, server.Handlers[OperationNames.CallTool]); - Assert.Equal(container.ListPromptsHandler, server.Handlers[OperationNames.ListPrompts]); - Assert.Equal(container.GetPromptHandler, server.Handlers[OperationNames.GetPrompt]); - Assert.Equal(container.ListResourcesHandler, server.Handlers[OperationNames.ListResources]); - Assert.Equal(container.ReadResourceHandler, server.Handlers[OperationNames.ReadResource]); - Assert.Equal(container.GetCompletionHandler, server.Handlers[OperationNames.GetCompletion]); - Assert.Equal(container.SubscribeToResourcesHandler, server.Handlers[OperationNames.SubscribeToResources]); - Assert.Equal(container.UnsubscribeFromResourcesHandler, server.Handlers[OperationNames.UnsubscribeFromResources]); - } - - private sealed class ExposeSetHandlersServer : IMcpServer - { - public Dictionary Handlers = []; - - public void SetOperationHandler(string operationName, Delegate handler) => Handlers[operationName] = handler; - - public ValueTask DisposeAsync() => default; - public bool IsInitialized => throw new NotImplementedException(); - public ClientCapabilities? ClientCapabilities => throw new NotImplementedException(); - public Implementation? ClientInfo => throw new NotImplementedException(); - public IServiceProvider? ServiceProvider => throw new NotImplementedException(); - public void AddNotificationHandler(string method, Func handler) => throw new NotImplementedException(); - public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) where T : class => throw new NotImplementedException(); - public Task StartAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + var handlers = new McpServerHandlers(); + + Assert.Null(handlers.ListToolsHandler); + Assert.Null(handlers.CallToolHandler); + Assert.Null(handlers.ListPromptsHandler); + Assert.Null(handlers.GetPromptHandler); + Assert.Null(handlers.ListResourcesHandler); + Assert.Null(handlers.ReadResourceHandler); + Assert.Null(handlers.GetCompletionHandler); + Assert.Null(handlers.SubscribeToResourcesHandler); + Assert.Null(handlers.UnsubscribeFromResourcesHandler); + + handlers.ListToolsHandler = (p, c) => Task.FromResult(new ListToolsResult()); + handlers.CallToolHandler = (p, c) => Task.FromResult(new CallToolResponse()); + handlers.ListPromptsHandler = (p, c) => Task.FromResult(new ListPromptsResult()); + handlers.GetPromptHandler = (p, c) => Task.FromResult(new GetPromptResult()); + handlers.ListResourcesHandler = (p, c) => Task.FromResult(new ListResourcesResult()); + handlers.ReadResourceHandler = (p, c) => Task.FromResult(new ReadResourceResult()); + handlers.GetCompletionHandler = (p, c) => Task.FromResult(new CompleteResult()); + handlers.SubscribeToResourcesHandler = (s, c) => Task.CompletedTask; + handlers.UnsubscribeFromResourcesHandler = (s, c) => Task.CompletedTask; + + Assert.NotNull(handlers.ListToolsHandler); + Assert.NotNull(handlers.CallToolHandler); + Assert.NotNull(handlers.ListPromptsHandler); + Assert.NotNull(handlers.GetPromptHandler); + Assert.NotNull(handlers.ListResourcesHandler); + Assert.NotNull(handlers.ReadResourceHandler); + Assert.NotNull(handlers.GetCompletionHandler); + Assert.NotNull(handlers.SubscribeToResourcesHandler); + Assert.NotNull(handlers.UnsubscribeFromResourcesHandler); } } diff --git a/tests/mcpdotnet.Tests/Server/McpServerFactoryTests.cs b/tests/mcpdotnet.Tests/Server/McpServerFactoryTests.cs index 59fc5e3..be348fa 100644 --- a/tests/mcpdotnet.Tests/Server/McpServerFactoryTests.cs +++ b/tests/mcpdotnet.Tests/Server/McpServerFactoryTests.cs @@ -1,8 +1,7 @@ using McpDotNet.Protocol.Transport; using McpDotNet.Protocol.Types; using McpDotNet.Server; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; +using Microsoft.Extensions.Logging.Abstractions; using Moq; namespace McpDotNet.Tests.Server; @@ -10,16 +9,12 @@ namespace McpDotNet.Tests.Server; public class McpServerFactoryTests { private readonly Mock _serverTransport; - private readonly Mock _loggerFactory; - private readonly Mock> _serverDelegates; private readonly McpServerOptions _options; private readonly IServiceProvider _serviceProvider; public McpServerFactoryTests() { _serverTransport = new Mock(); - _loggerFactory = new Mock(); - _serverDelegates = new Mock>(); _options = new McpServerOptions { ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, @@ -30,54 +25,26 @@ public McpServerFactoryTests() } [Fact] - public void Constructor_Should_Initialize_With_Valid_Parameters() + public async Task Create_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - var factory = new McpServerFactory(_serverTransport.Object, _options, _loggerFactory.Object, _serverDelegates.Object, _serviceProvider); + await using IMcpServer server = McpServerFactory.Create(_serverTransport.Object, _options, NullLoggerFactory.Instance); // Assert - Assert.NotNull(factory); + Assert.NotNull(server); } [Fact] public void Constructor_Throws_For_Null_ServerTransport() { // Arrange, Act & Assert - Assert.Throws(() => new McpServerFactory(null!, _options, _loggerFactory.Object, _serverDelegates.Object, _serviceProvider)); + Assert.Throws("serverTransport", () => McpServerFactory.Create(null!, _options, NullLoggerFactory.Instance)); } [Fact] public void Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws(() => new McpServerFactory(_serverTransport.Object, null!, _loggerFactory.Object, _serverDelegates.Object, _serviceProvider)); - } - - [Fact] - public void Constructor_Does_Not_Throw_For_Null_ServerDelegates() - { - var factory = new McpServerFactory(_serverTransport.Object, _options, _loggerFactory.Object, null, _serviceProvider); - Assert.NotNull(factory); - } - - [Fact] - public void Constructor_Does_Not_Throw_For_Null_ServiceProvider() - { - var factory = new McpServerFactory(_serverTransport.Object, _options, _loggerFactory.Object, _serverDelegates.Object, null); - Assert.NotNull(factory); - } - - [Fact] - public void CreateServer_Return_IMcpServerInstance() - { - // Arrange - var factory = new McpServerFactory(_serverTransport.Object, _options, _loggerFactory.Object, _serverDelegates.Object, _serviceProvider); - - // Act - var server = factory.CreateServer(); - - // Assert - Assert.NotNull(server); - Assert.IsAssignableFrom(server); + Assert.Throws("serverOptions", () => McpServerFactory.Create(_serverTransport.Object, null!, NullLoggerFactory.Instance)); } } diff --git a/tests/mcpdotnet.Tests/Server/McpServerTests.cs b/tests/mcpdotnet.Tests/Server/McpServerTests.cs index 67d5545..39037ba 100644 --- a/tests/mcpdotnet.Tests/Server/McpServerTests.cs +++ b/tests/mcpdotnet.Tests/Server/McpServerTests.cs @@ -34,7 +34,7 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, ProtocolVersion = "2024", InitializationTimeout = TimeSpan.FromSeconds(30), - Capabilities = capabilities + Capabilities = capabilities, }; } @@ -63,10 +63,13 @@ public void Constructor_Throws_For_Null_Options() } [Fact] - public void Constructor_Throws_For_Null_LoggerFactory() + public async Task Constructor_Does_Not_Throw_For_Null_Logger() { - // Arrange, Act & Assert - Assert.Throws(() => new McpServer(_serverTransport.Object, _options, null!, _serviceProvider)); + // Arrange & Act + await using var server = new McpServer(_serverTransport.Object, _options, null, _serviceProvider); + + // Assert + Assert.NotNull(server); } [Fact] @@ -226,103 +229,114 @@ public async Task Throws_Exception_If_Not_Connected() [Fact] public async Task Can_Handle_Ping_Requests() { - await Can_Handle_Requests(null, "ping", - configureServer: server => { }, - assertResult: response => - { - Assert.IsType(response); - }); + await Can_Handle_Requests( + serverCapabilities: null, + method: "ping", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); + }); } [Fact] public async Task Can_Handle_Initialize_Requests() { - await Can_Handle_Requests(null, "initialize", - configureServer: server => { }, - assertResult: response => - { - Assert.IsType(response); + await Can_Handle_Requests( + serverCapabilities: null, + method: "initialize", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (InitializeResult)response; - Assert.Equal("TestServer", result.ServerInfo.Name); - Assert.Equal("1.0", result.ServerInfo.Version); - Assert.Equal("2024", result.ProtocolVersion); - }); + var result = (InitializeResult)response; + Assert.Equal("TestServer", result.ServerInfo.Name); + Assert.Equal("1.0", result.ServerInfo.Version); + Assert.Equal("2024", result.ProtocolVersion); + }); } [Fact] public async Task Can_Handle_Completion_Requests() { - await Can_Handle_Requests(null, "completion/complete", - configureServer: server => { }, - assertResult: response => - { - Assert.IsType(response); + await Can_Handle_Requests( + serverCapabilities: null, + method: "completion/complete", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (CompleteResult)response; - Assert.NotNull(result.Completion); - Assert.Empty(result.Completion.Values); - Assert.Equal(0, result.Completion.Total); - Assert.False(result.Completion.HasMore); - }); + var result = (CompleteResult)response; + Assert.NotNull(result.Completion); + Assert.Empty(result.Completion.Values); + Assert.Equal(0, result.Completion.Total); + Assert.False(result.Completion.HasMore); + }); } [Fact] public async Task Can_Handle_Completion_Requests_With_Handler() { - await Can_Handle_Requests(null, "completion/complete", - configureServer: server => - { - server.SetGetCompletionHandler((request, ct) => - { - return Task.FromResult(new CompleteResult - { - Completion = new() - { - Values = ["test"], - Total = 2, - HasMore = true - } - }); - }); - }, - assertResult: response => - { - Assert.IsType(response); - - var result = (CompleteResult)response; - Assert.NotNull(result.Completion); - Assert.NotEmpty(result.Completion.Values); - Assert.Equal("test", result.Completion.Values[0]); - Assert.Equal(2, result.Completion.Total); - Assert.True(result.Completion.HasMore); - }); + await Can_Handle_Requests( + serverCapabilities: null, + method: "completion/complete", + configureOptions: options => options with + { + GetCompletionHandler = (request, ct) => + Task.FromResult(new CompleteResult + { + Completion = new() + { + Values = ["test"], + Total = 2, + HasMore = true + } + }) + }, + assertResult: response => + { + Assert.IsType(response); + + var result = (CompleteResult)response; + Assert.NotNull(result.Completion); + Assert.NotEmpty(result.Completion.Values); + Assert.Equal("test", result.Completion.Values[0]); + Assert.Equal(2, result.Completion.Total); + Assert.True(result.Completion.HasMore); + }); } [Fact] public async Task Can_Handle_Resources_List_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Resources = new() }, "resources/list", - configureServer: server => - { - server.SetListResourcesHandler((request, ct) => - { - return Task.FromResult(new ListResourcesResult - { - Resources = [new() { Uri = "test", Name = "Test Resource" }] - }); - }); - - }, - assertResult: response => - { - Assert.IsType(response); - - var result = (ListResourcesResult)response; - Assert.NotNull(result.Resources); - Assert.NotEmpty(result.Resources); - Assert.Equal("test", result.Resources[0].Uri); - }); + await Can_Handle_Requests( + new ServerCapabilities + { + Resources = new() + { + ListResourcesHandler = (request, ct) => + { + return Task.FromResult(new ListResourcesResult + { + Resources = [new() { Uri = "test", Name = "Test Resource" }] + }); + }, + ReadResourceHandler = (request, ct) => throw new NotImplementedException(), + } + }, + "resources/list", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); + + var result = (ListResourcesResult)response; + Assert.NotNull(result.Resources); + Assert.NotEmpty(result.Resources); + Assert.Equal("test", result.Resources[0].Uri); + }); } [Fact] @@ -334,26 +348,32 @@ public async Task Can_Handle_Resources_List_Requests_Throws_Exception_If_No_Hand [Fact] public async Task Can_Handle_ResourcesRead_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Resources = new() }, "resources/read", - configureServer: server => - { - server.SetReadResourceHandler((request, ct) => - { - return Task.FromResult(new ReadResourceResult - { - Contents = [new() { Text = "test" }] - }); - }); - }, - assertResult: response => - { - Assert.IsType(response); - - var result = (ReadResourceResult)response; - Assert.NotNull(result.Contents); - Assert.NotEmpty(result.Contents); - Assert.Equal("test", result.Contents[0].Text); - }); + await Can_Handle_Requests( + new ServerCapabilities + { + Resources = new() + { + ReadResourceHandler = (request, ct) => + { + return Task.FromResult(new ReadResourceResult + { + Contents = [new() { Text = "test" }] + }); + }, + ListResourcesHandler = (request, ct) => throw new NotImplementedException(), + } + }, + method: "resources/read", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); + + var result = (ReadResourceResult)response; + Assert.NotNull(result.Contents); + Assert.NotEmpty(result.Contents); + Assert.Equal("test", result.Contents[0].Text); + }); } [Fact] @@ -365,26 +385,32 @@ public async Task Can_Handle_Resources_Read_Requests_Throws_Exception_If_No_Hand [Fact] public async Task Can_Handle_List_Prompts_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Prompts = new() }, "prompts/list", - configureServer: server => + await Can_Handle_Requests( + new ServerCapabilities { - server.SetListPromptsHandler((request, ct) => + Prompts = new() { - return Task.FromResult(new ListPromptsResult + ListPromptsHandler = (request, ct) => { - Prompts = [new() { Name = "test" }] - }); - }); + return Task.FromResult(new ListPromptsResult + { + Prompts = [new() { Name = "test" }] + }); + }, + GetPromptHandler = (request, ct) => throw new NotImplementedException(), + }, }, - assertResult: response => - { - Assert.IsType(response); + method: "prompts/list", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (ListPromptsResult)response; - Assert.NotNull(result.Prompts); - Assert.NotEmpty(result.Prompts); - Assert.Equal("test", result.Prompts[0].Name); - }); + var result = (ListPromptsResult)response; + Assert.NotNull(result.Prompts); + Assert.NotEmpty(result.Prompts); + Assert.Equal("test", result.Prompts[0].Name); + }); } [Fact] @@ -396,24 +422,24 @@ public async Task Can_Handle_List_Prompts_Requests_Throws_Exception_If_No_Handle [Fact] public async Task Can_Handle_Get_Prompts_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Prompts = new() }, "prompts/get", - configureServer: server => + await Can_Handle_Requests( + new ServerCapabilities { - server.SetGetPromptHandler((request, ct) => + Prompts = new() { - return Task.FromResult(new GetPromptResult - { - Description = "test" - }); - }); + GetPromptHandler = (request, ct) => Task.FromResult(new GetPromptResult { Description = "test" }), + ListPromptsHandler = (request, ct) => throw new NotImplementedException(), + } }, - assertResult: response => - { - Assert.IsType(response); + method: "prompts/get", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (GetPromptResult)response; - Assert.Equal("test", result.Description); - }); + var result = (GetPromptResult)response; + Assert.Equal("test", result.Description); + }); } [Fact] @@ -425,25 +451,31 @@ public async Task Can_Handle_Get_Prompts_Requests_Throws_Exception_If_No_Handler [Fact] public async Task Can_Handle_List_Tools_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Tools = new() }, "tools/list", - configureServer: server => + await Can_Handle_Requests( + new ServerCapabilities { - server.SetListToolsHandler((request, ct) => + Tools = new() { - return Task.FromResult(new ListToolsResult + ListToolsHandler = (request, ct) => { - Tools = [new() { Name = "test" }] - }); - }); + return Task.FromResult(new ListToolsResult + { + Tools = [new() { Name = "test" }] + }); + }, + CallToolHandler = (request, ct) => throw new NotImplementedException(), + } }, - assertResult: response => - { - Assert.IsType(response); + method: "tools/list", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (ListToolsResult)response; - Assert.NotEmpty(result.Tools); - Assert.Equal("test", result.Tools[0].Name); - }); + var result = (ListToolsResult)response; + Assert.NotEmpty(result.Tools); + Assert.Equal("test", result.Tools[0].Name); + }); } [Fact] @@ -455,25 +487,31 @@ public async Task Can_Handle_List_Tools_Requests_Throws_Exception_If_No_Handler_ [Fact] public async Task Can_Handle_Call_Tool_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Tools = new() }, "tools/call", - configureServer: server => + await Can_Handle_Requests( + new ServerCapabilities { - server.SetCallToolHandler((request, ct) => + Tools = new() { - return Task.FromResult(new CallToolResponse + CallToolHandler = (request, ct) => { - Content = [new Content { Text = "test" }] - }); - }); - }, - assertResult: response => - { - Assert.IsType(response); + return Task.FromResult(new CallToolResponse + { + Content = [new Content { Text = "test" }] + }); + }, + ListToolsHandler = (request, ct) => throw new NotImplementedException(), + } + }, + method: "tools/call", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (CallToolResponse)response; - Assert.NotEmpty(result.Content); - Assert.Equal("test", result.Content[0].Text); - }); + var result = (CallToolResponse)response; + Assert.NotEmpty(result.Content); + Assert.Equal("test", result.Content[0].Text); + }); } [Fact] @@ -482,17 +520,19 @@ public async Task Can_Handle_Call_Tool_Requests_Throws_Exception_If_No_Handler_A await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, "tools/call", "CallTool handler not configured"); } - private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Action configureServer, Action assertResult) + private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Func? configureOptions, Action assertResult) { await using var transport = new TestServerTransport(); - var options = serverCapabilities == null ? _options : CreateOptions(serverCapabilities); + var options = CreateOptions(serverCapabilities); + if (configureOptions is not null) + { + options = configureOptions(options); + } await using var server = new McpServer(transport, options, _loggerFactory.Object, _serviceProvider); await server.StartAsync(); - configureServer(server); - var receivedMessage = new TaskCompletionSource(); transport.OnMessageSent = (message) => @@ -521,32 +561,6 @@ private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities se await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - await using var server = new McpServer(transport, options, _loggerFactory.Object, _serviceProvider); - - await server.StartAsync(); - - var receivedMessage = new TaskCompletionSource(); - - transport.OnMessageSent = (message) => - { - if (message is JsonRpcError response && response.Id.AsNumber == 55) - receivedMessage.SetResult(response); - }; - - await transport.SendMessageAsync( - new JsonRpcRequest - { - Method = method, - Id = RequestId.FromNumber(55) - } - ); - - var response = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(1)); - Assert.NotNull(response); - Assert.IsType(response); - - var result = (JsonRpcError)response; - Assert.NotNull(result.Error); - Assert.Equal(expectedError, result.Error.Message); + Assert.Throws(() => McpServerFactory.Create(transport, options, _loggerFactory.Object, _serviceProvider)); } } diff --git a/tests/mcpdotnet.Tests/SseIntegrationTests.cs b/tests/mcpdotnet.Tests/SseIntegrationTests.cs index e291fa9..27748f5 100644 --- a/tests/mcpdotnet.Tests/SseIntegrationTests.cs +++ b/tests/mcpdotnet.Tests/SseIntegrationTests.cs @@ -36,14 +36,8 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() Location = "http://localhost:5000/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Act - var client = await factory.GetClientAsync("test_server"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); @@ -82,14 +76,8 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() Location = $"http://localhost:{port}/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Create client and run tests - var client = await factory.GetClientAsync("everything"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); var tools = await client.ListToolsAsync().ToListAsync(); // assert @@ -110,19 +98,6 @@ public async Task Sampling_Sse_EverythingServer() await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); - var defaultOptions = new McpClientOptions - { - ClientInfo = new() - { - Name = "IntegrationTestClient", - Version = "1.0.0" - }, - Capabilities = new() - { - Sampling = new() - } - }; - var defaultConfig = new McpServerConfig { Id = "everything", @@ -132,29 +107,37 @@ public async Task Sampling_Sse_EverythingServer() Location = $"http://localhost:{port}/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - var client = await factory.GetClientAsync("everything"); - - // Set up the sampling handler int samplingHandlerCalls = 0; - client.SetSamplingHandler((_, _) => + var defaultOptions = new McpClientOptions { - samplingHandlerCalls++; - return Task.FromResult(new CreateMessageResult + ClientInfo = new() { - Model = "test-model", - Role = "assistant", - Content = new Content + Name = "IntegrationTestClient", + Version = "1.0.0" + }, + Capabilities = new() + { + Sampling = new() { - Type = "text", - Text = "Test response" - } - }); - }); + SamplingHandler = (_, _) => + { + samplingHandlerCalls++; + return Task.FromResult(new CreateMessageResult + { + Model = "test-model", + Role = "assistant", + Content = new Content + { + Type = "text", + Text = "Test response" + } + }); + }, + }, + }, + }; + + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); // Call the server's sampleLLM tool which should trigger our sampling handler var result = await client.CallToolAsync( @@ -200,14 +183,8 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU Location = "http://localhost:5000/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Act - var client = await factory.GetClientAsync("test_server"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); @@ -245,14 +222,8 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() Location = "http://localhost:5000/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Act - var client = await factory.GetClientAsync("test_server"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); @@ -300,14 +271,8 @@ public async Task ConnectTwice_Throws() Location = "http://localhost:5000/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Act - var client = await factory.GetClientAsync("test_server"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); var mcpClient = (McpClient)client; var transport = (SseClientTransport)mcpClient.Transport;