diff --git a/README.MD b/README.MD index 4cb6e37..8cc534d 100644 --- a/README.MD +++ b/README.MD @@ -71,9 +71,7 @@ McpServerConfig config = new() } }; -var factory = new McpClientFactory([config], options, NullLoggerFactory.Instance); - -var client = await factory.GetClientAsync("everything"); +var client = await McpClientFactory.CreateAsync(config, options); // Print the list of tools available from the server. await foreach (var tool in client.ListToolsAsync()) @@ -136,55 +134,58 @@ using McpDotNet.Protocol.Types; using McpDotNet.Server; using Microsoft.Extensions.Logging.Abstractions; -var loggerFactory = NullLoggerFactory.Instance; McpServerOptions options = new() { ServerInfo = new() { Name = "MyServer", Version = "1.0.0" }, - Capabilities = new() { Tools = new() }, -}; -McpServerFactory factory = new(new StdioServerTransport("MyServer", loggerFactory), options, loggerFactory); -IMcpServer server = factory.CreateServer(); - -server.SetListToolsHandler(async (request, cancellationToken) => -{ - return new ListToolsResult() + Capabilities = new() { - Tools = - [ - new Tool() + Tools = new() + { + ListToolsHandler = async (request, cancellationToken) => { - Name = "echo", - Description = "Echoes the input back to the client.", - InputSchema = new JsonSchema() + return new ListToolsResult() { - Type = "object", - Properties = new Dictionary() + Tools = + [ + new Tool() + { + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = new JsonSchema() + { + Type = "object", + Properties = new Dictionary() + { + ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + } + }, + } + ] + }; + }, + + CallToolHandler = async (request, cancellationToken) => + { + if (request.Params?.Name == "echo") + { + if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) { - ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + throw new McpServerException("Missing required argument 'message'"); } - }, - } - ] - }; -}); -server.SetCallToolHandler(async (request, cancellationToken) => -{ - if (request.Params?.Name == "echo") - { - if (request.Params.Arguments?.TryGetValue("message", out var message) is not true) - { - throw new McpServerException("Missing required argument 'message'"); - } + return new CallToolResponse() + { + Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] + }; + } - return new CallToolResponse() - { - Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] - }; - } + throw new McpServerException($"Unknown tool: '{request.Params?.Name}'"); + }, + } + }, +}; - throw new McpServerException($"Unknown tool: '{request.Params?.Name}'"); -}); +await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("MyServer"), options); await server.StartAsync(); @@ -192,7 +193,6 @@ await server.StartAsync(); await Task.Delay(Timeout.Infinite); ``` - ## Roadmap - Expand documentation with detailed guides for: diff --git a/samples/anthropic/tools/ToolsConsole/Program.cs b/samples/anthropic/tools/ToolsConsole/Program.cs index 667fdf2..ba7cbc3 100644 --- a/samples/anthropic/tools/ToolsConsole/Program.cs +++ b/samples/anthropic/tools/ToolsConsole/Program.cs @@ -1,24 +1,21 @@ using Anthropic.SDK; using Anthropic.SDK.Constants; using Anthropic.SDK.Messaging; -using System.Linq; using McpDotNet; using McpDotNet.Client; using McpDotNet.Configuration; using McpDotNet.Protocol.Transport; -using Microsoft.Extensions.Logging.Abstractions; internal class Program { private static async Task GetMcpClientAsync() { - - McpClientOptions options = new() + McpClientOptions clientOptions = new() { ClientInfo = new() { Name = "SimpleToolsConsole", Version = "1.0.0" } }; - var config = new McpServerConfig + McpServerConfig serverConfig = new() { Id = "everything", Name = "Everything", @@ -30,13 +27,7 @@ private static async Task GetMcpClientAsync() } }; - var factory = new McpClientFactory( - [config], - options, - NullLoggerFactory.Instance - ); - - return await factory.GetClientAsync("everything"); + return await McpClientFactory.CreateAsync(serverConfig, clientOptions); } private static async Task Main(string[] args) @@ -44,7 +35,7 @@ private static async Task Main(string[] args) try { Console.WriteLine("Initializing MCP 'everything' server"); - var client = await GetMcpClientAsync(); + await using var client = await GetMcpClientAsync(); Console.WriteLine("MCP 'everything' server initialized"); Console.WriteLine("Listing tools..."); var tools = await client.ListToolsAsync().ToListAsync(); @@ -60,10 +51,10 @@ private static async Task Main(string[] args) Console.WriteLine("Asking Claude to call the Echo Tool..."); - var messages = new List - { + List messages = + [ new Message(RoleType.User, "Please call the echo tool with the string 'Hello MCP!' and show me the echoed response.") - }; + ]; var parameters = new MessageParameters() { diff --git a/samples/microsoft.extensions.ai/tools/ToolsConsole/Program.cs b/samples/microsoft.extensions.ai/tools/ToolsConsole/Program.cs index 73894b0..2ed9f75 100644 --- a/samples/microsoft.extensions.ai/tools/ToolsConsole/Program.cs +++ b/samples/microsoft.extensions.ai/tools/ToolsConsole/Program.cs @@ -3,20 +3,18 @@ using McpDotNet.Extensions.AI; using McpDotNet.Protocol.Transport; using Microsoft.Extensions.AI; -using Microsoft.Extensions.Logging.Abstractions; using OpenAI; internal class Program { private static async Task GetMcpClientAsync() { - - McpClientOptions options = new() + McpClientOptions clientOptions = new() { ClientInfo = new() { Name = "SimpleToolsConsole", Version = "1.0.0" } }; - var config = new McpServerConfig + McpServerConfig serverConfig = new() { Id = "everything", Name = "Everything", @@ -28,13 +26,7 @@ private static async Task GetMcpClientAsync() } }; - var factory = new McpClientFactory( - [config], - options, - NullLoggerFactory.Instance - ); - - return await factory.GetClientAsync("everything"); + return await McpClientFactory.CreateAsync(serverConfig, clientOptions); } private static async Task Main(string[] args) @@ -42,7 +34,7 @@ private static async Task Main(string[] args) try { Console.WriteLine("Initializing MCP 'everything' server"); - var client = await GetMcpClientAsync(); + await using var client = await GetMcpClientAsync(); Console.WriteLine("MCP 'everything' server initialized"); Console.WriteLine("Listing tools..."); var mappedTools = await client.ListToolsAsync().Select(t => t.ToAITool(client)).ToListAsync(); diff --git a/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs b/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs new file mode 100644 index 0000000..968c31e --- /dev/null +++ b/src/Common/Polyfills/System/Runtime/CompilerServices/CallerArgumentExpressionAttribute.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace System.Runtime.CompilerServices; + +[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = false)] +internal sealed class CallerArgumentExpressionAttribute : Attribute +{ + public CallerArgumentExpressionAttribute(string parameterName) + { + ParameterName = parameterName; + } + + public string ParameterName { get; } +} diff --git a/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs b/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs index be96fd1..93e34fd 100644 --- a/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs +++ b/src/Common/Polyfills/System/Threading/CancellationTokenSourceExtensions.cs @@ -1,5 +1,3 @@ -using System.Text; - namespace System.Threading.Tasks; internal static class CancellationTokenSourceExtensions diff --git a/src/McpDotNet.Extensions.AI/McpSessionScope.cs b/src/McpDotNet.Extensions.AI/McpSessionScope.cs index 0a528e2..5f03de0 100644 --- a/src/McpDotNet.Extensions.AI/McpSessionScope.cs +++ b/src/McpDotNet.Extensions.AI/McpSessionScope.cs @@ -2,7 +2,6 @@ using McpDotNet.Configuration; using Microsoft.Extensions.AI; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Extensions.AI; @@ -45,7 +44,7 @@ public static async Task CreateAsync(McpServerConfig serverConf } var scope = new McpSessionScope(); - var client = await scope.AddClientAsync(serverConfig, options, loggerFactory).ConfigureAwait(false); + var client = await AddClientAsync(serverConfig, options, loggerFactory).ConfigureAwait(false); scope.Tools = []; await foreach (var tool in client.ListToolsAsync().ConfigureAwait(false)) @@ -81,7 +80,7 @@ public static async Task CreateAsync(IEnumerable CreateAsync(IEnumerable AddClientAsync(McpServerConfig config, - McpClientOptions? options, + private static Task AddClientAsync( + McpServerConfig serverConfig, + McpClientOptions? clientOptions, ILoggerFactory? loggerFactory = null) { - using var factory = new McpClientFactory([config], - options ?? new() { ClientInfo = new() { Name = "AnonymousClient", Version = "1.0.0.0" } }, - loggerFactory ?? NullLoggerFactory.Instance); - factory.DisposeClientsOnDispose = false; - var client = await factory.GetClientAsync(config.Id).ConfigureAwait(false); - _clients.Add(client); - return client; + return McpClientFactory.CreateAsync( + serverConfig, + clientOptions ?? new() { ClientInfo = new() { Name = "AnonymousClient", Version = "1.0.0.0" } }, + loggerFactory: loggerFactory); } /// diff --git a/src/mcpdotnet/Client/IMcpClient.cs b/src/mcpdotnet/Client/IMcpClient.cs index 0d7442f..0a49572 100644 --- a/src/mcpdotnet/Client/IMcpClient.cs +++ b/src/mcpdotnet/Client/IMcpClient.cs @@ -30,20 +30,6 @@ public interface IMcpClient : IAsyncDisposable /// string? ServerInstructions { get; } - /// Sets a handler for the named operation. - /// The name of the operation. - /// The handler. Each operation requires a specific delegate signature. - /// - /// - /// Each operation may have only a single handler. Setting a handler for an operation that already has one - /// will replace the existing handler. - /// - /// - /// provides constants for common operations. - /// - /// - void SetOperationHandler(string operationName, Delegate handler); - /// /// Adds a handler for server notifications of a specific method. /// @@ -60,13 +46,6 @@ public interface IMcpClient : IAsyncDisposable /// void AddNotificationHandler(string method, Func handler); - /// - /// Establishes a connection to the server. - /// - /// A token to cancel the operation. - /// A task representing the asynchronous operation. - Task ConnectAsync(CancellationToken cancellationToken = default); - /// /// Sends a generic JSON-RPC request to the server. /// diff --git a/src/mcpdotnet/Client/McpClient.cs b/src/mcpdotnet/Client/McpClient.cs index c7e4e24..4195f5d 100644 --- a/src/mcpdotnet/Client/McpClient.cs +++ b/src/mcpdotnet/Client/McpClient.cs @@ -6,8 +6,7 @@ using McpDotNet.Protocol.Types; using McpDotNet.Shared; using Microsoft.Extensions.Logging; - -#pragma warning disable CA1508 // Avoid dead conditional code +using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Client; @@ -15,15 +14,11 @@ namespace McpDotNet.Client; internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient { private readonly McpClientOptions _options; - private readonly McpServerConfig _serverConfig; - private readonly ILogger _logger; + private readonly ILogger _logger; private readonly IClientTransport _clientTransport; private volatile bool _isInitializing; - private Func>? _samplingHandler; - private Func>? _rootsHandler; - /// /// Initializes a new instance of the class. /// @@ -31,26 +26,37 @@ internal sealed class McpClient : McpJsonRpcEndpoint, IMcpClient /// Options for the client, defining protocol version and capabilities. /// The server configuration. /// The logger factory. - public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory loggerFactory) + public McpClient(IClientTransport transport, McpClientOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory) : base(transport, loggerFactory) { _options = options; - _serverConfig = serverConfig; - _logger = loggerFactory.CreateLogger(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _clientTransport = transport; EndpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; - if (options.Capabilities?.Sampling is not null) + if (options.Capabilities?.Sampling is { } samplingCapability) { + if (samplingCapability.SamplingHandler is not { } samplingHandler) + { + throw new InvalidOperationException($"Sampling capability was set but it did not provide a handler."); + } + SetRequestHandler( - "sampling/createMessage", request => HandleRequest("Sampling", _samplingHandler, request)); + "sampling/createMessage", + request => samplingHandler(request, CancellationTokenSource?.Token ?? default)); } - if (options.Capabilities?.Roots is not null) + if (options.Capabilities?.Roots is { } rootsCapability) { + if (rootsCapability.RootsHandler is not { } rootsHandler) + { + throw new InvalidOperationException($"Roots capability was set but it did not provide a handler."); + } + SetRequestHandler( - "roots/list", request => HandleRequest("Roots", _rootsHandler, request)); + "roots/list", + request => rootsHandler(request, CancellationTokenSource?.Token ?? default)); } } @@ -66,47 +72,6 @@ public McpClient(IClientTransport transport, McpClientOptions options, McpServer /// public override string EndpointName { get; } - public void SetOperationHandler(string operationName, Delegate handler) - { - if (operationName is null) - { - throw new ArgumentNullException(nameof(operationName)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - if (!TrySetOperationHandler(OperationNames.Sampling, operationName, handler, ref _samplingHandler) && - !TrySetOperationHandler(OperationNames.Roots, operationName, handler, ref _rootsHandler)) - { - throw new ArgumentException($"Unknown operation '{operationName}'", nameof(operationName)); - } - - static bool TrySetOperationHandler( - string targetOperationName, - string operationName, - Delegate handler, - ref Func>? field) - { - if (operationName == targetOperationName) - { - if (handler is Func> typed) - { - field = typed; - return true; - } - - throw new ArgumentException( - $"Handler must be of type {typeof(Func>)}", - nameof(handler)); - } - - return false; - } - } - /// public async Task ConnectAsync(CancellationToken cancellationToken = default) { @@ -195,19 +160,4 @@ await SendMessageAsync( throw new McpClientException("Initialization timed out"); } } - - private async Task HandleRequest( - string friendlyName, - Func>? handler, - TRequest args) - { - if (handler is not null) - { - return await handler(args, CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - } - - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.HandlerNotConfigured(friendlyName, EndpointName); - throw new McpClientException($"{friendlyName} handler not configured."); - } } diff --git a/src/mcpdotnet/Client/McpClientExtensions.cs b/src/mcpdotnet/Client/McpClientExtensions.cs index 6b8a93a..7e4703c 100644 --- a/src/mcpdotnet/Client/McpClientExtensions.cs +++ b/src/mcpdotnet/Client/McpClientExtensions.cs @@ -1,8 +1,7 @@ using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Types; -using Microsoft.Extensions.Logging; +using McpDotNet.Utils; using System.Runtime.CompilerServices; -using System.Xml.XPath; namespace McpDotNet.Client; @@ -20,10 +19,7 @@ public static class McpClientExtensions /// A token to cancel the operation. public static Task SendNotificationAsync(this IMcpClient client, string method, object? parameters = null, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendMessageAsync( new JsonRpcNotification { Method = method, Params = parameters }, @@ -38,10 +34,7 @@ public static Task SendNotificationAsync(this IMcpClient client, string method, /// A task that completes when the ping is successful. public static Task PingAsync(this IMcpClient client, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("ping", null), @@ -80,10 +73,7 @@ public static async IAsyncEnumerable ListToolsAsync( /// A task containing the server's response with tool information. public static Task ListToolsAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("tools/list", CreateCursorDictionary(cursor)), @@ -122,10 +112,7 @@ public static async IAsyncEnumerable ListPromptsAsync( /// A task containing the server's response with prompt information. public static Task ListPromptsAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("prompts/list", CreateCursorDictionary(cursor)), @@ -142,10 +129,7 @@ public static Task ListPromptsAsync(this IMcpClient client, s /// A task containing the prompt's content and messages. public static Task GetPromptAsync(this IMcpClient client, string name, Dictionary? arguments = null, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("prompts/get", CreateParametersDictionary(name, arguments)), @@ -183,10 +167,7 @@ public static async IAsyncEnumerable ListResourcesAsync( /// A token to cancel the operation. public static Task ListResourcesAsync(this IMcpClient client, string? cursor, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("resources/list", CreateCursorDictionary(cursor)), @@ -201,10 +182,7 @@ public static Task ListResourcesAsync(this IMcpClient clien /// A token to cancel the operation. public static Task ReadResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("resources/read", new() { ["uri"] = uri }), @@ -221,22 +199,9 @@ public static Task ReadResourceAsync(this IMcpClient client, /// A token to cancel the operation. public static Task GetCompletionAsync(this IMcpClient client, Reference reference, string argumentName, string argumentValue, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } - - if (reference is null) - { - throw new ArgumentNullException(nameof(reference)); - } - - if (string.IsNullOrWhiteSpace(argumentName)) - { - throw argumentName is null ? - new ArgumentNullException(nameof(argumentName)) : - new ArgumentException("Argument name cannot be empty.", nameof(argumentName)); - } + Throw.IfNull(client); + Throw.IfNull(reference); + Throw.IfNullOrWhiteSpace(argumentName); if (!reference.Validate(out string? validationMessage)) { @@ -260,10 +225,7 @@ public static Task GetCompletionAsync(this IMcpClient client, Re /// A token to cancel the operation. public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("resources/subscribe", new() { ["uri"] = uri }), @@ -278,10 +240,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, /// A token to cancel the operation. public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string uri, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("resources/unsubscribe", new() { ["uri"] = uri }), @@ -298,42 +257,13 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u /// A task containing the tool's response. public static Task CallToolAsync(this IMcpClient client, string toolName, Dictionary arguments, CancellationToken cancellationToken = default) { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } + Throw.IfNull(client); return client.SendRequestAsync( CreateRequest("tools/call", CreateParametersDictionary(toolName, arguments)), cancellationToken); } - /// Sets the handler for server sampling requests. - /// The client. - /// The sampling request handler. - public static void SetSamplingHandler(this IMcpClient client, Func> handler) - { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } - - client.SetOperationHandler(OperationNames.Sampling, handler); - } - - /// Sets the handler for server roots requests. - /// The client. - /// The roots request handler. - public static void SetRootsHandler(this IMcpClient client, Func> handler) - { - if (client is null) - { - throw new ArgumentNullException(nameof(client)); - } - - client.SetOperationHandler(OperationNames.Roots, handler); - } - private static JsonRpcRequest CreateRequest(string method, Dictionary? parameters) => new JsonRpcRequest { diff --git a/src/mcpdotnet/Client/McpClientFactory.cs b/src/mcpdotnet/Client/McpClientFactory.cs index fa1bc26..8f499fb 100644 --- a/src/mcpdotnet/Client/McpClientFactory.cs +++ b/src/mcpdotnet/Client/McpClientFactory.cs @@ -3,236 +3,129 @@ using McpDotNet.Configuration; using McpDotNet.Logging; using McpDotNet.Protocol.Transport; +using McpDotNet.Utils; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Client; -/// -/// Factory for creating MCP clients based on configuration. It caches clients for reuse, so it is safe to call GetClientAsync multiple times. -/// Call GetClientAsync to get a client for a specific server (by ID), which will create a new client and connect if it doesn't already exist. -/// All server configurations must be passed in the constructor. -/// Capabilities (as defined in client options) are shared across all clients, as the client host can always decide not to use certain capabilities. -/// -public class McpClientFactory : IDisposable -{ - private const string ARGUMENTS_OPTIONS_KEY = "arguments"; - private const string COMMAND_OPTIONS_KEY = "command"; - private readonly Dictionary _serverConfigs; - private readonly McpClientOptions _clientOptions; - private readonly Dictionary _clients = []; - private readonly Func _transportFactoryMethod; - private readonly Func _clientFactoryMethod; - private readonly ILoggerFactory _loggerFactory; - private readonly ILogger _logger; - private bool _isDisposed; - - /// - /// Gets or sets a value indicating whether clients should be disposed when the factory is disposed. - /// - public bool DisposeClientsOnDispose { get; set; } = true; - /// - /// Initializes a new instance of the class. - /// It is not necessary to pass factory methods for creating transports and clients, as default implementations are provided. - /// Custom factory methods can be provided for mocking or to use custom transport or client implementations. - /// - /// Configuration objects for each server the factory should support. - /// A configuration object which specifies client capabilities and protocol version. +/// Provides factory methods for creating MCP clients. +public static class McpClientFactory +{ + /// Creates an , connecting it to the specified server. + /// Configuration for the target server to which the client should connect. + /// A client configuration object which specifies client capabilities and protocol version. + /// An optional factory method which returns transport implementations based on a server configuration. /// A logger factory for creating loggers for clients. - /// An optional factory method which returns transport implementations based on a server configuration. - /// An optional factory method which creates a client based on client options and transport implementation. - public McpClientFactory( - IEnumerable serverConfigs, + /// A token to cancel the operation. + /// An that's connected to the specified server. + /// is . + /// is . + /// contains invalid information. + /// returns an invalid transport. + public static async Task CreateAsync( + McpServerConfig serverConfig, McpClientOptions clientOptions, + Func? createTransportFunc = null, ILoggerFactory? loggerFactory = null, - Func? transportFactoryMethod = null, - Func? clientFactoryMethod = null) + CancellationToken cancellationToken = default) { - if (serverConfigs is null) - { - throw new ArgumentNullException(nameof(serverConfigs)); - } + Throw.IfNull(serverConfig); + Throw.IfNull(clientOptions); - if (clientOptions is null) - { - throw new ArgumentNullException(nameof(clientOptions)); - } + createTransportFunc ??= CreateTransport; + + string endpointName = $"Client ({serverConfig.Id}: {serverConfig.Name})"; - loggerFactory ??= NullLoggerFactory.Instance; + var logger = loggerFactory?.CreateLogger(typeof(McpClientFactory)) ?? NullLogger.Instance; + logger.CreatingClient(endpointName); - _serverConfigs = serverConfigs.ToDictionary(c => c.Id); - _clientOptions = clientOptions; - _loggerFactory = loggerFactory; - _logger = loggerFactory.CreateLogger(); - _transportFactoryMethod = transportFactoryMethod ?? CreateTransport; - _clientFactoryMethod = clientFactoryMethod ?? ((transport, serverConfig, options) => new McpClient(transport, options, serverConfig, loggerFactory)); + var transport = + createTransportFunc(serverConfig, loggerFactory) ?? + throw new InvalidOperationException($"{nameof(createTransportFunc)} returned a null transport."); - // Initialize commands for stdio transport, this is to run commands in a shell even if specified directly, as otherwise - // the stdio protocol will not work correctly. - _logger.InitializingStdioCommands(); - foreach (var config in _serverConfigs.Values.Where(c => c.TransportType.Equals(TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase))) + try { - InitializeCommand(config); + McpClient client = new(transport, clientOptions, serverConfig, loggerFactory); + try + { + await client.ConnectAsync(cancellationToken).ConfigureAwait(false); + logger.ClientCreated(endpointName); + return client; + } + catch + { + await client.DisposeAsync().ConfigureAwait(false); + throw; + } } - } - - /// - /// Gets or creates a client for the specified server. The first time a server is requested, a new client is created and connected. - /// Note that this will often spawn the server process during connection, so in some cases you want to call this method only when needed. - /// In other cases, you may want to call it ahead of time to ensure the server is ready when needed, as it may take some time to start up. - /// - /// The ID of the server to connect to. It must have been passed in the serverConfigs when constructing the factory. - /// A token to cancel the operation. - public async Task GetClientAsync(string serverId, CancellationToken cancellationToken = default) - { - if (!_serverConfigs.TryGetValue(serverId, out var config)) + catch { - _logger.ServerNotFound(serverId); - throw new ArgumentException($"Server with ID '{serverId}' not found.", nameof(serverId)); + await transport.DisposeAsync().ConfigureAwait(false); + throw; } - - string endpointName = $"Client ({serverId}: {config.Name})"; - - if (_clients.TryGetValue(serverId, out var existingClient)) - { - _logger.ClientExists(endpointName); - return existingClient; - } - - _logger.CreatingClient(endpointName); - - var transport = _transportFactoryMethod(config); - var client = _clientFactoryMethod(transport, config, _clientOptions); - await client.ConnectAsync(cancellationToken).ConfigureAwait(false); - - _logger.ClientCreated(endpointName); - - _clients[serverId] = client; - return client; } - internal Func TransportFactoryMethod => _transportFactoryMethod; - - private IClientTransport CreateTransport(McpServerConfig config) + private static IClientTransport CreateTransport(McpServerConfig serverConfig, ILoggerFactory? loggerFactory) { - string endpointName = $"Client ({config.Id}: {config.Name})"; - - var options = string.Join(", ", config.TransportOptions?.Select(kv => $"{kv.Key}={kv.Value}") ?? []); - _logger.CreatingTransport(endpointName, config.TransportType, options); - - if (string.Equals(config.TransportType, TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase)) + if (string.Equals(serverConfig.TransportType, TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase)) { + string? command = serverConfig.TransportOptions?.GetValueOrDefault("command"); + if (string.IsNullOrWhiteSpace(command)) + { + command = serverConfig.Location; + if (string.IsNullOrWhiteSpace(command)) + { + throw new ArgumentException("Command is required for stdio transport.", nameof(serverConfig)); + } + } + + string? arguments = serverConfig.TransportOptions?.GetValueOrDefault("arguments"); + + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && + serverConfig.TransportType.Equals(TransportTypes.StdIo, StringComparison.OrdinalIgnoreCase) && + !string.IsNullOrEmpty(command) && + !string.Equals(Path.GetFileName(command), "cmd.exe", StringComparison.OrdinalIgnoreCase)) + { + // On Windows, for stdio, we need to wrap non-shell commands with cmd.exe /c {command} (usually npx or uvicorn). + // The stdio transport will not work correctly if the command is not run in a shell. + arguments = string.IsNullOrWhiteSpace(arguments) ? + $"/c {command}" : + $"/c {command} {arguments}"; + command = "cmd.exe"; + } + return new StdioClientTransport(new StdioClientTransportOptions { - Command = GetCommand(config), - Arguments = config.TransportOptions?.GetValueOrDefault(ARGUMENTS_OPTIONS_KEY)?.Split(' '), - WorkingDirectory = config.TransportOptions?.GetValueOrDefault("workingDirectory"), - EnvironmentVariables = config.TransportOptions? + Command = command!, + Arguments = arguments, + WorkingDirectory = serverConfig.TransportOptions?.GetValueOrDefault("workingDirectory"), + EnvironmentVariables = serverConfig.TransportOptions? .Where(kv => kv.Key.StartsWith("env:", StringComparison.Ordinal)) - .ToDictionary(kv => kv.Key.Substring(4), kv => kv.Value), - ShutdownTimeout = TimeSpan.TryParse(config.TransportOptions?.GetValueOrDefault("shutdownTimeout"), CultureInfo.InvariantCulture, out var timespan) ? timespan : StdioClientTransportOptions.DefaultShutdownTimeout - }, config, _loggerFactory); + .ToDictionary(kv => kv.Key.Substring("env:".Length), kv => kv.Value), + ShutdownTimeout = TimeSpan.TryParse(serverConfig.TransportOptions?.GetValueOrDefault("shutdownTimeout"), CultureInfo.InvariantCulture, out var timespan) ? timespan : StdioClientTransportOptions.DefaultShutdownTimeout + }, serverConfig, loggerFactory); } - if (string.Equals(config.TransportType, TransportTypes.Sse, StringComparison.OrdinalIgnoreCase) || - string.Equals(config.TransportType, "http", StringComparison.OrdinalIgnoreCase)) + if (string.Equals(serverConfig.TransportType, TransportTypes.Sse, StringComparison.OrdinalIgnoreCase) || + string.Equals(serverConfig.TransportType, "http", StringComparison.OrdinalIgnoreCase)) { - return new SseClientTransport( - new SseClientTransportOptions - { - ConnectionTimeout = TimeSpan.FromSeconds(ParseOrDefault(config.TransportOptions, "connectionTimeout", 30)), - MaxReconnectAttempts = ParseOrDefault(config.TransportOptions, "maxReconnectAttempts", 3), - ReconnectDelay = TimeSpan.FromSeconds(ParseOrDefault(config.TransportOptions, "reconnectDelay", 5)), - AdditionalHeaders = config.TransportOptions? - .Where(kv => kv.Key.StartsWith("header.", StringComparison.Ordinal)) - .ToDictionary(kv => kv.Key.Substring(7), kv => kv.Value) - }, config, _loggerFactory); - } - - throw new ArgumentException($"Unsupported transport type '{config.TransportType}'.", nameof(config)); - } - - private static int ParseOrDefault(Dictionary? options, string key, int defaultValue) - { - if (options?.TryGetValue(key, out var value) ?? false) - { - if (!int.TryParse(value, out var result)) - throw new FormatException($"Invalid value '{value}' for option '{key}'"); - return result; - } - return defaultValue; - } - - private static string GetCommand(McpServerConfig config) - { - var command = config.TransportOptions?.GetValueOrDefault(COMMAND_OPTIONS_KEY); - - if (string.IsNullOrEmpty(command)) - return config.Location!; - - return RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "cmd.exe" : command!; - } - - /// - /// Initializes a non-shell command by injecting a /c {command} argument, as the command will be run in a shell. - /// - private void InitializeCommand(McpServerConfig config) - { - string endpointName = $"Client ({config.Id}: {config.Name})"; - - // If the command is empty or already contains cmd.exe, we don't need to do anything - var command = config.TransportOptions?.GetValueOrDefault(COMMAND_OPTIONS_KEY); - - if (string.IsNullOrEmpty(command) || !RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - { - return; - } - - // On Windows, we need to wrap non-shell commands with cmd.exe /c - if (command!.IndexOf("cmd.exe", StringComparison.OrdinalIgnoreCase) >= 0) - { - _logger.SkippingShellWrapper(endpointName); - return; - } - - // If the command is not empty and does not contain cmd.exe, we need to inject /c {command} (usually npx or uvicorn) - // This is because the stdio transport will not work correctly if the command is not run in a shell - _logger.PromotingCommandToShellArgumentForStdio(endpointName, command, config.TransportOptions!.GetValueOrDefault(ARGUMENTS_OPTIONS_KEY) ?? ""); - config.TransportOptions![ARGUMENTS_OPTIONS_KEY] = config.TransportOptions.TryGetValue(ARGUMENTS_OPTIONS_KEY, out var args) - ? $"/c {command} {args}" - : $"/c {command}"; - } - - /// - /// Disposes all clients created by the factory. - /// - /// - protected virtual void Dispose(bool disposing) - { - if (!_isDisposed) - { - if (disposing && DisposeClientsOnDispose) - DisposeClients(); - - _isDisposed = true; - } - } - - private void DisposeClients() - { - foreach (var client in _clients.Values) - { - client?.DisposeAsync().AsTask().Wait(); + return new SseClientTransport(new SseClientTransportOptions + { + ConnectionTimeout = TimeSpan.FromSeconds(ParseInt32OrDefault(serverConfig.TransportOptions, "connectionTimeout", 30)), + MaxReconnectAttempts = ParseInt32OrDefault(serverConfig.TransportOptions, "maxReconnectAttempts", 3), + ReconnectDelay = TimeSpan.FromSeconds(ParseInt32OrDefault(serverConfig.TransportOptions, "reconnectDelay", 5)), + AdditionalHeaders = serverConfig.TransportOptions? + .Where(kv => kv.Key.StartsWith("header.", StringComparison.Ordinal)) + .ToDictionary(kv => kv.Key.Substring("header.".Length), kv => kv.Value) + }, serverConfig, loggerFactory); + + static int ParseInt32OrDefault(Dictionary? options, string key, int defaultValue) => + options?.TryGetValue(key, out var value) is not true ? defaultValue : + int.TryParse(value, out var result) ? result : + throw new ArgumentException($"Invalid value '{value}' for option '{key}' in transport options.", nameof(serverConfig)); } - _clients.Clear(); - } - - /// - public void Dispose() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: true); - GC.SuppressFinalize(this); + throw new ArgumentException($"Unsupported transport type '{serverConfig.TransportType}'.", nameof(serverConfig)); } } \ No newline at end of file diff --git a/src/mcpdotnet/Configuration/DefaultMcpServerBuilder.cs b/src/mcpdotnet/Configuration/DefaultMcpServerBuilder.cs index 5559d71..5bcea8d 100644 --- a/src/mcpdotnet/Configuration/DefaultMcpServerBuilder.cs +++ b/src/mcpdotnet/Configuration/DefaultMcpServerBuilder.cs @@ -1,4 +1,5 @@ -using Microsoft.Extensions.DependencyInjection; +using McpDotNet.Utils; +using Microsoft.Extensions.DependencyInjection; namespace McpDotNet.Configuration; @@ -17,6 +18,8 @@ internal class DefaultMcpServerBuilder : IMcpServerBuilder /// public DefaultMcpServerBuilder(IServiceCollection services) { - Services = services ?? throw new ArgumentNullException(nameof(services)); + Throw.IfNull(services); + + Services = services; } } diff --git a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Handler.cs b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Handler.cs index fe0b955..9ffcf9a 100644 --- a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Handler.cs +++ b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Handler.cs @@ -1,6 +1,7 @@ using McpDotNet.Configuration; using McpDotNet.Protocol.Types; using McpDotNet.Server; +using McpDotNet.Utils; using Microsoft.Extensions.DependencyInjection; namespace McpDotNet; @@ -17,12 +18,9 @@ public static partial class McpServerBuilderExtensions /// The handler. public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.ListToolsHandler = handler); + builder.Services.Configure(s => s.ListToolsHandler = handler); return builder; } @@ -33,12 +31,9 @@ public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder buil /// The handler. public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.CallToolHandler = handler); + builder.Services.Configure(s => s.CallToolHandler = handler); return builder; } @@ -49,12 +44,9 @@ public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder build /// The handler. public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.ListPromptsHandler = handler); + builder.Services.Configure(s => s.ListPromptsHandler = handler); return builder; } @@ -65,12 +57,9 @@ public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder bu /// The handler. public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.GetPromptHandler = handler); + builder.Services.Configure(s => s.GetPromptHandler = handler); return builder; } @@ -81,12 +70,9 @@ public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder buil /// The handler. public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.ListResourcesHandler = handler); + builder.Services.Configure(s => s.ListResourcesHandler = handler); return builder; } @@ -97,28 +83,22 @@ public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder /// The handler. public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.ReadResourceHandler = handler); + builder.Services.Configure(s => s.ReadResourceHandler = handler); return builder; } /// - /// Sets the handler for get resources requests. + /// Sets the handler for get completion requests. /// /// The builder instance. /// The handler. public static IMcpServerBuilder WithGetCompletionHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.GetCompletionHandler = handler); + builder.Services.Configure(s => s.GetCompletionHandler = handler); return builder; } @@ -129,12 +109,9 @@ public static IMcpServerBuilder WithGetCompletionHandler(this IMcpServerBuilder /// The handler. public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.SubscribeToResourcesHandler = handler); + builder.Services.Configure(s => s.SubscribeToResourcesHandler = handler); return builder; } @@ -145,12 +122,9 @@ public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerB /// The handler. public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); - builder.Services.Configure(s => s.UnsubscribeFromResourcesHandler = handler); + builder.Services.Configure(s => s.UnsubscribeFromResourcesHandler = handler); return builder; } } diff --git a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Tools.cs b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Tools.cs index bbc030f..9b4b5a0 100644 --- a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Tools.cs +++ b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Tools.cs @@ -6,6 +6,7 @@ using McpDotNet.Configuration; using McpDotNet.Protocol.Types; using McpDotNet.Server; +using McpDotNet.Utils; using Microsoft.Extensions.DependencyInjection; namespace McpDotNet; @@ -40,15 +41,14 @@ public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder) /// Types with marked methods to add as tools to the server. public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params Type[] toolTypes) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); if (toolTypes is null || toolTypes.Length == 0) + { throw new ArgumentException("At least one tool type must be provided.", nameof(toolTypes)); + } - var tools = new List(); + List tools = []; Dictionary, CancellationToken, Task>> callbacks = []; foreach (var type in toolTypes) @@ -114,7 +114,9 @@ public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder bui } if (toolTypes.Count == 0) + { throw new ArgumentException("No types with marked methods found in the assembly.", nameof(assembly)); + } return WithTools(builder, [.. toolTypes]); } diff --git a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Transports.cs b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Transports.cs index 6a1a165..2eeae4e 100644 --- a/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Transports.cs +++ b/src/mcpdotnet/Configuration/McpServerBuilderExtensions.Transports.cs @@ -1,5 +1,6 @@ using McpDotNet.Configuration; using McpDotNet.Protocol.Transport; +using McpDotNet.Utils; using Microsoft.Extensions.DependencyInjection; namespace McpDotNet; @@ -15,10 +16,7 @@ public static partial class McpServerBuilderExtensions /// The builder instance. public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder builder) { - if (builder is null) - { - throw new ArgumentNullException(nameof(builder)); - } + Throw.IfNull(builder); builder.Services.AddSingleton(); return builder; diff --git a/src/mcpdotnet/Configuration/McpServerServiceCollectionExtension.cs b/src/mcpdotnet/Configuration/McpServerServiceCollectionExtension.cs index 7f9e8a5..79b1776 100644 --- a/src/mcpdotnet/Configuration/McpServerServiceCollectionExtension.cs +++ b/src/mcpdotnet/Configuration/McpServerServiceCollectionExtension.cs @@ -1,10 +1,12 @@ -using System.Reflection; -using McpDotNet.Configuration; +using McpDotNet.Configuration; using McpDotNet.Hosting; using McpDotNet.Protocol.Transport; using McpDotNet.Protocol.Types; using McpDotNet.Server; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using System.Reflection; namespace McpDotNet; @@ -36,11 +38,21 @@ public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, A public static IMcpServerBuilder AddMcpServer(this IServiceCollection services, McpServerOptions serverOptions) { services.AddSingleton(serverOptions); - services.AddSingleton(); services.AddHostedService(); services.AddOptions(); + services.AddSingleton(services => + { + IServerTransport serverTransport = services.GetRequiredService(); + McpServerOptions options = services.GetRequiredService(); + ILoggerFactory? loggerFactory = services.GetService(); + + if (services.GetService>() is { } handlersOptions) + { + options = handlersOptions.Value.OverwriteWithSetHandlers(options); + } - services.AddSingleton(sp => sp.GetRequiredService().CreateServer()); + return McpServerFactory.Create(serverTransport, options, loggerFactory, services); + }); return new DefaultMcpServerBuilder(services); } diff --git a/src/mcpdotnet/Hosting/McpServerHostedService.cs b/src/mcpdotnet/Hosting/McpServerHostedService.cs index 11d4006..fcb1666 100644 --- a/src/mcpdotnet/Hosting/McpServerHostedService.cs +++ b/src/mcpdotnet/Hosting/McpServerHostedService.cs @@ -1,4 +1,5 @@ using McpDotNet.Server; +using McpDotNet.Utils; using Microsoft.Extensions.Hosting; namespace McpDotNet.Hosting; @@ -17,7 +18,9 @@ public class McpServerHostedService : BackgroundService /// public McpServerHostedService(IMcpServer server) { - _server = server ?? throw new ArgumentNullException(nameof(server)); + Throw.IfNull(server); + + _server = server; } /// diff --git a/src/mcpdotnet/Logging/Log.cs b/src/mcpdotnet/Logging/Log.cs index 551c5aa..4092efe 100644 --- a/src/mcpdotnet/Logging/Log.cs +++ b/src/mcpdotnet/Logging/Log.cs @@ -8,60 +8,15 @@ namespace McpDotNet.Logging; /// internal static partial class Log { - [LoggerMessage(Level = LogLevel.Debug, Message = "Client {clientId} initializing connection to server {serverId}")] - internal static partial void ClientConnecting(this ILogger logger, string clientId, string serverId); - [LoggerMessage(Level = LogLevel.Information, Message = "Server {endpointName} capabilities received: {capabilities}, server info: {serverInfo}")] internal static partial void ServerCapabilitiesReceived(this ILogger logger, string endpointName, string capabilities, string serverInfo); - [LoggerMessage(Level = LogLevel.Warning, Message = "Request {requestId} timed out")] - internal static partial void RequestTimeout(this ILogger logger, int requestId); - - [LoggerMessage(Level = LogLevel.Debug, Message = "Command for {endpointName} already contains shell wrapper, skipping argument injection")] - internal static partial void SkippingShellWrapper(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "Server connection config for Id={serverId} not found")] - internal static partial void ServerNotFound(this ILogger logger, string serverId); - - [LoggerMessage(Level = LogLevel.Debug, Message = "Client for {endpointName} already created, returning cached client")] - internal static partial void ClientExists(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Information, Message = "Creating client for {endpointName}")] internal static partial void CreatingClient(this ILogger logger, string endpointName); [LoggerMessage(Level = LogLevel.Information, Message = "Client for {endpointName} created and connected")] internal static partial void ClientCreated(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Information, Message = "Creating transport for {endpointName} with type {transportType} and options {options}")] - internal static partial void CreatingTransport(this ILogger logger, string endpointName, string transportType, string options); - - [LoggerMessage(Level = LogLevel.Debug, Message = "Promoting command for {endpointName} to shell argument for stdio transport with command {command} and arguments {arguments}")] - internal static partial void PromotingCommandToShellArgumentForStdio(this ILogger logger, string endpointName, string command, string arguments); - - [LoggerMessage(Level = LogLevel.Debug, Message = "Initializing stdio commands")] - internal static partial void InitializingStdioCommands(this ILogger logger); - - [LoggerMessage(Level = LogLevel.Error, Message = "{handlerName} handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void HandlerNotConfigured(this ILogger logger, string handlerName, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "List tools handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void ListToolsHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "Call tool handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void CallToolHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "List prompts handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void ListPromptsHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "Get prompt handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void GetPromptHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "List resources handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void ListResourcesHandlerNotConfigured(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Error, Message = "Read resource handler not configured for server {endpointName}, always set a handler when using this capability")] - internal static partial void ReadResourceHandlerNotConfigured(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Error, Message = "Client server {endpointName} already initializing")] internal static partial void ClientAlreadyInitializing(this ILogger logger, string endpointName); @@ -77,33 +32,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Error, Message = "Client {endpointName} initialization timeout")] internal static partial void ClientInitializationTimeout(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Information, Message = "Pinging sent ({endpointName})")] - internal static partial void PingingServer(this ILogger logger, string endpointName); - - [LoggerMessage(Level = LogLevel.Information, Message = "Listing tools for {endpointName} with cursor {cursor}")] - internal static partial void ListingTools(this ILogger logger, string endpointName, string cursor); - - [LoggerMessage(Level = LogLevel.Information, Message = "Listing prompts for {endpointName} with cursor {cursor}")] - internal static partial void ListingPrompts(this ILogger logger, string endpointName, string cursor); - - [LoggerMessage(Level = LogLevel.Information, Message = "Getting prompt {name} for {endpointName} with arguments {arguments}")] - internal static partial void GettingPrompt(this ILogger logger, string endpointName, string name, string arguments); - - [LoggerMessage(Level = LogLevel.Information, Message = "Listing resources for {endpointName} with cursor {cursor}")] - internal static partial void ListingResources(this ILogger logger, string endpointName, string cursor); - - [LoggerMessage(Level = LogLevel.Information, Message = "Reading resource {uri} for {endpointName}")] - internal static partial void ReadingResource(this ILogger logger, string endpointName, string uri); - - [LoggerMessage(Level = LogLevel.Information, Message = "Subscribing to resource {uri} for {endpointName}")] - internal static partial void SubscribingToResource(this ILogger logger, string endpointName, string uri); - - [LoggerMessage(Level = LogLevel.Information, Message = "Unsubscribing from resource {uri} for {endpointName}")] - internal static partial void UnsubscribingFromResource(this ILogger logger, string endpointName, string uri); - - [LoggerMessage(Level = LogLevel.Information, Message = "Calling tool {toolName} for {endpointName} with arguments {arguments}")] - internal static partial void CallingTool(this ILogger logger, string endpointName, string toolName, string arguments); - [LoggerMessage(Level = LogLevel.Information, Message = "Endpoint message processing cancelled for {endpointName}")] internal static partial void EndpointMessageProcessingCancelled(this ILogger logger, string endpointName); @@ -152,9 +80,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Error, Message = "Request invalid response type for {endpointName} with method {method}")] internal static partial void RequestInvalidResponseType(this ILogger logger, string endpointName, string method); - [LoggerMessage(Level = LogLevel.Error, Message = "Request params type conversion error for {endpointName} with method {method}: expected {expectedType}")] - internal static partial void RequestParamsTypeConversionError(this ILogger logger, string endpointName, string method, Type expectedType); - [LoggerMessage(Level = LogLevel.Information, Message = "Cleaning up endpoint {endpointName}")] internal static partial void CleaningUpEndpoint(this ILogger logger, string endpointName); @@ -168,7 +93,7 @@ internal static partial class Log internal static partial void TransportConnecting(this ILogger logger, string endpointName); [LoggerMessage(Level = LogLevel.Information, Message = "Creating process for transport for {endpointName} with command {command}, arguments {arguments}, environment {environment}, working directory {workingDirectory}, shutdown timeout {shutdownTimeout}")] - internal static partial void CreateProcessForTransport(this ILogger logger, string endpointName, string command, string arguments, string environment, string workingDirectory, string shutdownTimeout); + internal static partial void CreateProcessForTransport(this ILogger logger, string endpointName, string command, string? arguments, string environment, string workingDirectory, string shutdownTimeout); [LoggerMessage(Level = LogLevel.Error, Message = "Transport for {endpointName} error: {data}")] internal static partial void TransportError(this ILogger logger, string endpointName, string data); @@ -239,9 +164,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Debug, Message = "Transport waiting for shutdown for {endpointName}")] internal static partial void TransportWaitingForShutdown(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Warning, Message = "Transport killing process for {endpointName}")] - internal static partial void TransportKillingProcess(this ILogger logger, string endpointName); - [LoggerMessage(Level = LogLevel.Error, Message = "Transport shutdown failed for {endpointName}")] internal static partial void TransportShutdownFailed(this ILogger logger, string endpointName, Exception exception); @@ -266,9 +188,6 @@ internal static partial class Log [LoggerMessage(Level = LogLevel.Debug, Message = "Sending message to {endpointName}: {message}")] internal static partial void SendingMessage(this ILogger logger, string endpointName, string message); - [LoggerMessage(Level = LogLevel.Information, Message = "Sending notification for {endpointName}: {method}")] - internal static partial void SendingNotification(this ILogger logger, string endpointName, string method); - [LoggerMessage( EventId = 7000, Level = LogLevel.Error, @@ -310,50 +229,6 @@ public static partial void TransportEndpointEventParseFailed( string data, Exception exception); - [LoggerMessage( - EventId = 7004, - Level = LogLevel.Trace, - Message = "Invalid completion reference {reference} for {endpointName}: {validationMessage}" - )] - public static partial void InvalidCompletionReference( - this ILogger logger, - string endpointName, - string reference, - string validationMessage); - - [LoggerMessage( - EventId = 7005, - Level = LogLevel.Trace, - Message = "Invalid completion argument name {argumentName} for {endpointName}" - )] - public static partial void InvalidCompletionArgumentName( - this ILogger logger, - string endpointName, - string argumentName); - - [LoggerMessage( - EventId = 7006, - Level = LogLevel.Trace, - Message = "Invalid completion argument value {argumentValue} for {argumentName} for {endpointName}" - )] - public static partial void InvalidCompletionArgumentValue( - this ILogger logger, - string endpointName, - string argumentValue, - string argumentName); - - [LoggerMessage( - EventId = 7007, - Level = LogLevel.Debug, - Message = "Getting completion for {endpointName} with reference {reference}, argument name {argumentName}, argument value {argumentValue}" - )] - public static partial void GettingCompletion( - this ILogger logger, - string endpointName, - string reference, - string argumentName, - string argumentValue); - [LoggerMessage( EventId = 7008, Level = LogLevel.Error, diff --git a/src/mcpdotnet/Protocol/Messages/OperationNames.cs b/src/mcpdotnet/Protocol/Messages/OperationNames.cs index f8f3bcc..903256c 100644 --- a/src/mcpdotnet/Protocol/Messages/OperationNames.cs +++ b/src/mcpdotnet/Protocol/Messages/OperationNames.cs @@ -1,8 +1,4 @@ -using McpDotNet.Protocol.Messages; -using McpDotNet.Protocol.Types; -using Microsoft.Extensions.Logging; - -namespace McpDotNet.Protocol.Messages; +namespace McpDotNet.Protocol.Messages; /// Provides names of standard operations for use with registering handlers. /// diff --git a/src/mcpdotnet/Protocol/Messages/RequestIdConverter.cs b/src/mcpdotnet/Protocol/Messages/RequestIdConverter.cs index 4599e95..33b6d9e 100644 --- a/src/mcpdotnet/Protocol/Messages/RequestIdConverter.cs +++ b/src/mcpdotnet/Protocol/Messages/RequestIdConverter.cs @@ -1,4 +1,5 @@ -using System.Text.Json; +using McpDotNet.Utils; +using System.Text.Json; using System.Text.Json.Serialization; namespace McpDotNet.Protocol.Messages; @@ -22,14 +23,15 @@ public override RequestId Read(ref Utf8JsonReader reader, Type typeToConvert, Js /// public override void Write(Utf8JsonWriter writer, RequestId value, JsonSerializerOptions options) { - if (writer is null) - { - throw new ArgumentNullException(nameof(writer)); - } + Throw.IfNull(writer); if (value.IsString) + { writer.WriteStringValue(value.AsString); + } else + { writer.WriteNumberValue(value.AsNumber); + } } } \ No newline at end of file diff --git a/src/mcpdotnet/Protocol/Transport/PasteArguments.cs b/src/mcpdotnet/Protocol/Transport/PasteArguments.cs deleted file mode 100644 index be4750c..0000000 --- a/src/mcpdotnet/Protocol/Transport/PasteArguments.cs +++ /dev/null @@ -1,102 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -// Copied from: -// https://github.com/dotnet/runtime/blob/d2650b6ae7023a2d9d2c74c56116f1f18472ab04/src/libraries/System.Private.CoreLib/src/System/PasteArguments.cs -// and changed from using ValueStringBuilder to StringBuilder. - -using System.Text; - -namespace McpDotNet.Protocol.Transport -{ - internal static partial class PasteArguments - { - internal static void AppendArgument(StringBuilder stringBuilder, string argument) - { - if (stringBuilder.Length != 0) - { - stringBuilder.Append(' '); - } - - // Parsing rules for non-argv[0] arguments: - // - Backslash is a normal character except followed by a quote. - // - 2N backslashes followed by a quote ==> N literal backslashes followed by unescaped quote - // - 2N+1 backslashes followed by a quote ==> N literal backslashes followed by a literal quote - // - Parsing stops at first whitespace outside of quoted region. - // - (post 2008 rule): A closing quote followed by another quote ==> literal quote, and parsing remains in quoting mode. - if (argument.Length != 0 && ContainsNoWhitespaceOrQuotes(argument)) - { - // Simple case - no quoting or changes needed. - stringBuilder.Append(argument); - } - else - { - stringBuilder.Append(Quote); - int idx = 0; - while (idx < argument.Length) - { - char c = argument[idx++]; - if (c == Backslash) - { - int numBackSlash = 1; - while (idx < argument.Length && argument[idx] == Backslash) - { - idx++; - numBackSlash++; - } - - if (idx == argument.Length) - { - // We'll emit an end quote after this so must double the number of backslashes. - stringBuilder.Append(Backslash, numBackSlash * 2); - } - else if (argument[idx] == Quote) - { - // Backslashes will be followed by a quote. Must double the number of backslashes. - stringBuilder.Append(Backslash, numBackSlash * 2 + 1); - stringBuilder.Append(Quote); - idx++; - } - else - { - // Backslash will not be followed by a quote, so emit as normal characters. - stringBuilder.Append(Backslash, numBackSlash); - } - - continue; - } - - if (c == Quote) - { - // Escape the quote so it appears as a literal. This also guarantees that we won't end up generating a closing quote followed - // by another quote (which parses differently pre-2008 vs. post-2008.) - stringBuilder.Append(Backslash); - stringBuilder.Append(Quote); - continue; - } - - stringBuilder.Append(c); - } - - stringBuilder.Append(Quote); - } - } - - private static bool ContainsNoWhitespaceOrQuotes(string s) - { - for (int i = 0; i < s.Length; i++) - { - char c = s[i]; - if (char.IsWhiteSpace(c) || c == Quote) - { - return false; - } - } - - return true; - } - - private const char Quote = '\"'; - private const char Backslash = '\\'; - } -} \ No newline at end of file diff --git a/src/mcpdotnet/Protocol/Transport/SseClientTransport.cs b/src/mcpdotnet/Protocol/Transport/SseClientTransport.cs index 9cc4c7a..94e9d28 100644 --- a/src/mcpdotnet/Protocol/Transport/SseClientTransport.cs +++ b/src/mcpdotnet/Protocol/Transport/SseClientTransport.cs @@ -5,6 +5,7 @@ using McpDotNet.Configuration; using McpDotNet.Logging; using McpDotNet.Protocol.Messages; +using McpDotNet.Utils; using McpDotNet.Utils.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -54,27 +55,16 @@ public SseClientTransport(SseClientTransportOptions transportOptions, McpServerC public SseClientTransport(SseClientTransportOptions transportOptions, McpServerConfig serverConfig, HttpClient httpClient, ILoggerFactory? loggerFactory, bool ownsHttpClient = false) : base(loggerFactory) { - if (transportOptions is null) - { - throw new ArgumentNullException(nameof(transportOptions)); - } - - if (serverConfig is null) - { - throw new ArgumentNullException(nameof(serverConfig)); - } - - if (httpClient is null) - { - throw new ArgumentNullException(nameof(httpClient)); - } + Throw.IfNull(transportOptions); + Throw.IfNull(serverConfig); + Throw.IfNull(httpClient); _options = transportOptions; _serverConfig = serverConfig; _sseEndpoint = new Uri(serverConfig.Location!); _httpClient = httpClient; _connectionCts = new CancellationTokenSource(); - _logger = loggerFactory is not null ? loggerFactory.CreateLogger() : NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _jsonOptions = JsonSerializerOptionsExtensions.DefaultOptions; _connectionEstablished = new TaskCompletionSource(); _ownsHttpClient = ownsHttpClient; diff --git a/src/mcpdotnet/Protocol/Transport/StdioClientTransport.cs b/src/mcpdotnet/Protocol/Transport/StdioClientTransport.cs index cb705c2..d8613d5 100644 --- a/src/mcpdotnet/Protocol/Transport/StdioClientTransport.cs +++ b/src/mcpdotnet/Protocol/Transport/StdioClientTransport.cs @@ -1,6 +1,4 @@ using System.Diagnostics; -using System.Runtime.InteropServices; -using System.Text; using System.Text.Json; using McpDotNet.Configuration; using McpDotNet.Logging; @@ -10,6 +8,8 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +#pragma warning disable CA2213 // Disposable fields should be disposed + namespace McpDotNet.Protocol.Transport; /// @@ -37,15 +37,8 @@ public sealed class StdioClientTransport : TransportBase, IClientTransport public StdioClientTransport(StdioClientTransportOptions options, McpServerConfig serverConfig, ILoggerFactory? loggerFactory = null) : base(loggerFactory) { - if (options is null) - { - throw new ArgumentNullException(nameof(options)); - } - - if (serverConfig is null) - { - throw new ArgumentNullException(nameof(serverConfig)); - } + Throw.IfNull(options); + Throw.IfNull(serverConfig); _options = options; _serverConfig = serverConfig; @@ -71,23 +64,17 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) var startInfo = new ProcessStartInfo { FileName = _options.Command, - UseShellExecute = false, RedirectStandardInput = true, RedirectStandardOutput = true, RedirectStandardError = true, - CreateNoWindow = RuntimeInformation.IsOSPlatform(OSPlatform.Windows), + UseShellExecute = false, + CreateNoWindow = true, WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory, }; - if (_options.Arguments is { Length: > 0 }) + if (!string.IsNullOrWhiteSpace(_options.Arguments)) { - StringBuilder argsBuilder = new(); - foreach (var arg in _options.Arguments) - { - PasteArguments.AppendArgument(argsBuilder, arg); - } - - startInfo.Arguments = argsBuilder.ToString(); + startInfo.Arguments = _options.Arguments; } if (_options.EnvironmentVariables != null) @@ -105,10 +92,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) _process = new Process { StartInfo = startInfo }; // Set up error logging - _process.ErrorDataReceived += (sender, args) => - { - _logger.TransportError(EndpointName, args.Data ?? "(no data)"); - }; + _process.ErrorDataReceived += (sender, args) => _logger.TransportError(EndpointName, args.Data ?? "(no data)"); if (!_process.Start()) { @@ -120,7 +104,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default) _process.BeginErrorReadLine(); // Start reading messages in the background - _readTask = Task.Run(async () => await ReadMessagesAsync(_shutdownCts.Token).ConfigureAwait(false), CancellationToken.None); + _readTask = Task.Run(() => ReadMessagesAsync(_shutdownCts.Token), CancellationToken.None); _logger.TransportReadingMessages(EndpointName); SetConnected(true); @@ -272,10 +256,10 @@ private async Task CleanupAsync(CancellationToken cancellationToken) _process = null; } - if (_shutdownCts != null) + if (_shutdownCts is { } shutdownCts) { - await _shutdownCts.CancelAsync().ConfigureAwait(false); - _shutdownCts.Dispose(); + await shutdownCts.CancelAsync().ConfigureAwait(false); + shutdownCts.Dispose(); _shutdownCts = null; } diff --git a/src/mcpdotnet/Protocol/Transport/StdioClientTransportOptions.cs b/src/mcpdotnet/Protocol/Transport/StdioClientTransportOptions.cs index 35f303d..1ee77e5 100644 --- a/src/mcpdotnet/Protocol/Transport/StdioClientTransportOptions.cs +++ b/src/mcpdotnet/Protocol/Transport/StdioClientTransportOptions.cs @@ -19,7 +19,7 @@ public record StdioClientTransportOptions /// /// Arguments to pass to the server process. /// - public string[]? Arguments { get; set; } = []; + public string? Arguments { get; set; } /// /// The working directory for the server process. diff --git a/src/mcpdotnet/Protocol/Transport/StdioServerTransport.cs b/src/mcpdotnet/Protocol/Transport/StdioServerTransport.cs index cc3a09f..30045ff 100644 --- a/src/mcpdotnet/Protocol/Transport/StdioServerTransport.cs +++ b/src/mcpdotnet/Protocol/Transport/StdioServerTransport.cs @@ -1,13 +1,14 @@ using System.Text.Json; -using McpDotNet.Configuration; using McpDotNet.Logging; using McpDotNet.Protocol.Messages; using McpDotNet.Server; +using McpDotNet.Utils; using McpDotNet.Utils.Json; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; #pragma warning disable CA2208 // Instantiate argument exceptions correctly +#pragma warning disable CA2213 // Disposable fields should be disposed namespace McpDotNet.Protocol.Transport; @@ -104,12 +105,15 @@ public StdioServerTransport(McpServerOptions serverOptions, TextReader input, Te public StdioServerTransport(string serverName, TextReader input, TextWriter output, ILoggerFactory? loggerFactory = null) : base(loggerFactory) { - _serverName = serverName ?? throw new ArgumentNullException(nameof(serverName)); - _input = input ?? throw new ArgumentNullException(nameof(input)); - _output = output ?? throw new ArgumentNullException(nameof(output)); + Throw.IfNull(serverName); + Throw.IfNull(input); + Throw.IfNull(output); - _logger = loggerFactory is not null ? loggerFactory.CreateLogger() : NullLogger.Instance; + _serverName = serverName; + _input = input; + _output = output; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; _jsonOptions = JsonSerializerOptionsExtensions.DefaultOptions; } @@ -274,15 +278,8 @@ private async Task CleanupAsync(CancellationToken cancellationToken) /// Validates the and extracts from it the server name to use. private static string GetServerName(McpServerOptions serverOptions) { - if (serverOptions is null) - { - throw new ArgumentNullException(nameof(serverOptions)); - } - - if (serverOptions.ServerInfo is null) - { - throw new ArgumentNullException($"{nameof(serverOptions)}.{nameof(serverOptions.ServerInfo)}"); - } + Throw.IfNull(serverOptions); + Throw.IfNull(serverOptions.ServerInfo); return serverOptions.ServerInfo.Name; } diff --git a/src/mcpdotnet/Protocol/Transport/TransportBase.cs b/src/mcpdotnet/Protocol/Transport/TransportBase.cs index d55807b..b57733c 100644 --- a/src/mcpdotnet/Protocol/Transport/TransportBase.cs +++ b/src/mcpdotnet/Protocol/Transport/TransportBase.cs @@ -26,7 +26,7 @@ protected TransportBase(ILoggerFactory? loggerFactory) SingleReader = true, SingleWriter = true, }); - _logger = loggerFactory is not null ? loggerFactory.CreateLogger() : NullLogger.Instance; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; } /// diff --git a/src/mcpdotnet/Protocol/Types/Capabilities.cs b/src/mcpdotnet/Protocol/Types/Capabilities.cs index 20a5e13..b623f29 100644 --- a/src/mcpdotnet/Protocol/Types/Capabilities.cs +++ b/src/mcpdotnet/Protocol/Types/Capabilities.cs @@ -1,6 +1,8 @@ -using System.Text.Json.Serialization; +using McpDotNet.Server; +using System.Text.Json.Serialization; namespace McpDotNet.Protocol.Types; + /// /// Represents the capabilities that a client may support. /// See the schema for details @@ -37,6 +39,10 @@ public record RootsCapability /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } + + /// Gets or sets the handler for sampling requests. + [JsonIgnore] + public Func>? RootsHandler { get; init; } } /// @@ -46,6 +52,10 @@ public record RootsCapability public record SamplingCapability { // Currently empty in the spec, but may be extended in the future + + /// Gets or sets the handler for sampling requests. + [JsonIgnore] + public Func>? SamplingHandler { get; init; } } /// @@ -68,6 +78,18 @@ public record PromptsCapability /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } + + /// + /// Gets or sets the handler for list prompts requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? ListPromptsHandler { get; init; } + + /// + /// Gets or sets the handler for get prompt requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? GetPromptHandler { get; init; } } /// @@ -87,6 +109,30 @@ public record ResourcesCapability /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } + + /// + /// Gets or sets the handler for list resources requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? ListResourcesHandler { get; init; } + + /// + /// Gets or sets the handler for read resources requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? ReadResourceHandler { get; init; } + + /// + /// Gets or sets the handler for subscribe to resources messages. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? SubscribeToResourcesHandler { get; init; } + + /// + /// Gets or sets the handler for unsubscribe from resources messages. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? UnsubscribeFromResourcesHandler { get; init; } } /// @@ -100,4 +146,16 @@ public record ToolsCapability /// [JsonPropertyName("listChanged")] public bool? ListChanged { get; init; } + + /// + /// Gets or sets the handler for list tools requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? ListToolsHandler { get; init; } + + /// + /// Gets or sets the handler for call tool requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? CallToolHandler { get; init; } } \ No newline at end of file diff --git a/src/mcpdotnet/Server/IMcpServer.cs b/src/mcpdotnet/Server/IMcpServer.cs index 8672701..024b883 100644 --- a/src/mcpdotnet/Server/IMcpServer.cs +++ b/src/mcpdotnet/Server/IMcpServer.cs @@ -28,20 +28,6 @@ public interface IMcpServer : IAsyncDisposable /// IServiceProvider? ServiceProvider { get; } - /// Sets a handler for the named operation. - /// The name of the operation. - /// The handler. Each operation requires a specific delegate signature. - /// - /// - /// Each operation may have only a single handler. Setting a handler for an operation that already has one - /// will replace the existing handler. - /// - /// - /// provides constants for common operations. - /// - /// - void SetOperationHandler(string operationName, Delegate handler); - /// /// Adds a handler for client notifications of a specific method. /// diff --git a/src/mcpdotnet/Server/IMcpServerFactory.cs b/src/mcpdotnet/Server/IMcpServerFactory.cs deleted file mode 100644 index a8c1e70..0000000 --- a/src/mcpdotnet/Server/IMcpServerFactory.cs +++ /dev/null @@ -1,13 +0,0 @@ -namespace McpDotNet.Server; - -/// -/// Factory for creating instances. -/// -public interface IMcpServerFactory -{ - /// - /// Creates a new server instance. - /// - /// - IMcpServer CreateServer(); -} diff --git a/src/mcpdotnet/Server/McpServer.cs b/src/mcpdotnet/Server/McpServer.cs index 2c525c4..ddd2be6 100644 --- a/src/mcpdotnet/Server/McpServer.cs +++ b/src/mcpdotnet/Server/McpServer.cs @@ -1,11 +1,11 @@ - -using System.Text.Json.Nodes; +using System.Text.Json.Nodes; using McpDotNet.Logging; -using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Transport; using McpDotNet.Protocol.Types; using McpDotNet.Shared; +using McpDotNet.Utils; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Server; @@ -15,17 +15,7 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer private readonly IServerTransport _serverTransport; private readonly McpServerOptions _options; private volatile bool _isInitializing; - private readonly ILogger _logger; - - private Func, CancellationToken, Task>? _listToolsHandler; - private Func, CancellationToken, Task>? _callToolHandler; - private Func, CancellationToken, Task>? _listPromptsHandler; - private Func, CancellationToken, Task>? _getPromptHandler; - private Func, CancellationToken, Task>? _listResourcesHandler; - private Func, CancellationToken, Task>? _readResourceHandler; - private Func, CancellationToken, Task>? _getCompletionHandler; - private Func, CancellationToken, Task>? _subscribeToResourcesHandler; - private Func, CancellationToken, Task>? _unsubscribeFromResourcesHandler; + private readonly ILogger _logger; /// /// Creates a new instance of . @@ -36,22 +26,29 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer /// Logger factory to use for logging /// Optional service provider to use for dependency injection /// - public McpServer(IServerTransport transport, McpServerOptions options, ILoggerFactory loggerFactory, IServiceProvider? serviceProvider) + public McpServer(IServerTransport transport, McpServerOptions options, ILoggerFactory? loggerFactory, IServiceProvider? serviceProvider) : base(transport, loggerFactory) { + Throw.IfNull(options); + _serverTransport = transport; - _options = options ?? throw new ArgumentNullException(nameof(options)); - _logger = loggerFactory.CreateLogger(); + _options = options; + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; ServerInstructions = options.ServerInstructions; ServiceProvider = serviceProvider; + AddNotificationHandler("notifications/initialized", _ => + { + IsInitialized = true; + return Task.CompletedTask; + }); + + SetInitializeHandler(options); + SetCompletionHandler(options); + SetPingHandler(); SetToolsHandler(options); SetPromptsHandler(options); SetResourcesHandler(options); - SetCompletionHandler(); - SetInitializeHandler(options); - SetPingHandler(); - AddNotificationHandler(); } public ClientCapabilities? ClientCapabilities { get; set; } @@ -69,54 +66,6 @@ public McpServer(IServerTransport transport, McpServerOptions options, ILoggerFa public override string EndpointName => $"Server ({_options.ServerInfo.Name} {_options.ServerInfo.Version}), Client ({ClientInfo?.Name} {ClientInfo?.Version})"; - public void SetOperationHandler(string operationName, Delegate handler) - { - if (operationName is null) - { - throw new ArgumentNullException(nameof(operationName)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - if (!TrySetOperationHandler(OperationNames.ListTools, operationName, handler, ref _listToolsHandler) && - !TrySetOperationHandler(OperationNames.CallTool, operationName, handler, ref _callToolHandler) && - !TrySetOperationHandler(OperationNames.ListPrompts, operationName, handler, ref _listPromptsHandler) && - !TrySetOperationHandler(OperationNames.GetPrompt, operationName, handler, ref _getPromptHandler) && - !TrySetOperationHandler(OperationNames.ListResources, operationName, handler, ref _listResourcesHandler) && - !TrySetOperationHandler(OperationNames.ReadResource, operationName, handler, ref _readResourceHandler) && - !TrySetOperationHandler(OperationNames.GetCompletion, operationName, handler, ref _getCompletionHandler) && - !TrySetOperationHandler(OperationNames.SubscribeToResources, operationName, handler, ref _subscribeToResourcesHandler) && - !TrySetOperationHandler(OperationNames.UnsubscribeFromResources, operationName, handler, ref _unsubscribeFromResourcesHandler)) - { - throw new ArgumentException($"Unknown operation '{operationName}'", nameof(operationName)); - } - - static bool TrySetOperationHandler( - string targetOperationName, - string operationName, - Delegate handler, - ref Func? field) - { - if (operationName == targetOperationName) - { - if (handler is Func typed) - { - field = typed; - return true; - } - - throw new ArgumentException( - $"Handler must be of type {typeof(Func)}", - nameof(handler)); - } - - return false; - } - } - /// public async Task StartAsync(CancellationToken cancellationToken = default) { @@ -153,15 +102,6 @@ public async Task StartAsync(CancellationToken cancellationToken = default) } } - private void AddNotificationHandler() - { - AddNotificationHandler("notifications/initialized", (notification) => - { - IsInitialized = true; - return Task.CompletedTask; - }); - } - private void SetPingHandler() { SetRequestHandler("ping", @@ -171,136 +111,94 @@ private void SetPingHandler() private void SetInitializeHandler(McpServerOptions options) { SetRequestHandler("initialize", - (request) => - { - ClientCapabilities = request?.Capabilities ?? new(); - ClientInfo = request?.ClientInfo; - return Task.FromResult(new InitializeResult() - { - ProtocolVersion = options.ProtocolVersion, - Instructions = ServerInstructions, - ServerInfo = _options.ServerInfo, - Capabilities = options.Capabilities ?? new ServerCapabilities(), - }); - }); + request => + { + ClientCapabilities = request?.Capabilities ?? new(); + ClientInfo = request?.ClientInfo; + return Task.FromResult(new InitializeResult() + { + ProtocolVersion = options.ProtocolVersion, + Instructions = ServerInstructions, + ServerInfo = _options.ServerInfo, + Capabilities = options.Capabilities ?? new ServerCapabilities(), + }); + }); } - private void SetCompletionHandler() + private void SetCompletionHandler(McpServerOptions options) { + // This capability is not optional, so return an empty result if there is no handler. SetRequestHandler("completion/complete", - async (request) => - { - if (_getCompletionHandler is null) - { - // This capability is not optional, so return an empty result if there is no handler - return new CompleteResult() - { - Completion = new() - { - Values = [], - Total = 0, - HasMore = false - } - }; - } - - return await _getCompletionHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + options.GetCompletionHandler is { } handler ? + request => handler(new(this, request), CancellationTokenSource?.Token ?? default) : + request => Task.FromResult(new CompleteResult() { Completion = new() { Values = [], Total = 0, HasMore = false } })); } private void SetResourcesHandler(McpServerOptions options) { - if (options.Capabilities?.Resources is not null) + if (options.Capabilities?.Resources is not { } resourcesCapability) { - SetRequestHandler("resources/list", - async (request) => - { - if (_listResourcesHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.ListResourcesHandlerNotConfigured(EndpointName); - throw new McpServerException("ListResources handler not configured"); - } - - return await _listResourcesHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + return; + } - SetRequestHandler("resources/read", - async (request) => - { - if (_readResourceHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.ReadResourceHandlerNotConfigured(EndpointName); - throw new McpServerException("ReadResource handler not configured"); - } - - return await _readResourceHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + if (resourcesCapability.ListResourcesHandler is not { } listResourcesHandler || + resourcesCapability.ReadResourceHandler is not { } readResourceHandler) + { + throw new McpServerException("Resources capability was enabled, but ListResources and/or ReadResource handlers were not specified."); + } + + CancellationToken cancellationToken = CancellationTokenSource?.Token ?? default; + SetRequestHandler("resources/list", request => listResourcesHandler(new(this, request), cancellationToken)); + SetRequestHandler("resources/read", request => readResourceHandler(new(this, request), cancellationToken)); + + if (resourcesCapability.Subscribe is not true) + { + return; + } + + var subscribeHandler = resourcesCapability.SubscribeToResourcesHandler; + var unsubscribeHandler = resourcesCapability.UnsubscribeFromResourcesHandler; + if (subscribeHandler is null || unsubscribeHandler is null) + { + throw new McpServerException("Resources capability was enabled with subscribe support, but SubscribeToResources and/or UnsubscribeFromResources handlers were not specified."); } + + // TODO: Implement Subscribe support } private void SetPromptsHandler(McpServerOptions options) { - if (options.Capabilities?.Prompts is not null) + if (options.Capabilities?.Prompts is not { } promptsCapability) { - SetRequestHandler("prompts/list", - async (request) => - { - if (_listPromptsHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.ListPromptsHandlerNotConfigured(EndpointName); - throw new McpServerException("ListPrompts handler not configured"); - } - - return await _listPromptsHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + return; + } - SetRequestHandler("prompts/get", - async (request) => - { - if (_getPromptHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.GetPromptHandlerNotConfigured(EndpointName); - throw new McpServerException("GetPrompt handler not configured"); - } - - return await _getPromptHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + if (promptsCapability.ListPromptsHandler is not { } listPromptsHandler || + promptsCapability.GetPromptHandler is not { } getPromptHandler) + { + throw new McpServerException("Prompts capability was enabled, but ListPrompts and/or GetPrompt handlers were not specified."); } + + CancellationToken cancellationToken = CancellationTokenSource?.Token ?? default; + SetRequestHandler("prompts/list", request => listPromptsHandler(new(this, request), cancellationToken)); + SetRequestHandler("prompts/get", request => getPromptHandler(new(this, request), cancellationToken)); } private void SetToolsHandler(McpServerOptions options) { - if (options.Capabilities?.Tools is not null) + if (options.Capabilities?.Tools is not { } toolsCapability) { - SetRequestHandler("tools/list", - async (request) => - { - if (_listToolsHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.ListToolsHandlerNotConfigured(EndpointName); - throw new McpServerException("ListTools handler not configured"); - } - - return await _listToolsHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + return; + } - SetRequestHandler("tools/call", - async (request) => - { - if (_callToolHandler is null) - { - // Setting the capability, but not a handler means we have nothing to return to the server - _logger.CallToolHandlerNotConfigured(EndpointName); - throw new McpServerException("CallTool handler not configured"); - } - - return await _callToolHandler(new(this, request), CancellationTokenSource?.Token ?? CancellationToken.None).ConfigureAwait(false); - }); + if (toolsCapability.ListToolsHandler is not { } listToolsHandler || + toolsCapability.CallToolHandler is not { } callToolHandler) + { + throw new McpServerException("ListTools and/or CallTool handlers were specified but the Tools capability was not enabled."); } + + CancellationToken cancellationToken = CancellationTokenSource?.Token ?? default; + SetRequestHandler("tools/list", request => listToolsHandler(new(this, request), cancellationToken)); + SetRequestHandler("tools/call", request => callToolHandler(new(this, request), cancellationToken)); } } diff --git a/src/mcpdotnet/Server/McpServerDelegates.cs b/src/mcpdotnet/Server/McpServerDelegates.cs deleted file mode 100644 index 419eabd..0000000 --- a/src/mcpdotnet/Server/McpServerDelegates.cs +++ /dev/null @@ -1,93 +0,0 @@ -using McpDotNet.Protocol.Types; - -namespace McpDotNet.Server; - -/// -/// Container for delegates that can be applied to an MCP server. -/// -public class McpServerDelegates -{ - /// - /// Gets or sets the handler for list tools requests. - /// - public Func, CancellationToken, Task>? ListToolsHandler { get; set; } - - /// - /// Gets or sets the handler for call tool requests. - /// - public Func, CancellationToken, Task>? CallToolHandler { get; set; } - - /// - /// Gets or sets the handler for list prompts requests. - /// - public Func, CancellationToken, Task>? ListPromptsHandler { get; set; } - - /// - /// Gets or sets the handler for get prompt requests. - /// - public Func, CancellationToken, Task>? GetPromptHandler { get; set; } - - /// - /// Gets or sets the handler for list resources requests. - /// - public Func, CancellationToken, Task>? ListResourcesHandler { get; set; } - - /// - /// Gets or sets the handler for read resources requests. - /// - public Func, CancellationToken, Task>? ReadResourceHandler { get; set; } - - /// - /// Gets or sets the handler for get resources requests. - /// - public Func, CancellationToken, Task>? GetCompletionHandler { get; set; } - - /// - /// Gets or sets the handler for subscribe to resources messages. - /// - public Func, CancellationToken, Task>? SubscribeToResourcesHandler { get; set; } - - /// - /// Gets or sets the handler for subscribe to resources messages. - /// - public Func, CancellationToken, Task>? UnsubscribeFromResourcesHandler { get; set; } - - /// - /// Applies the delegates to the server. - /// - /// - public void Apply(IMcpServer server) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (ListToolsHandler != null) - server.SetListToolsHandler(ListToolsHandler); - - if (CallToolHandler != null) - server.SetCallToolHandler(CallToolHandler); - - if (ListPromptsHandler != null) - server.SetListPromptsHandler(ListPromptsHandler); - - if (GetPromptHandler != null) - server.SetGetPromptHandler(GetPromptHandler); - - if (ListResourcesHandler != null) - server.SetListResourcesHandler(ListResourcesHandler); - - if (ReadResourceHandler != null) - server.SetReadResourceHandler(ReadResourceHandler); - - if (GetCompletionHandler != null) - server.SetGetCompletionHandler(GetCompletionHandler); - - if (SubscribeToResourcesHandler != null) - server.SetSubscribeToResourcesHandler(SubscribeToResourcesHandler); - - if (UnsubscribeFromResourcesHandler != null) - server.SetUnsubscribeFromResourcesHandler(UnsubscribeFromResourcesHandler); - } -} diff --git a/src/mcpdotnet/Server/McpServerExtensions.cs b/src/mcpdotnet/Server/McpServerExtensions.cs index 5ed5f4c..06e9ded 100644 --- a/src/mcpdotnet/Server/McpServerExtensions.cs +++ b/src/mcpdotnet/Server/McpServerExtensions.cs @@ -1,5 +1,6 @@ using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Types; +using McpDotNet.Utils; namespace McpDotNet.Server; @@ -12,10 +13,7 @@ public static class McpServerExtensions public static Task RequestSamplingAsync( this IMcpServer server, CreateMessageRequestParams request, CancellationToken cancellationToken) { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } + Throw.IfNull(server); if (server.ClientCapabilities?.Sampling is null) { @@ -33,10 +31,7 @@ public static Task RequestSamplingAsync( public static Task RequestRootsAsync( this IMcpServer server, ListRootsRequestParams request, CancellationToken cancellationToken) { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } + Throw.IfNull(server); if (server.ClientCapabilities?.Roots is null) { @@ -47,175 +42,4 @@ public static Task RequestRootsAsync( new JsonRpcRequest { Method = "roots/list", Params = request }, cancellationToken); } - - /// - /// Sets the handler for list tools requests. - /// - public static void SetListToolsHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.ListTools, handler); - } - - /// - /// Sets the handler for call tool requests. - /// - public static void SetCallToolHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.CallTool, handler); - } - - /// - /// Sets the handler for list prompts requests. - /// - public static void SetListPromptsHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.ListPrompts, handler); - } - - /// - /// Sets the handler for get prompt requests. - /// - public static void SetGetPromptHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.GetPrompt, handler); - } - - /// - /// Sets the handler for list resources requests. - /// - public static void SetListResourcesHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.ListResources, handler); - } - - /// - /// Sets the handler for read resource requests. - /// - public static void SetReadResourceHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.ReadResource, handler); - } - - /// - /// Sets the handler for get completion requests. - /// - public static void SetGetCompletionHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.GetCompletion, handler); - } - - /// - /// Sets the handler for subscribe to resources requests. - /// - public static void SetSubscribeToResourcesHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.SubscribeToResources, handler); - } - - /// - /// Sets the handler for unsubscribe from resources requests. - /// - public static void SetUnsubscribeFromResourcesHandler( - this IMcpServer server, Func, CancellationToken, Task> handler) - { - if (server is null) - { - throw new ArgumentNullException(nameof(server)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } - - server.SetOperationHandler(OperationNames.UnsubscribeFromResources, handler); - } } diff --git a/src/mcpdotnet/Server/McpServerFactory.cs b/src/mcpdotnet/Server/McpServerFactory.cs index 26e1e45..a8e37d2 100644 --- a/src/mcpdotnet/Server/McpServerFactory.cs +++ b/src/mcpdotnet/Server/McpServerFactory.cs @@ -1,62 +1,36 @@ using McpDotNet.Protocol.Transport; +using McpDotNet.Utils; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.Extensions.Options; namespace McpDotNet.Server; /// -/// Factory for creating instances. -/// This is the main entry point for creating a server. -/// Pass the server transport, options, and logger factory to the constructor. Server instructions are optional. -/// -/// Then call CreateServer to create a new server instance. -/// You can create multiple servers with the same factory, but the transport must be able to handle multiple connections. -/// -/// You must register handlers for all supported capabilities on the server instance, before calling BeginListeningAsync. +/// Provides a factory for creating instances. /// -public class McpServerFactory : IMcpServerFactory +public static class McpServerFactory { - private readonly IServerTransport _serverTransport; - private readonly McpServerOptions _options; - private readonly ILoggerFactory _loggerFactory; - private readonly McpServerDelegates? _serverDelegates; - private readonly IServiceProvider? _serviceProvider; - /// /// Initializes a new instance of the class. /// /// Transport to use for the server - /// Configuration options for this server, including capabilities. - /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// + /// Configuration options for this server, including capabilities. + /// Make sure to accurately reflect exactly what capabilities the server supports and does not support. + /// /// Optional service provider to create new instances. /// Logger factory to use for logging - /// - public McpServerFactory( + /// An that's started and ready to receive connections. + /// is . + /// is . + public static IMcpServer Create( IServerTransport serverTransport, - McpServerOptions options, + McpServerOptions serverOptions, ILoggerFactory? loggerFactory = null, - IOptions? serverDelegates = null, IServiceProvider? serviceProvider = null) { - _serverTransport = serverTransport ?? throw new ArgumentNullException(nameof(serverTransport)); - _options = options ?? throw new ArgumentNullException(nameof(options)); - _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; - _serverDelegates = serverDelegates?.Value; - _serviceProvider = serviceProvider; - } - - /// - /// Creates a new server instance. - /// - /// NB! You must register handlers for all supported capabilities on the server instance, before calling BeginListeningAsync. - /// - public IMcpServer CreateServer() - { - var server = new McpServer(_serverTransport, _options, _loggerFactory, _serviceProvider); - - _serverDelegates?.Apply(server); + Throw.IfNull(serverTransport); + Throw.IfNull(serverOptions); - return server; + return new McpServer(serverTransport, serverOptions, loggerFactory, serviceProvider); } } diff --git a/src/mcpdotnet/Server/McpServerHandlers.cs b/src/mcpdotnet/Server/McpServerHandlers.cs new file mode 100644 index 0000000..507b9b4 --- /dev/null +++ b/src/mcpdotnet/Server/McpServerHandlers.cs @@ -0,0 +1,135 @@ +using McpDotNet.Protocol.Types; + +namespace McpDotNet.Server; + +/// +/// Container for handlers used in the creation of an MCP server. +/// +public sealed class McpServerHandlers +{ + /// + /// Gets or sets the handler for list tools requests. + /// + public Func, CancellationToken, Task>? ListToolsHandler { get; set; } + + /// + /// Gets or sets the handler for call tool requests. + /// + public Func, CancellationToken, Task>? CallToolHandler { get; set; } + + /// + /// Gets or sets the handler for list prompts requests. + /// + public Func, CancellationToken, Task>? ListPromptsHandler { get; set; } + + /// + /// Gets or sets the handler for get prompt requests. + /// + public Func, CancellationToken, Task>? GetPromptHandler { get; set; } + + /// + /// Gets or sets the handler for list resources requests. + /// + public Func, CancellationToken, Task>? ListResourcesHandler { get; set; } + + /// + /// Gets or sets the handler for read resources requests. + /// + public Func, CancellationToken, Task>? ReadResourceHandler { get; set; } + + /// + /// Gets or sets the handler for get completion requests. + /// + public Func, CancellationToken, Task>? GetCompletionHandler { get; set; } + + /// + /// Gets or sets the handler for subscribe to resources messages. + /// + public Func, CancellationToken, Task>? SubscribeToResourcesHandler { get; set; } + + /// + /// Gets or sets the handler for unsubscribe from resources messages. + /// + public Func, CancellationToken, Task>? UnsubscribeFromResourcesHandler { get; set; } + + /// + /// Overwrite any handlers in McpServerOptions with non-null handlers from this instance. + /// + /// + /// + internal McpServerOptions OverwriteWithSetHandlers(McpServerOptions options) + { + PromptsCapability? promptsCapability = options.Capabilities?.Prompts; + if (ListPromptsHandler is not null || GetPromptHandler is not null) + { + promptsCapability = promptsCapability is null ? + new() + { + ListPromptsHandler = ListPromptsHandler, + GetPromptHandler = GetPromptHandler, + } : + promptsCapability with + { + ListPromptsHandler = ListPromptsHandler ?? promptsCapability.ListPromptsHandler, + GetPromptHandler = GetPromptHandler ?? promptsCapability.GetPromptHandler, + }; + } + + ResourcesCapability? resourcesCapability = options.Capabilities?.Resources; + if (ListResourcesHandler is not null || + ReadResourceHandler is not null || + SubscribeToResourcesHandler is not null || + UnsubscribeFromResourcesHandler is not null) + { + resourcesCapability = resourcesCapability is null ? + new() + { + ListResourcesHandler = ListResourcesHandler, + ReadResourceHandler = ReadResourceHandler, + SubscribeToResourcesHandler = SubscribeToResourcesHandler, + UnsubscribeFromResourcesHandler = UnsubscribeFromResourcesHandler, + } : + resourcesCapability with + { + ListResourcesHandler = ListResourcesHandler ?? resourcesCapability.ListResourcesHandler, + ReadResourceHandler = ReadResourceHandler ?? resourcesCapability.ReadResourceHandler, + SubscribeToResourcesHandler = SubscribeToResourcesHandler ?? resourcesCapability.SubscribeToResourcesHandler, + UnsubscribeFromResourcesHandler = UnsubscribeFromResourcesHandler ?? resourcesCapability.UnsubscribeFromResourcesHandler, + }; + } + + ToolsCapability? toolsCapability = options.Capabilities?.Tools; + if (ListToolsHandler is not null || CallToolHandler is not null) + { + toolsCapability = toolsCapability is null ? + new() + { + ListToolsHandler = ListToolsHandler, + CallToolHandler = CallToolHandler, + } : + toolsCapability with + { + ListToolsHandler = ListToolsHandler ?? toolsCapability.ListToolsHandler, + CallToolHandler = CallToolHandler ?? toolsCapability.CallToolHandler, + }; + } + + return options with + { + GetCompletionHandler = GetCompletionHandler ?? options.GetCompletionHandler, + Capabilities = options.Capabilities is null ? + new() + { + Prompts = promptsCapability, + Resources = resourcesCapability, + Tools = toolsCapability, + } : + options.Capabilities with + { + Prompts = promptsCapability, + Resources = resourcesCapability, + Tools = toolsCapability, + }, + }; + } +} diff --git a/src/mcpdotnet/Server/McpServerOptions.cs b/src/mcpdotnet/Server/McpServerOptions.cs index 76f5e1d..40bfa4f 100644 --- a/src/mcpdotnet/Server/McpServerOptions.cs +++ b/src/mcpdotnet/Server/McpServerOptions.cs @@ -1,5 +1,6 @@  using McpDotNet.Protocol.Types; +using System.Text.Json.Serialization; namespace McpDotNet.Server; @@ -34,4 +35,10 @@ public record McpServerOptions /// Optional server instructions to send to clients /// public string ServerInstructions { get; init; } = string.Empty; + + /// + /// Gets or sets the handler for get completion requests. + /// + [JsonIgnore] + public Func, CancellationToken, Task>? GetCompletionHandler { get; init; } } diff --git a/src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs b/src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs index 85eb384..880668b 100644 --- a/src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs +++ b/src/mcpdotnet/Shared/McpJsonRpcEndpoint.cs @@ -4,8 +4,10 @@ using McpDotNet.Logging; using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Transport; +using McpDotNet.Utils; using McpDotNet.Utils.Json; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace McpDotNet.Shared; @@ -24,7 +26,7 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable private readonly Dictionary>> _requestHandlers = []; private int _nextRequestId; private readonly JsonSerializerOptions _jsonOptions; - private readonly ILogger _logger; + private readonly ILogger _logger; private bool _isDisposed; /// @@ -32,19 +34,18 @@ internal abstract class McpJsonRpcEndpoint : IAsyncDisposable /// /// An MCP transport implementation. /// The logger factory. - protected McpJsonRpcEndpoint(ITransport transport, ILoggerFactory loggerFactory) + protected McpJsonRpcEndpoint(ITransport transport, ILoggerFactory? loggerFactory = null) { - if (transport is null) - { - throw new ArgumentNullException(nameof(transport)); - } + Throw.IfNull(transport); + + loggerFactory ??= NullLoggerFactory.Instance; _transport = transport; _pendingRequests = new(); _notificationHandlers = new(); _nextRequestId = 1; _jsonOptions = JsonSerializerOptionsExtensions.DefaultOptions; - _logger = loggerFactory.CreateLogger(); + _logger = (ILogger?)loggerFactory?.CreateLogger() ?? NullLogger.Instance; } /// @@ -282,10 +283,7 @@ public async Task SendRequestAsync(JsonRpcRequest request, Can public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { - if (message is null) - { - throw new ArgumentNullException(nameof(message)); - } + Throw.IfNull(message); if (!_transport.IsConnected) { @@ -333,15 +331,8 @@ public async ValueTask DisposeAsync() /// Handler to be called when a request with specified method identifier is received protected void SetRequestHandler(string method, Func> handler) { - if (method is null) - { - throw new ArgumentNullException(nameof(method)); - } - - if (handler is null) - { - throw new ArgumentNullException(nameof(handler)); - } + Throw.IfNull(method); + Throw.IfNull(handler); _requestHandlers[method] = async (request) => { diff --git a/src/mcpdotnet/Utils/Throw.cs b/src/mcpdotnet/Utils/Throw.cs new file mode 100644 index 0000000..5671759 --- /dev/null +++ b/src/mcpdotnet/Utils/Throw.cs @@ -0,0 +1,41 @@ +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace McpDotNet.Utils; + +/// Provides helper methods for throwing exceptions. +internal static class Throw +{ + // NOTE: Most of these should be replaced with extension statics for the relevant extension + // type as downlevel polyfills once the C# 14 extension everything feature is available. + + public static void IfNull([NotNull] object? arg, [CallerArgumentExpression(nameof(arg))] string? parameterName = null) + { + if (arg is null) + { + ThrowArgumentNullException(parameterName); + } + } + + public static void IfNullOrWhiteSpace([NotNull] string? arg, [CallerArgumentExpression(nameof(arg))] string? parameterName = null) + { + if (arg is null || arg.AsSpan().IsWhiteSpace()) + { + ThrowArgumentNullOrWhiteSpaceException(arg); + } + } + + [DoesNotReturn] + private static void ThrowArgumentNullOrWhiteSpaceException(string? parameterName) + { + if (parameterName is null) + { + ThrowArgumentNullException(parameterName); + } + + throw new ArgumentException("Value cannot be empty or composed entirely of whitespace.", parameterName); + } + + [DoesNotReturn] + private static void ThrowArgumentNullException(string? parameterName) => throw new ArgumentNullException(parameterName); +} diff --git a/tests/McpDotNet.Extensions.AI.Tests/IntegrationTests.cs b/tests/McpDotNet.Extensions.AI.Tests/IntegrationTests.cs index b88d6a8..8931b03 100644 --- a/tests/McpDotNet.Extensions.AI.Tests/IntegrationTests.cs +++ b/tests/McpDotNet.Extensions.AI.Tests/IntegrationTests.cs @@ -1,7 +1,6 @@ using McpDotNet.Client; using McpDotNet.Configuration; using Microsoft.Extensions.AI; -using Microsoft.Extensions.Logging.Abstractions; using OpenAI; namespace McpDotNet.Extensions.AI.Tests; @@ -35,19 +34,13 @@ private static async Task GetMcpClientAsync() ClientInfo = new() { Name = "McpDotNet.Extensions.AI.Tests", Version = "1.0.0" } }; - var factory = new McpClientFactory( - [GetEverythingServerConfig()], - options, - NullLoggerFactory.Instance - ); - - return await factory.GetClientAsync("everything"); + return await McpClientFactory.CreateAsync(GetEverythingServerConfig(), options); } [Fact] public async Task IntegrateWithMeai_UsingEverythingServer_ToolsAreProperlyCalled() { - var client = await GetMcpClientAsync(); + await using var client = await GetMcpClientAsync(); var mappedTools = await client.ListToolsAsync().Select(t => t.ToAITool(client)).ToListAsync(); IChatClient openaiClient = new OpenAIClient(_openAIKey) diff --git a/tests/mcpdotnet.TestServer/Program.cs b/tests/mcpdotnet.TestServer/Program.cs index 7d77bb2..bc6bd1b 100644 --- a/tests/mcpdotnet.TestServer/Program.cs +++ b/tests/mcpdotnet.TestServer/Program.cs @@ -34,24 +34,17 @@ private static async Task Main(string[] args) ServerInfo = new Implementation() { Name = "TestServer", Version = "1.0.0" }, Capabilities = new ServerCapabilities() { - Tools = new(), - Resources = new(), - Prompts = new(), + Tools = ConfigureTools(), + Resources = ConfigureResources(), + Prompts = ConfigurePrompts(), }, ProtocolVersion = "2024-11-05", ServerInstructions = "This is a test server with only stub functionality", + GetCompletionHandler = ConfigureCompletion(), }; var loggerFactory = CreateLoggerFactory(); - McpServerFactory factory = new(new StdioServerTransport("TestServer", loggerFactory), options, loggerFactory); - IMcpServer server = factory.CreateServer(); - - Log.Logger.Information("Server object created, registering handlers."); - - ConfigureTools(server); - ConfigureResources(server); - ConfigurePrompts(server); - ConfigureCompletion(server); + await using IMcpServer server = McpServerFactory.Create(new StdioServerTransport("TestServer", loggerFactory), options, loggerFactory); Log.Logger.Information("Server initialized."); @@ -60,184 +53,187 @@ private static async Task Main(string[] args) Log.Logger.Information("Server started."); // Run until process is stopped by the client (parent process) - while (true) // NOSONAR - { - await Task.Delay(1000); - } + await Task.Delay(Timeout.Infinite); } - private static void ConfigureTools(IMcpServer server) + private static ToolsCapability ConfigureTools() { - server.SetListToolsHandler((request, cancellationToken) => + return new() { - return Task.FromResult(new ListToolsResult() + ListToolsHandler = (request, cancellationToken) => { - Tools = - [ - new Tool() - { - Name = "echo", - Description = "Echoes the input back to the client.", - InputSchema = new JsonSchema() + return Task.FromResult(new ListToolsResult() + { + Tools = + [ + new Tool() { - Type = "object", - Properties = new Dictionary() + Name = "echo", + Description = "Echoes the input back to the client.", + InputSchema = new JsonSchema() { - ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } - } + Type = "object", + Properties = new Dictionary() + { + ["message"] = new JsonSchemaProperty() { Type = "string", Description = "The input to echo back." } + } + }, }, - }, - new Tool() - { - Name = "sampleLLM", - Description = "Samples from an LLM using MCP's sampling feature.", - InputSchema = new JsonSchema() + new Tool() { - Type = "object", - Properties = new Dictionary() + Name = "sampleLLM", + Description = "Samples from an LLM using MCP's sampling feature.", + InputSchema = new JsonSchema() { - ["prompt"] = new JsonSchemaProperty() { Type = "string", Description = "The prompt to send to the LLM" }, - ["maxTokens"] = new JsonSchemaProperty() { Type = "number", Description = "Maximum number of tokens to generate" } - } - }, - } - ] - }); - }); + Type = "object", + Properties = new Dictionary() + { + ["prompt"] = new JsonSchemaProperty() { Type = "string", Description = "The prompt to send to the LLM" }, + ["maxTokens"] = new JsonSchemaProperty() { Type = "number", Description = "Maximum number of tokens to generate" } + } + }, + } + ] + }); + }, - server.SetCallToolHandler(async (request, cancellationToken) => - { - if (request.Params?.Name == "echo") + CallToolHandler = async (request, cancellationToken) => { - if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) + if (request.Params?.Name == "echo") { - throw new McpServerException("Missing required argument 'message'"); + if (request.Params?.Arguments is null || !request.Params.Arguments.TryGetValue("message", out var message)) + { + throw new McpServerException("Missing required argument 'message'"); + } + return new CallToolResponse() + { + Content = [new Content() { Text = "Echo: " + message?.ToString(), Type = "text" }] + }; } - return new CallToolResponse() - { - Content = [new Content() { Text = "Echo: " + message?.ToString(), Type = "text" }] - }; - } - else if (request.Params?.Name == "sampleLLM") - { - if (request.Params?.Arguments is null || - !request.Params.Arguments.TryGetValue("prompt", out var prompt) || - !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) + else if (request.Params?.Name == "sampleLLM") { - throw new McpServerException("Missing required arguments 'prompt' and 'maxTokens'"); - } - var sampleResult = await server.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), - cancellationToken); + if (request.Params?.Arguments is null || + !request.Params.Arguments.TryGetValue("prompt", out var prompt) || + !request.Params.Arguments.TryGetValue("maxTokens", out var maxTokens)) + { + throw new McpServerException("Missing required arguments 'prompt' and 'maxTokens'"); + } + var sampleResult = await request.Server.RequestSamplingAsync(CreateRequestSamplingParams(prompt?.ToString() ?? "", "sampleLLM", Convert.ToInt32(maxTokens?.ToString())), + cancellationToken); - return new CallToolResponse() + return new CallToolResponse() + { + Content = [new Content() { Text = $"LLM sampling result: {sampleResult.Content.Text}", Type = "text" }] + }; + } + else { - Content = [new Content() { Text = $"LLM sampling result: {sampleResult.Content.Text}", Type = "text" }] - }; - } - else - { - throw new McpServerException($"Unknown tool: {request.Params?.Name}"); + throw new McpServerException($"Unknown tool: {request.Params?.Name}"); + } } - }); + }; } - private static void ConfigurePrompts(IMcpServer server) + private static PromptsCapability ConfigurePrompts() { - server.SetListPromptsHandler((request, cancellationToken) => - { - return Task.FromResult(new ListPromptsResult() - { - Prompts = [ - new Prompt() - { - Name = "simple_prompt", - Description = "A prompt without arguments" - }, - new Prompt() - { - Name = "complex_prompt", - Description = "A prompt with arguments", - Arguments = - [ - new PromptArgument() - { - Name = "temperature", - Description = "Temperature setting", - Required = true - }, - new PromptArgument() - { - Name = "style", - Description = "Output style", - Required = false - } - ] - } - ] - }); - }); - - server.SetGetPromptHandler((request, cancellationToken) => + return new() { - List messages = new(); - if (request.Params?.Name == "simple_prompt") + ListPromptsHandler = (request, cancellationToken) => { - messages.Add(new PromptMessage() + return Task.FromResult(new ListPromptsResult() { - Role = Role.User, - Content = new Content() - { - Type = "text", - Text = "This is a simple prompt without arguments." - } + Prompts = [ + new Prompt() + { + Name = "simple_prompt", + Description = "A prompt without arguments" + }, + new Prompt() + { + Name = "complex_prompt", + Description = "A prompt with arguments", + Arguments = + [ + new PromptArgument() + { + Name = "temperature", + Description = "Temperature setting", + Required = true + }, + new PromptArgument() + { + Name = "style", + Description = "Output style", + Required = false + } + ] + } + ] }); - } - else if (request.Params?.Name == "complex_prompt") + }, + + GetPromptHandler = (request, cancellationToken) => { - string temperature = request.Params.Arguments?["temperature"]?.ToString() ?? "unknown"; - string style = request.Params.Arguments?["style"]?.ToString() ?? "unknown"; - messages.Add(new PromptMessage() + List messages = []; + if (request.Params?.Name == "simple_prompt") { - Role = Role.User, - Content = new Content() + messages.Add(new PromptMessage() { - Type = "text", - Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" - } - }); - messages.Add(new PromptMessage() + Role = Role.User, + Content = new Content() + { + Type = "text", + Text = "This is a simple prompt without arguments." + } + }); + } + else if (request.Params?.Name == "complex_prompt") { - Role = Role.Assistant, - Content = new Content() + string temperature = request.Params.Arguments?["temperature"]?.ToString() ?? "unknown"; + string style = request.Params.Arguments?["style"]?.ToString() ?? "unknown"; + messages.Add(new PromptMessage() { - Type = "text", - Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" - } - }); - messages.Add(new PromptMessage() - { - Role = Role.User, - Content = new Content() + Role = Role.User, + Content = new Content() + { + Type = "text", + Text = $"This is a complex prompt with arguments: temperature={temperature}, style={style}" + } + }); + messages.Add(new PromptMessage() { - Type = "image", - Data = MCP_TINY_IMAGE, - MimeType = "image/png" - } + Role = Role.Assistant, + Content = new Content() + { + Type = "text", + Text = "I understand. You've provided a complex prompt with temperature and style arguments. How would you like me to proceed?" + } + }); + messages.Add(new PromptMessage() + { + Role = Role.User, + Content = new Content() + { + Type = "image", + Data = MCP_TINY_IMAGE, + MimeType = "image/png" + } + }); + } + else + { + throw new McpServerException($"Unknown prompt: {request.Params?.Name}"); + } + + return Task.FromResult(new GetPromptResult() + { + Messages = messages }); } - else - { - throw new McpServerException($"Unknown prompt: {request.Params?.Name}"); - } - - return Task.FromResult(new GetPromptResult() - { - Messages = messages - }); - }); + }; } - private static void ConfigureResources(IMcpServer server) + private static ResourcesCapability ConfigureResources() { List resources = []; List resourceContents = []; @@ -279,54 +275,56 @@ private static void ConfigureResources(IMcpServer server) const int pageSize = 10; - server.SetListResourcesHandler((request, cancellationToken) => + return new() { - int startIndex = 0; - request ??= new(server, new()); - if (request.Params?.Cursor is not null) + ListResourcesHandler = (request, cancellationToken) => { - try + int startIndex = 0; + if (request.Params?.Cursor is not null) { - var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(request.Params.Cursor)); - startIndex = Convert.ToInt32(startIndexAsString); + try + { + var startIndexAsString = Encoding.UTF8.GetString(Convert.FromBase64String(request.Params.Cursor)); + startIndex = Convert.ToInt32(startIndexAsString); + } + catch + { + throw new McpServerException("Invalid cursor"); + } } - catch + + int endIndex = Math.Min(startIndex + pageSize, resources.Count); + string? nextCursor = null; + + if (endIndex < resources.Count) { - throw new McpServerException("Invalid cursor"); + nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); } - } - - int endIndex = Math.Min(startIndex + pageSize, resources.Count); - string? nextCursor = null; + return Task.FromResult(new ListResourcesResult() + { + NextCursor = nextCursor, + Resources = resources.GetRange(startIndex, endIndex - startIndex) + }); + }, - if (endIndex < resources.Count) - { - nextCursor = Convert.ToBase64String(Encoding.UTF8.GetBytes(endIndex.ToString())); - } - return Task.FromResult(new ListResourcesResult() + ReadResourceHandler = (request, cancellationToken) => { - NextCursor = nextCursor, - Resources = resources.GetRange(startIndex, endIndex - startIndex) - }); - }); + if (request.Params?.Uri is null) + { + throw new McpServerException("Missing required argument 'uri'"); + } + ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) + ?? throw new McpServerException("Resource not found"); - server.SetReadResourceHandler((request, cancellationToken) => - { - if (request.Params?.Uri is null) - { - throw new McpServerException("Missing required argument 'uri'"); + return Task.FromResult(new ReadResourceResult() + { + Contents = [contents] + }); } - ResourceContents contents = resourceContents.FirstOrDefault(r => r.Uri == request.Params.Uri) - ?? throw new McpServerException("Resource not found"); - - return Task.FromResult(new ReadResourceResult() - { - Contents = [contents] - }); - }); + }; } - private static void ConfigureCompletion(IMcpServer server) + private static Func, CancellationToken, Task> ConfigureCompletion() { List sampleResourceIds = ["1", "2", "3", "4", "5"]; Dictionary> exampleCompletions = new() @@ -335,7 +333,7 @@ private static void ConfigureCompletion(IMcpServer server) {"temperature", ["0", "0.5", "0.7", "1.0"]}, }; - server.SetGetCompletionHandler((request, cancellationToken) => + return (request, cancellationToken) => { if (request.Params?.Ref?.Type == "ref/resource") { @@ -344,7 +342,7 @@ private static void ConfigureCompletion(IMcpServer server) return Task.FromResult(new CompleteResult() { Completion = new() { Values = [] } }); // Filter resource IDs that start with the input value - var values = sampleResourceIds.Where(id => id.StartsWith(request.Params.Argument.Value)).ToArray(); + var values = sampleResourceIds.Where(id => id.StartsWith(request.Params!.Argument.Value)).ToArray(); return Task.FromResult(new CompleteResult() { Completion = new() { Values = values, HasMore = false, Total = values.Length } }); } @@ -360,7 +358,7 @@ private static void ConfigureCompletion(IMcpServer server) } throw new McpServerException($"Unknown reference type: {request.Params?.Ref.Type}"); - }); + }; } static CreateMessageRequestParams CreateRequestSamplingParams(string context, string uri, int maxTokens = 100) diff --git a/tests/mcpdotnet.Tests/Client/McpClientFactoryTests.cs b/tests/mcpdotnet.Tests/Client/McpClientFactoryTests.cs index 3c15020..e76f90a 100644 --- a/tests/mcpdotnet.Tests/Client/McpClientFactoryTests.cs +++ b/tests/mcpdotnet.Tests/Client/McpClientFactoryTests.cs @@ -3,8 +3,7 @@ using McpDotNet.Configuration; using McpDotNet.Protocol.Messages; using McpDotNet.Protocol.Transport; -using Microsoft.Extensions.Logging.Abstractions; -using Moq; +using McpDotNet.Protocol.Types; namespace McpDotNet.Tests.Client; @@ -16,10 +15,43 @@ public class McpClientFactoryTests }; [Fact] - public async Task GetClientAsync_WithValidStdioConfig_CreatesNewClient() + public async Task CreateAsync_WithInvalidArgs_Throws() + { + await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync((McpServerConfig)null!, _defaultOptions)); + + await Assert.ThrowsAsync("clientOptions", () => McpClientFactory.CreateAsync( + new McpServerConfig() + { + Name = "name", + Id = "id", + TransportType = TransportTypes.StdIo, + }, (McpClientOptions)null!)); + + await Assert.ThrowsAsync("serverConfig", () => McpClientFactory.CreateAsync( + new McpServerConfig() + { + Name = "name", + Id = "id", + TransportType = "somethingunsupported", + }, + _defaultOptions)); + + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync( + new McpServerConfig() + { + Name = "name", + Id = "id", + TransportType = TransportTypes.StdIo, + }, + _defaultOptions, + (_, __) => null!)); + } + + [Fact] + public async Task CreateAsync_WithValidStdioConfig_CreatesNewClient() { // Arrange - var config = new McpServerConfig + var serverConfig = new McpServerConfig { Id = "test-server", Name = "Test Server", @@ -32,30 +64,11 @@ public async Task GetClientAsync_WithValidStdioConfig_CreatesNewClient() } }; - // Create a mock transport - var mockTransport = new Mock(); - mockTransport.Setup(t => t.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockTransport.Setup(t => t.IsConnected).Returns(true); - mockTransport.Setup(t => t.MessageReader).Returns(Mock.Of>()); - - // Create a mock client - var mockClient = new Mock(); - mockClient.Setup(c => c.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockClient.Setup(c => c.IsInitialized).Returns(true); - - // Inject the mock transport into the factory - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance, - transportFactoryMethod: _ => mockTransport.Object, - clientFactoryMethod: (_, _, _) => mockClient.Object - ); - // Act - var client = await factory.GetClientAsync("test-server"); + var client = await McpClientFactory.CreateAsync( + serverConfig, + _defaultOptions, + (_, __) => new NopTransport()); // Assert Assert.NotNull(client); @@ -63,137 +76,33 @@ public async Task GetClientAsync_WithValidStdioConfig_CreatesNewClient() } [Fact] - public async Task GetClientAsync_CalledTwice_ReturnsSameInstance() + public async Task CreateAsync_WithNoTransportOptions_CreatesNewClient() { // Arrange - var config = new McpServerConfig + var serverConfig = new McpServerConfig { Id = "test-server", Name = "Test Server", TransportType = TransportTypes.StdIo, - Location = "/path/to/server" + Location = "/path/to/server", }; - // Create a mock transport - var mockTransport = new Mock(); - mockTransport.Setup(t => t.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockTransport.Setup(t => t.IsConnected).Returns(true); - mockTransport.Setup(t => t.MessageReader).Returns(Mock.Of>()); - - // Create a mock client - var mockClient = new Mock(); - mockClient.Setup(c => c.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockClient.Setup(c => c.IsInitialized).Returns(true); - - using var factory = new McpClientFactory([config], - _defaultOptions, - NullLoggerFactory.Instance, - transportFactoryMethod: _ => mockTransport.Object, - clientFactoryMethod: (_, _, _) => mockClient.Object); - // Act - var client1 = await factory.GetClientAsync("test-server"); - var client2 = await factory.GetClientAsync("test-server"); - - // Assert - Assert.Same(client1, client2); - } - - [Fact] - public async Task GetClientAsync_WithInvalidServerId_ThrowsArgumentException() - { - // Arrange - using var factory = new McpClientFactory([], + var client = await McpClientFactory.CreateAsync( + serverConfig, _defaultOptions, - NullLoggerFactory.Instance); - - // Act & Assert - await Assert.ThrowsAsync( - () => factory.GetClientAsync("non-existent-server") - ); - } - - [Fact] - public async Task GetClientAsync_WithUnsupportedTransport_ThrowsArgumentException() - { - // Arrange - var config = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = "unsupported", - Location = "/path/to/server" - }; - - using var factory = new McpClientFactory([config], _defaultOptions, - NullLoggerFactory.Instance); - - // Act & Assert - await Assert.ThrowsAsync( - () => factory.GetClientAsync("test-server") - ); - } - - [Fact] - public async Task GetClientAsync_WithNoTransportOptions_CreatesClientWithDefaults() - { - // Arrange - var config = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.StdIo, - Location = "/path/to/server" - }; - - // Create a mock transport - var mockTransport = new Mock(); - mockTransport.Setup(t => t.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockTransport.Setup(t => t.IsConnected).Returns(true); - mockTransport.Setup(t => t.MessageReader).Returns(Mock.Of>()); - - // Create a mock client - var mockClient = new Mock(); - mockClient.Setup(c => c.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockClient.Setup(c => c.IsInitialized).Returns(true); - - using var factory = new McpClientFactory([config], - _defaultOptions, - NullLoggerFactory.Instance, - transportFactoryMethod: _ => mockTransport.Object, - clientFactoryMethod: (_, _, _) => mockClient.Object); - - // Act - var client = await factory.GetClientAsync("test-server"); + (_, __) => new NopTransport()); // Assert Assert.NotNull(client); + // We could add more assertions here about the client's configuration } [Fact] - public void Constructor_WithDuplicateServerIds_ThrowsArgumentException() - { - // Arrange - McpServerConfig[] configs = - [ - new McpServerConfig { Id = "duplicate", Name = "duplicate", TransportType = TransportTypes.StdIo, Location = "/path1" }, - new McpServerConfig { Id = "duplicate", Name = "duplicate", TransportType = TransportTypes.StdIo, Location = "/path2" } - ]; - - // Act & Assert - Assert.Throws(() => new McpClientFactory(configs, _defaultOptions, - NullLoggerFactory.Instance)); - } - - [Fact] - public async Task GetClientAsync_WithSseTransport_CanCreateClient() + public async Task CreateAsync_WithValidSseConfig_CreatesNewClient() { // Arrange - var config = new McpServerConfig + var serverConfig = new McpServerConfig { Id = "test-server", Name = "Test Server", @@ -201,40 +110,22 @@ public async Task GetClientAsync_WithSseTransport_CanCreateClient() Location = "http://localhost:8080" }; - // Create a mock transport - var mockTransport = new Mock(); - mockTransport.Setup(t => t.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockTransport.Setup(t => t.IsConnected).Returns(true); - mockTransport.Setup(t => t.MessageReader).Returns(Mock.Of>()); - - // Create a mock client - var mockClient = new Mock(); - mockClient.Setup(c => c.ConnectAsync(It.IsAny())) - .Returns(Task.CompletedTask); - mockClient.Setup(c => c.IsInitialized).Returns(true); - - // Inject the mock transport into the factory - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance, - transportFactoryMethod: _ => mockTransport.Object, - clientFactoryMethod: (_, _, _) => mockClient.Object - ); - // Act - var client = await factory.GetClientAsync("test-server"); + var client = await McpClientFactory.CreateAsync( + serverConfig, + _defaultOptions, + (_, __) => new NopTransport()); // Assert Assert.NotNull(client); + // We could add more assertions here about the client's configuration } [Fact] - public void McpFactory_WithSse_CreatesCorrectTransportOptions() + public async Task CreateAsync_WithSse_CreatesCorrectTransportOptions() { // Arrange - var config = new McpServerConfig + var serverConfig = new McpServerConfig { Id = "test-server", Name = "Test Server", @@ -250,58 +141,23 @@ public void McpFactory_WithSse_CreatesCorrectTransportOptions() }; // Act - using var factory = new McpClientFactory( - [config], + var client = await McpClientFactory.CreateAsync( + serverConfig, _defaultOptions, - NullLoggerFactory.Instance - ); - - var transport = factory.TransportFactoryMethod(config) as SseClientTransport; + (_, __) => new NopTransport()); // Assert - Assert.NotNull(transport); - Assert.Equal(TimeSpan.FromSeconds(10), transport.Options.ConnectionTimeout); - Assert.Equal(2, transport.Options.MaxReconnectAttempts); - Assert.Equal(TimeSpan.FromSeconds(5), transport.Options.ReconnectDelay); - Assert.NotNull(transport.Options.AdditionalHeaders); - Assert.Equal("the_header_value", transport.Options.AdditionalHeaders["test"]); - } - - [Fact] - public void McpFactory_WithSseAndNoOptions_CreatesDefaultTransportOptions() - { - // Arrange - var config = new McpServerConfig - { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.Sse, - Location = "http://localhost:8080" - }; - - var defaultOptions = new SseClientTransportOptions(); - - // Act - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance - ); - - var transport = factory.TransportFactoryMethod(config) as SseClientTransport; - - // Assert - Assert.NotNull(transport); - Assert.True(transport.Options.ConnectionTimeout == defaultOptions.ConnectionTimeout); - Assert.True(transport.Options.MaxReconnectAttempts == defaultOptions.MaxReconnectAttempts); - Assert.True(transport.Options.ReconnectDelay == defaultOptions.ReconnectDelay); - Assert.True(transport.Options.AdditionalHeaders == null && defaultOptions.AdditionalHeaders == null); + Assert.NotNull(client); + // We could add more assertions here about the client's configuration } - [Fact] - public void McpFactory_WithSseAndMissingOptions_CreatesCorrectTransportOptions() + [Theory] + [InlineData("connectionTimeout", "not_a_number")] + [InlineData("maxReconnectAttempts", "invalid")] + [InlineData("reconnectDelay", "bad_value")] + public async Task McpFactory_WithInvalidTransportOptions_ThrowsFormatException(string key, string value) { - // Arrange + // arrange var config = new McpServerConfig { Id = "test-server", @@ -310,59 +166,50 @@ public void McpFactory_WithSseAndMissingOptions_CreatesCorrectTransportOptions() Location = "http://localhost:8080", TransportOptions = new Dictionary { - ["connectionTimeout"] = "10", - ["header.test"] = "the_header_value" + [key] = value } }; - var defaultOptions = new SseClientTransportOptions(); + // act & assert + await Assert.ThrowsAsync(() => McpClientFactory.CreateAsync(config, _defaultOptions)); + } - // Act - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance - ); + private sealed class NopTransport : IClientTransport + { + private readonly Channel _channel = Channel.CreateUnbounded(); - var transport = factory.TransportFactoryMethod(config) as SseClientTransport; + public bool IsConnected => true; - // Assert - Assert.NotNull(transport); - Assert.Equal(TimeSpan.FromSeconds(10), transport.Options.ConnectionTimeout); - Assert.Equal(defaultOptions.MaxReconnectAttempts, transport.Options.MaxReconnectAttempts); - Assert.Equal(defaultOptions.ReconnectDelay, transport.Options.ReconnectDelay); - Assert.NotNull(transport.Options.AdditionalHeaders); - Assert.Equal("the_header_value", transport.Options.AdditionalHeaders["test"]); - } + public ChannelReader MessageReader => _channel.Reader; - [Theory] - [InlineData("connectionTimeout", "not_a_number")] - [InlineData("maxReconnectAttempts", "invalid")] - [InlineData("reconnectDelay", "bad_value")] - public void McpFactory_WithInvalidTransportOptions_ThrowsFormatException(string key, string value) - { - // arrange - var config = new McpServerConfig + public Task ConnectAsync(CancellationToken cancellationToken = default) => + Task.CompletedTask; + + public ValueTask DisposeAsync() => default; + + public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) { - Id = "test-server", - Name = "Test Server", - TransportType = TransportTypes.Sse, - Location = "http://localhost:8080", - TransportOptions = new Dictionary + switch (message) { - [key] = value + case JsonRpcRequest request: + _channel.Writer.TryWrite(new JsonRpcResponse + { + Id = ((JsonRpcRequest)message).Id, + Result = new InitializeResult() + { + Capabilities = new ServerCapabilities(), + ProtocolVersion = "2024-11-05", + ServerInfo = new Implementation() + { + Name = "NopTransport", + Version = "1.0.0" + }, + } + }); + break; } - }; - // Act - using var factory = new McpClientFactory( - [config], - _defaultOptions, - NullLoggerFactory.Instance - ); - - // act & assert - Assert.Throws(() => - factory.TransportFactoryMethod(config)); + return Task.CompletedTask; + } } } diff --git a/tests/mcpdotnet.Tests/ClientIntegrationTestFixture.cs b/tests/mcpdotnet.Tests/ClientIntegrationTestFixture.cs index 1ec22ff..2c271c7 100644 --- a/tests/mcpdotnet.Tests/ClientIntegrationTestFixture.cs +++ b/tests/mcpdotnet.Tests/ClientIntegrationTestFixture.cs @@ -8,8 +8,11 @@ namespace McpDotNet.Tests; public class ClientIntegrationTestFixture : IDisposable { public ILoggerFactory LoggerFactory { get; } - public McpClientFactory Factory { get; } public McpClientOptions DefaultOptions { get; } + public McpServerConfig EverythingServerConfig { get; } + public McpServerConfig TestServerConfig { get; } + + public static IEnumerable ClientIds => ["everything", "test_server"]; public ClientIntegrationTestFixture() { @@ -20,10 +23,9 @@ public ClientIntegrationTestFixture() DefaultOptions = new() { ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" }, - Capabilities = new() { Sampling = new(), Roots = new() } }; - var everythingServerConfig = new McpServerConfig + EverythingServerConfig = new() { Id = "everything", Name = "Everything", @@ -36,7 +38,7 @@ public ClientIntegrationTestFixture() } }; - var testServerConfig = new McpServerConfig + TestServerConfig = new() { Id = "test_server", Name = "TestServer", @@ -49,19 +51,21 @@ public ClientIntegrationTestFixture() }; if (!OperatingSystem.IsWindows()) - testServerConfig.TransportOptions["arguments"] = "TestServer.dll"; - - // Inject the mock transport into the factory - Factory = new McpClientFactory( - [everythingServerConfig, testServerConfig], - DefaultOptions, - LoggerFactory - ); + { + TestServerConfig.TransportOptions["arguments"] = "TestServer.dll"; + } } + public Task CreateClientAsync(string clientId, McpClientOptions? clientOptions = null) => + McpClientFactory.CreateAsync(clientId switch + { + "everything" => EverythingServerConfig, + "test_server" => TestServerConfig, + _ => throw new ArgumentException($"Unknown client ID: {clientId}") + }, clientOptions ?? DefaultOptions, loggerFactory: LoggerFactory); + public void Dispose() { - Factory?.Dispose(); LoggerFactory?.Dispose(); GC.SuppressFinalize(this); } diff --git a/tests/mcpdotnet.Tests/ClientIntegrationTests.cs b/tests/mcpdotnet.Tests/ClientIntegrationTests.cs index cb0a728..7e293e5 100644 --- a/tests/mcpdotnet.Tests/ClientIntegrationTests.cs +++ b/tests/mcpdotnet.Tests/ClientIntegrationTests.cs @@ -15,11 +15,8 @@ public ClientIntegrationTests(ClientIntegrationTestFixture fixture) _fixture = fixture; } - public static IEnumerable GetClients() - { - yield return ["everything"]; - yield return ["test_server"]; - } + public static IEnumerable GetClients() => + ClientIntegrationTestFixture.ClientIds.Select(id => new object[] { id }); [Theory] [MemberData(nameof(GetClients))] @@ -28,7 +25,7 @@ public async Task ConnectAndPing_Stdio(string clientId) // Arrange // Act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); await client.PingAsync(CancellationToken.None); // Assert @@ -42,7 +39,7 @@ public async Task Connect_ShouldProvideServerFields(string clientId) // Arrange // Act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); // Assert Assert.NotNull(client.ServerCapabilities); @@ -58,7 +55,7 @@ public async Task ListTools_Stdio(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var tools = await client.ListToolsAsync().ToListAsync(); // assert @@ -73,7 +70,7 @@ public async Task CallTool_Stdio_EchoServer(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.CallToolAsync( "echo", new Dictionary @@ -97,7 +94,7 @@ public async Task ListPrompts_Stdio(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var prompts = await client.ListPromptsAsync().ToListAsync(); // assert @@ -114,7 +111,7 @@ public async Task GetPrompt_Stdio_SimplePrompt(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.GetPromptAsync("simple_prompt", null, CancellationToken.None); // assert @@ -129,7 +126,7 @@ public async Task GetPrompt_Stdio_ComplexPrompt(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var arguments = new Dictionary { { "temperature", "0.7" }, @@ -149,7 +146,7 @@ public async Task GetPrompt_NonExistent_ThrowsException(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); await Assert.ThrowsAsync(() => client.GetPromptAsync("non_existent_prompt", null, CancellationToken.None)); } @@ -161,7 +158,7 @@ public async Task ListResources_Stdio(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); List allResources = []; string? cursor = null; @@ -184,7 +181,7 @@ public async Task ReadResource_Stdio_TextResource(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); // Odd numbered resources are text in the everything server (despite the docs saying otherwise) // 1 is index 0, which is "even" in the 0-based index var result = await client.ReadResourceAsync("test://static/resource/1", CancellationToken.None); @@ -201,7 +198,7 @@ public async Task ReadResource_Stdio_BinaryResource(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); // Even numbered resources are binary in the everything server (despite the docs saying otherwise) // 2 is index 1, which is "odd" in the 0-based index var result = await client.ReadResourceAsync("test://static/resource/2", CancellationToken.None); @@ -218,7 +215,7 @@ public async Task GetCompletion_Stdio_ResourceReference(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.GetCompletionAsync(new Reference { Type = "ref/resource", @@ -240,7 +237,7 @@ public async Task GetCompletion_Stdio_PromptReference(string clientId) // arrange // act - var client = await _fixture.Factory.GetClientAsync(clientId); + await using var client = await _fixture.CreateClientAsync(clientId); var result = await client.GetCompletionAsync(new Reference { Type = "ref/prompt", @@ -258,24 +255,32 @@ public async Task GetCompletion_Stdio_PromptReference(string clientId) [Theory] [MemberData(nameof(GetClients))] public async Task Sampling_Stdio(string clientId) - { - var client = await _fixture.Factory.GetClientAsync(clientId); - + { // Set up the sampling handler int samplingHandlerCalls = 0; - client.SetSamplingHandler((_, _) => + await using var client = await _fixture.CreateClientAsync(clientId, new() { - samplingHandlerCalls++; - return Task.FromResult(new CreateMessageResult + ClientInfo = new() { Name = "Sampling_Stdio", Version = "1.0.0" }, + Capabilities = new() { - Model = "test-model", - Role = "assistant", - Content = new Content + Sampling = new() { - Type = "text", - Text = "Test response" - } - }); + SamplingHandler = (_, _) => + { + samplingHandlerCalls++; + return Task.FromResult(new CreateMessageResult + { + Model = "test-model", + Role = "assistant", + Content = new Content + { + Type = "text", + Text = "Test response" + } + }); + }, + }, + }, }); // Call the server's sampleLLM tool which should trigger our sampling handler @@ -306,7 +311,7 @@ public async Task Sampling_Stdio(string clientId) // new() { Uri = "file:///test/root2", Name = "Test Root 2" } // }; - // var client = await _fixture.Factory.GetClientAsync(clientId); + // await using var client = await _fixture.Factory.GetClientAsync(clientId); // // Set up the roots handler // client.SetRootsHandler((request, ct) => @@ -330,9 +335,7 @@ public async Task Sampling_Stdio(string clientId) [MemberData(nameof(GetClients))] public async Task Notifications_Stdio(string clientId) { - var client = await _fixture.Factory.GetClientAsync(clientId); - - await client.ConnectAsync(); + await using var client = await _fixture.CreateClientAsync(clientId); // Verify we can send notifications without errors await client.SendNotificationAsync(NotificationMethods.RootsUpdatedNotification); @@ -347,7 +350,7 @@ public async Task Notifications_Stdio(string clientId) public async Task CallTool_Stdio_MemoryServer() { // arrange - var config = new McpServerConfig + McpServerConfig serverConfig = new() { Id = "memory", Name = "memory", @@ -359,21 +362,17 @@ public async Task CallTool_Stdio_MemoryServer() } }; - var options = new McpClientOptions + McpClientOptions clientOptions = new() { ClientInfo = new() { Name = "IntegrationTestClient", Version = "1.0.0" } }; - using var factory = new McpClientFactory([config], options, _fixture.LoggerFactory); - var client = await factory.GetClientAsync("memory"); - - await client.ConnectAsync(); + await using var client = await McpClientFactory.CreateAsync(serverConfig, clientOptions, loggerFactory: _fixture.LoggerFactory); // act var result = await client.CallToolAsync( "read_graph", - [], - CancellationToken.None + [] ); // assert diff --git a/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs b/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs index bce2560..eeafa67 100644 --- a/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs +++ b/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs @@ -27,7 +27,7 @@ public void WithListToolsHandler_Sets_Handler() _builder.Object.WithListToolsHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.ListToolsHandler); } @@ -40,7 +40,7 @@ public void WithCallToolHandler_Sets_Handler() _builder.Object.WithCallToolHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.CallToolHandler); } @@ -53,7 +53,7 @@ public void WithListPromptsHandler_Sets_Handler() _builder.Object.WithListPromptsHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.ListPromptsHandler); } @@ -66,7 +66,7 @@ public void WithGetPromptHandler_Sets_Handler() _builder.Object.WithGetPromptHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.GetPromptHandler); } @@ -79,7 +79,7 @@ public void WithListResourcesHandler_Sets_Handler() _builder.Object.WithListResourcesHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.ListResourcesHandler); } @@ -92,7 +92,7 @@ public void WithReadResourceHandler_Sets_Handler() _builder.Object.WithReadResourceHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.ReadResourceHandler); } @@ -105,7 +105,7 @@ public void WithGetCompletionHandler_Sets_Handler() _builder.Object.WithGetCompletionHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.GetCompletionHandler); } @@ -118,7 +118,7 @@ public void WithSubscribeToResourcesHandler_Sets_Handler() _builder.Object.WithSubscribeToResourcesHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.SubscribeToResourcesHandler); } @@ -131,7 +131,7 @@ public void WithUnsubscribeFromResourcesHandler_Sets_Handler() _builder.Object.WithUnsubscribeFromResourcesHandler(handler); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.Equal(handler, options.UnsubscribeFromResourcesHandler); } diff --git a/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 6d8b4f9..eaacf09 100644 --- a/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/mcpdotnet.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -26,7 +26,7 @@ public void Adds_Tools_To_Server() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; Assert.NotNull(options.ListToolsHandler); Assert.NotNull(options.CallToolHandler); @@ -38,7 +38,7 @@ public async Task Can_List_Registered_Tool() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); Assert.NotNull(result); @@ -69,7 +69,7 @@ public async Task Can_Call_Registered_Tool() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "Echo", Arguments = new() { { "message", "Peter" } } }), CancellationToken.None); Assert.NotNull(result); @@ -86,7 +86,7 @@ public async Task Can_Call_Registered_Tool_With_Array_Result() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "EchoArray", Arguments = new() { { "message", "Peter" } } }), CancellationToken.None); Assert.NotNull(result); @@ -103,7 +103,7 @@ public async Task Can_Call_Registered_Tool_With_Null_Result() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnNull" }), CancellationToken.None); Assert.NotNull(result); @@ -120,7 +120,7 @@ public async Task Can_Call_Registered_Tool_With_Json_Result() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnJson" }), CancellationToken.None); Assert.NotNull(result); @@ -137,7 +137,7 @@ public async Task Can_Call_Registered_Tool_With_Int_Result() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnInteger" }), CancellationToken.None); Assert.NotNull(result); @@ -154,7 +154,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_Cancellation_Token() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; using var cts = new CancellationTokenSource(); var token = cts.Token; @@ -173,7 +173,7 @@ public async Task Can_Call_Registered_Tool_And_Returns_Cancelled_Response() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; using var cts = new CancellationTokenSource(); var token = cts.Token; @@ -190,7 +190,7 @@ public async Task Can_Call_Registered_Tool_And_Pass_ComplexType() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.CallToolHandler!(new(Mock.Of(), new() { Name = "EchoComplex", Arguments = new() { { "complex", JsonDocument.Parse("{\"Name\": \"Peter\", \"Age\": 25}").RootElement } } }), CancellationToken.None); Assert.NotNull(result); @@ -207,7 +207,7 @@ public async Task Throws_Exception_When_Tool_Fails() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var action = async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "ReturnError" }), CancellationToken.None); @@ -228,7 +228,7 @@ public async Task Can_Call_Registered_Tool_With_Dependency_Injection() _builder.Object.WithTool(); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var mcpServer = new Mock(); mcpServer.SetupGet(s => s.ServiceProvider).Returns(serviceProvider); @@ -249,7 +249,7 @@ public async Task Throws_Exception_On_Unknown_Tool() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var exception = await Assert.ThrowsAsync(async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "NotRegisteredTool" }), CancellationToken.None)); Assert.Equal("Unknown tool: NotRegisteredTool", exception.Message); @@ -261,7 +261,7 @@ public async Task Throws_Exception_Missing_Parameter() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var exception = await Assert.ThrowsAsync(async () => await options.CallToolHandler!(new(Mock.Of(), new() { Name = "Echo" }), CancellationToken.None)); Assert.Equal("Missing required argument 'message'.", exception.Message); @@ -291,7 +291,7 @@ public async Task Register_Tools_From_Current_Assembly() _builder.Object.WithTools(); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); Assert.NotNull(result); @@ -316,7 +316,7 @@ public async Task Recognizes_Parameter_Types() _builder.Object.WithTools(typeof(EchoTool)); var serviceProvider = _services.BuildServiceProvider(); - var options = serviceProvider.GetRequiredService>().Value; + var options = serviceProvider.GetRequiredService>().Value; var result = await options.ListToolsHandler!(new(Mock.Of(), new()), CancellationToken.None); Assert.NotNull(result); diff --git a/tests/mcpdotnet.Tests/Server/McpServerDelegatesTests.cs b/tests/mcpdotnet.Tests/Server/McpServerDelegatesTests.cs index 95f5891..aa8d37b 100644 --- a/tests/mcpdotnet.Tests/Server/McpServerDelegatesTests.cs +++ b/tests/mcpdotnet.Tests/Server/McpServerDelegatesTests.cs @@ -1,54 +1,43 @@ -using McpDotNet.Protocol.Messages; -using McpDotNet.Protocol.Types; +using McpDotNet.Protocol.Types; using McpDotNet.Server; namespace McpDotNet.Tests.Server; -public class McpServerDelegatesTests +public class McpServerHandlerTests { [Fact] - public void Applies_All_Given_Delegates() + public void AllPropertiesAreSettable() { - var container = new McpServerDelegates(); - var server = new ExposeSetHandlersServer(); - - container.ListToolsHandler = (p, c) => Task.FromResult(new ListToolsResult()); - container.CallToolHandler = (p, c) => Task.FromResult(new CallToolResponse()); - container.ListPromptsHandler = (p, c) => Task.FromResult(new ListPromptsResult()); - container.GetPromptHandler = (p, c) => Task.FromResult(new GetPromptResult()); - container.ListResourcesHandler = (p, c) => Task.FromResult(new ListResourcesResult()); - container.ReadResourceHandler = (p, c) => Task.FromResult(new ReadResourceResult()); - container.GetCompletionHandler = (p, c) => Task.FromResult(new CompleteResult()); - container.SubscribeToResourcesHandler = (s, c) => Task.CompletedTask; - container.UnsubscribeFromResourcesHandler = (s, c) => Task.CompletedTask; - - container.Apply(server); - - Assert.Equal(container.ListToolsHandler, server.Handlers[OperationNames.ListTools]); - Assert.Equal(container.CallToolHandler, server.Handlers[OperationNames.CallTool]); - Assert.Equal(container.ListPromptsHandler, server.Handlers[OperationNames.ListPrompts]); - Assert.Equal(container.GetPromptHandler, server.Handlers[OperationNames.GetPrompt]); - Assert.Equal(container.ListResourcesHandler, server.Handlers[OperationNames.ListResources]); - Assert.Equal(container.ReadResourceHandler, server.Handlers[OperationNames.ReadResource]); - Assert.Equal(container.GetCompletionHandler, server.Handlers[OperationNames.GetCompletion]); - Assert.Equal(container.SubscribeToResourcesHandler, server.Handlers[OperationNames.SubscribeToResources]); - Assert.Equal(container.UnsubscribeFromResourcesHandler, server.Handlers[OperationNames.UnsubscribeFromResources]); - } - - private sealed class ExposeSetHandlersServer : IMcpServer - { - public Dictionary Handlers = []; - - public void SetOperationHandler(string operationName, Delegate handler) => Handlers[operationName] = handler; - - public ValueTask DisposeAsync() => default; - public bool IsInitialized => throw new NotImplementedException(); - public ClientCapabilities? ClientCapabilities => throw new NotImplementedException(); - public Implementation? ClientInfo => throw new NotImplementedException(); - public IServiceProvider? ServiceProvider => throw new NotImplementedException(); - public void AddNotificationHandler(string method, Func handler) => throw new NotImplementedException(); - public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default) => throw new NotImplementedException(); - public Task SendRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) where T : class => throw new NotImplementedException(); - public Task StartAsync(CancellationToken cancellationToken = default) => throw new NotImplementedException(); + var handlers = new McpServerHandlers(); + + Assert.Null(handlers.ListToolsHandler); + Assert.Null(handlers.CallToolHandler); + Assert.Null(handlers.ListPromptsHandler); + Assert.Null(handlers.GetPromptHandler); + Assert.Null(handlers.ListResourcesHandler); + Assert.Null(handlers.ReadResourceHandler); + Assert.Null(handlers.GetCompletionHandler); + Assert.Null(handlers.SubscribeToResourcesHandler); + Assert.Null(handlers.UnsubscribeFromResourcesHandler); + + handlers.ListToolsHandler = (p, c) => Task.FromResult(new ListToolsResult()); + handlers.CallToolHandler = (p, c) => Task.FromResult(new CallToolResponse()); + handlers.ListPromptsHandler = (p, c) => Task.FromResult(new ListPromptsResult()); + handlers.GetPromptHandler = (p, c) => Task.FromResult(new GetPromptResult()); + handlers.ListResourcesHandler = (p, c) => Task.FromResult(new ListResourcesResult()); + handlers.ReadResourceHandler = (p, c) => Task.FromResult(new ReadResourceResult()); + handlers.GetCompletionHandler = (p, c) => Task.FromResult(new CompleteResult()); + handlers.SubscribeToResourcesHandler = (s, c) => Task.CompletedTask; + handlers.UnsubscribeFromResourcesHandler = (s, c) => Task.CompletedTask; + + Assert.NotNull(handlers.ListToolsHandler); + Assert.NotNull(handlers.CallToolHandler); + Assert.NotNull(handlers.ListPromptsHandler); + Assert.NotNull(handlers.GetPromptHandler); + Assert.NotNull(handlers.ListResourcesHandler); + Assert.NotNull(handlers.ReadResourceHandler); + Assert.NotNull(handlers.GetCompletionHandler); + Assert.NotNull(handlers.SubscribeToResourcesHandler); + Assert.NotNull(handlers.UnsubscribeFromResourcesHandler); } } diff --git a/tests/mcpdotnet.Tests/Server/McpServerFactoryTests.cs b/tests/mcpdotnet.Tests/Server/McpServerFactoryTests.cs index 59fc5e3..be348fa 100644 --- a/tests/mcpdotnet.Tests/Server/McpServerFactoryTests.cs +++ b/tests/mcpdotnet.Tests/Server/McpServerFactoryTests.cs @@ -1,8 +1,7 @@ using McpDotNet.Protocol.Transport; using McpDotNet.Protocol.Types; using McpDotNet.Server; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; +using Microsoft.Extensions.Logging.Abstractions; using Moq; namespace McpDotNet.Tests.Server; @@ -10,16 +9,12 @@ namespace McpDotNet.Tests.Server; public class McpServerFactoryTests { private readonly Mock _serverTransport; - private readonly Mock _loggerFactory; - private readonly Mock> _serverDelegates; private readonly McpServerOptions _options; private readonly IServiceProvider _serviceProvider; public McpServerFactoryTests() { _serverTransport = new Mock(); - _loggerFactory = new Mock(); - _serverDelegates = new Mock>(); _options = new McpServerOptions { ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, @@ -30,54 +25,26 @@ public McpServerFactoryTests() } [Fact] - public void Constructor_Should_Initialize_With_Valid_Parameters() + public async Task Create_Should_Initialize_With_Valid_Parameters() { // Arrange & Act - var factory = new McpServerFactory(_serverTransport.Object, _options, _loggerFactory.Object, _serverDelegates.Object, _serviceProvider); + await using IMcpServer server = McpServerFactory.Create(_serverTransport.Object, _options, NullLoggerFactory.Instance); // Assert - Assert.NotNull(factory); + Assert.NotNull(server); } [Fact] public void Constructor_Throws_For_Null_ServerTransport() { // Arrange, Act & Assert - Assert.Throws(() => new McpServerFactory(null!, _options, _loggerFactory.Object, _serverDelegates.Object, _serviceProvider)); + Assert.Throws("serverTransport", () => McpServerFactory.Create(null!, _options, NullLoggerFactory.Instance)); } [Fact] public void Constructor_Throws_For_Null_Options() { // Arrange, Act & Assert - Assert.Throws(() => new McpServerFactory(_serverTransport.Object, null!, _loggerFactory.Object, _serverDelegates.Object, _serviceProvider)); - } - - [Fact] - public void Constructor_Does_Not_Throw_For_Null_ServerDelegates() - { - var factory = new McpServerFactory(_serverTransport.Object, _options, _loggerFactory.Object, null, _serviceProvider); - Assert.NotNull(factory); - } - - [Fact] - public void Constructor_Does_Not_Throw_For_Null_ServiceProvider() - { - var factory = new McpServerFactory(_serverTransport.Object, _options, _loggerFactory.Object, _serverDelegates.Object, null); - Assert.NotNull(factory); - } - - [Fact] - public void CreateServer_Return_IMcpServerInstance() - { - // Arrange - var factory = new McpServerFactory(_serverTransport.Object, _options, _loggerFactory.Object, _serverDelegates.Object, _serviceProvider); - - // Act - var server = factory.CreateServer(); - - // Assert - Assert.NotNull(server); - Assert.IsAssignableFrom(server); + Assert.Throws("serverOptions", () => McpServerFactory.Create(_serverTransport.Object, null!, NullLoggerFactory.Instance)); } } diff --git a/tests/mcpdotnet.Tests/Server/McpServerTests.cs b/tests/mcpdotnet.Tests/Server/McpServerTests.cs index 67d5545..39037ba 100644 --- a/tests/mcpdotnet.Tests/Server/McpServerTests.cs +++ b/tests/mcpdotnet.Tests/Server/McpServerTests.cs @@ -34,7 +34,7 @@ private static McpServerOptions CreateOptions(ServerCapabilities? capabilities = ServerInfo = new Implementation { Name = "TestServer", Version = "1.0" }, ProtocolVersion = "2024", InitializationTimeout = TimeSpan.FromSeconds(30), - Capabilities = capabilities + Capabilities = capabilities, }; } @@ -63,10 +63,13 @@ public void Constructor_Throws_For_Null_Options() } [Fact] - public void Constructor_Throws_For_Null_LoggerFactory() + public async Task Constructor_Does_Not_Throw_For_Null_Logger() { - // Arrange, Act & Assert - Assert.Throws(() => new McpServer(_serverTransport.Object, _options, null!, _serviceProvider)); + // Arrange & Act + await using var server = new McpServer(_serverTransport.Object, _options, null, _serviceProvider); + + // Assert + Assert.NotNull(server); } [Fact] @@ -226,103 +229,114 @@ public async Task Throws_Exception_If_Not_Connected() [Fact] public async Task Can_Handle_Ping_Requests() { - await Can_Handle_Requests(null, "ping", - configureServer: server => { }, - assertResult: response => - { - Assert.IsType(response); - }); + await Can_Handle_Requests( + serverCapabilities: null, + method: "ping", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); + }); } [Fact] public async Task Can_Handle_Initialize_Requests() { - await Can_Handle_Requests(null, "initialize", - configureServer: server => { }, - assertResult: response => - { - Assert.IsType(response); + await Can_Handle_Requests( + serverCapabilities: null, + method: "initialize", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (InitializeResult)response; - Assert.Equal("TestServer", result.ServerInfo.Name); - Assert.Equal("1.0", result.ServerInfo.Version); - Assert.Equal("2024", result.ProtocolVersion); - }); + var result = (InitializeResult)response; + Assert.Equal("TestServer", result.ServerInfo.Name); + Assert.Equal("1.0", result.ServerInfo.Version); + Assert.Equal("2024", result.ProtocolVersion); + }); } [Fact] public async Task Can_Handle_Completion_Requests() { - await Can_Handle_Requests(null, "completion/complete", - configureServer: server => { }, - assertResult: response => - { - Assert.IsType(response); + await Can_Handle_Requests( + serverCapabilities: null, + method: "completion/complete", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (CompleteResult)response; - Assert.NotNull(result.Completion); - Assert.Empty(result.Completion.Values); - Assert.Equal(0, result.Completion.Total); - Assert.False(result.Completion.HasMore); - }); + var result = (CompleteResult)response; + Assert.NotNull(result.Completion); + Assert.Empty(result.Completion.Values); + Assert.Equal(0, result.Completion.Total); + Assert.False(result.Completion.HasMore); + }); } [Fact] public async Task Can_Handle_Completion_Requests_With_Handler() { - await Can_Handle_Requests(null, "completion/complete", - configureServer: server => - { - server.SetGetCompletionHandler((request, ct) => - { - return Task.FromResult(new CompleteResult - { - Completion = new() - { - Values = ["test"], - Total = 2, - HasMore = true - } - }); - }); - }, - assertResult: response => - { - Assert.IsType(response); - - var result = (CompleteResult)response; - Assert.NotNull(result.Completion); - Assert.NotEmpty(result.Completion.Values); - Assert.Equal("test", result.Completion.Values[0]); - Assert.Equal(2, result.Completion.Total); - Assert.True(result.Completion.HasMore); - }); + await Can_Handle_Requests( + serverCapabilities: null, + method: "completion/complete", + configureOptions: options => options with + { + GetCompletionHandler = (request, ct) => + Task.FromResult(new CompleteResult + { + Completion = new() + { + Values = ["test"], + Total = 2, + HasMore = true + } + }) + }, + assertResult: response => + { + Assert.IsType(response); + + var result = (CompleteResult)response; + Assert.NotNull(result.Completion); + Assert.NotEmpty(result.Completion.Values); + Assert.Equal("test", result.Completion.Values[0]); + Assert.Equal(2, result.Completion.Total); + Assert.True(result.Completion.HasMore); + }); } [Fact] public async Task Can_Handle_Resources_List_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Resources = new() }, "resources/list", - configureServer: server => - { - server.SetListResourcesHandler((request, ct) => - { - return Task.FromResult(new ListResourcesResult - { - Resources = [new() { Uri = "test", Name = "Test Resource" }] - }); - }); - - }, - assertResult: response => - { - Assert.IsType(response); - - var result = (ListResourcesResult)response; - Assert.NotNull(result.Resources); - Assert.NotEmpty(result.Resources); - Assert.Equal("test", result.Resources[0].Uri); - }); + await Can_Handle_Requests( + new ServerCapabilities + { + Resources = new() + { + ListResourcesHandler = (request, ct) => + { + return Task.FromResult(new ListResourcesResult + { + Resources = [new() { Uri = "test", Name = "Test Resource" }] + }); + }, + ReadResourceHandler = (request, ct) => throw new NotImplementedException(), + } + }, + "resources/list", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); + + var result = (ListResourcesResult)response; + Assert.NotNull(result.Resources); + Assert.NotEmpty(result.Resources); + Assert.Equal("test", result.Resources[0].Uri); + }); } [Fact] @@ -334,26 +348,32 @@ public async Task Can_Handle_Resources_List_Requests_Throws_Exception_If_No_Hand [Fact] public async Task Can_Handle_ResourcesRead_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Resources = new() }, "resources/read", - configureServer: server => - { - server.SetReadResourceHandler((request, ct) => - { - return Task.FromResult(new ReadResourceResult - { - Contents = [new() { Text = "test" }] - }); - }); - }, - assertResult: response => - { - Assert.IsType(response); - - var result = (ReadResourceResult)response; - Assert.NotNull(result.Contents); - Assert.NotEmpty(result.Contents); - Assert.Equal("test", result.Contents[0].Text); - }); + await Can_Handle_Requests( + new ServerCapabilities + { + Resources = new() + { + ReadResourceHandler = (request, ct) => + { + return Task.FromResult(new ReadResourceResult + { + Contents = [new() { Text = "test" }] + }); + }, + ListResourcesHandler = (request, ct) => throw new NotImplementedException(), + } + }, + method: "resources/read", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); + + var result = (ReadResourceResult)response; + Assert.NotNull(result.Contents); + Assert.NotEmpty(result.Contents); + Assert.Equal("test", result.Contents[0].Text); + }); } [Fact] @@ -365,26 +385,32 @@ public async Task Can_Handle_Resources_Read_Requests_Throws_Exception_If_No_Hand [Fact] public async Task Can_Handle_List_Prompts_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Prompts = new() }, "prompts/list", - configureServer: server => + await Can_Handle_Requests( + new ServerCapabilities { - server.SetListPromptsHandler((request, ct) => + Prompts = new() { - return Task.FromResult(new ListPromptsResult + ListPromptsHandler = (request, ct) => { - Prompts = [new() { Name = "test" }] - }); - }); + return Task.FromResult(new ListPromptsResult + { + Prompts = [new() { Name = "test" }] + }); + }, + GetPromptHandler = (request, ct) => throw new NotImplementedException(), + }, }, - assertResult: response => - { - Assert.IsType(response); + method: "prompts/list", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (ListPromptsResult)response; - Assert.NotNull(result.Prompts); - Assert.NotEmpty(result.Prompts); - Assert.Equal("test", result.Prompts[0].Name); - }); + var result = (ListPromptsResult)response; + Assert.NotNull(result.Prompts); + Assert.NotEmpty(result.Prompts); + Assert.Equal("test", result.Prompts[0].Name); + }); } [Fact] @@ -396,24 +422,24 @@ public async Task Can_Handle_List_Prompts_Requests_Throws_Exception_If_No_Handle [Fact] public async Task Can_Handle_Get_Prompts_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Prompts = new() }, "prompts/get", - configureServer: server => + await Can_Handle_Requests( + new ServerCapabilities { - server.SetGetPromptHandler((request, ct) => + Prompts = new() { - return Task.FromResult(new GetPromptResult - { - Description = "test" - }); - }); + GetPromptHandler = (request, ct) => Task.FromResult(new GetPromptResult { Description = "test" }), + ListPromptsHandler = (request, ct) => throw new NotImplementedException(), + } }, - assertResult: response => - { - Assert.IsType(response); + method: "prompts/get", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (GetPromptResult)response; - Assert.Equal("test", result.Description); - }); + var result = (GetPromptResult)response; + Assert.Equal("test", result.Description); + }); } [Fact] @@ -425,25 +451,31 @@ public async Task Can_Handle_Get_Prompts_Requests_Throws_Exception_If_No_Handler [Fact] public async Task Can_Handle_List_Tools_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Tools = new() }, "tools/list", - configureServer: server => + await Can_Handle_Requests( + new ServerCapabilities { - server.SetListToolsHandler((request, ct) => + Tools = new() { - return Task.FromResult(new ListToolsResult + ListToolsHandler = (request, ct) => { - Tools = [new() { Name = "test" }] - }); - }); + return Task.FromResult(new ListToolsResult + { + Tools = [new() { Name = "test" }] + }); + }, + CallToolHandler = (request, ct) => throw new NotImplementedException(), + } }, - assertResult: response => - { - Assert.IsType(response); + method: "tools/list", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (ListToolsResult)response; - Assert.NotEmpty(result.Tools); - Assert.Equal("test", result.Tools[0].Name); - }); + var result = (ListToolsResult)response; + Assert.NotEmpty(result.Tools); + Assert.Equal("test", result.Tools[0].Name); + }); } [Fact] @@ -455,25 +487,31 @@ public async Task Can_Handle_List_Tools_Requests_Throws_Exception_If_No_Handler_ [Fact] public async Task Can_Handle_Call_Tool_Requests() { - await Can_Handle_Requests(new ServerCapabilities { Tools = new() }, "tools/call", - configureServer: server => + await Can_Handle_Requests( + new ServerCapabilities { - server.SetCallToolHandler((request, ct) => + Tools = new() { - return Task.FromResult(new CallToolResponse + CallToolHandler = (request, ct) => { - Content = [new Content { Text = "test" }] - }); - }); - }, - assertResult: response => - { - Assert.IsType(response); + return Task.FromResult(new CallToolResponse + { + Content = [new Content { Text = "test" }] + }); + }, + ListToolsHandler = (request, ct) => throw new NotImplementedException(), + } + }, + method: "tools/call", + configureOptions: null, + assertResult: response => + { + Assert.IsType(response); - var result = (CallToolResponse)response; - Assert.NotEmpty(result.Content); - Assert.Equal("test", result.Content[0].Text); - }); + var result = (CallToolResponse)response; + Assert.NotEmpty(result.Content); + Assert.Equal("test", result.Content[0].Text); + }); } [Fact] @@ -482,17 +520,19 @@ public async Task Can_Handle_Call_Tool_Requests_Throws_Exception_If_No_Handler_A await Throws_Exception_If_No_Handler_Assigned(new ServerCapabilities { Tools = new() }, "tools/call", "CallTool handler not configured"); } - private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Action configureServer, Action assertResult) + private async Task Can_Handle_Requests(ServerCapabilities? serverCapabilities, string method, Func? configureOptions, Action assertResult) { await using var transport = new TestServerTransport(); - var options = serverCapabilities == null ? _options : CreateOptions(serverCapabilities); + var options = CreateOptions(serverCapabilities); + if (configureOptions is not null) + { + options = configureOptions(options); + } await using var server = new McpServer(transport, options, _loggerFactory.Object, _serviceProvider); await server.StartAsync(); - configureServer(server); - var receivedMessage = new TaskCompletionSource(); transport.OnMessageSent = (message) => @@ -521,32 +561,6 @@ private async Task Throws_Exception_If_No_Handler_Assigned(ServerCapabilities se await using var transport = new TestServerTransport(); var options = CreateOptions(serverCapabilities); - await using var server = new McpServer(transport, options, _loggerFactory.Object, _serviceProvider); - - await server.StartAsync(); - - var receivedMessage = new TaskCompletionSource(); - - transport.OnMessageSent = (message) => - { - if (message is JsonRpcError response && response.Id.AsNumber == 55) - receivedMessage.SetResult(response); - }; - - await transport.SendMessageAsync( - new JsonRpcRequest - { - Method = method, - Id = RequestId.FromNumber(55) - } - ); - - var response = await receivedMessage.Task.WaitAsync(TimeSpan.FromSeconds(1)); - Assert.NotNull(response); - Assert.IsType(response); - - var result = (JsonRpcError)response; - Assert.NotNull(result.Error); - Assert.Equal(expectedError, result.Error.Message); + Assert.Throws(() => McpServerFactory.Create(transport, options, _loggerFactory.Object, _serviceProvider)); } } diff --git a/tests/mcpdotnet.Tests/SseIntegrationTests.cs b/tests/mcpdotnet.Tests/SseIntegrationTests.cs index e291fa9..27748f5 100644 --- a/tests/mcpdotnet.Tests/SseIntegrationTests.cs +++ b/tests/mcpdotnet.Tests/SseIntegrationTests.cs @@ -36,14 +36,8 @@ public async Task ConnectAndReceiveMessage_InMemoryServer() Location = "http://localhost:5000/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Act - var client = await factory.GetClientAsync("test_server"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); @@ -82,14 +76,8 @@ public async Task ConnectAndReceiveMessage_EverythingServerWithSse() Location = $"http://localhost:{port}/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Create client and run tests - var client = await factory.GetClientAsync("everything"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); var tools = await client.ListToolsAsync().ToListAsync(); // assert @@ -110,19 +98,6 @@ public async Task Sampling_Sse_EverythingServer() await using var fixture = new EverythingSseServerFixture(port); await fixture.StartAsync(); - var defaultOptions = new McpClientOptions - { - ClientInfo = new() - { - Name = "IntegrationTestClient", - Version = "1.0.0" - }, - Capabilities = new() - { - Sampling = new() - } - }; - var defaultConfig = new McpServerConfig { Id = "everything", @@ -132,29 +107,37 @@ public async Task Sampling_Sse_EverythingServer() Location = $"http://localhost:{port}/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - var client = await factory.GetClientAsync("everything"); - - // Set up the sampling handler int samplingHandlerCalls = 0; - client.SetSamplingHandler((_, _) => + var defaultOptions = new McpClientOptions { - samplingHandlerCalls++; - return Task.FromResult(new CreateMessageResult + ClientInfo = new() { - Model = "test-model", - Role = "assistant", - Content = new Content + Name = "IntegrationTestClient", + Version = "1.0.0" + }, + Capabilities = new() + { + Sampling = new() { - Type = "text", - Text = "Test response" - } - }); - }); + SamplingHandler = (_, _) => + { + samplingHandlerCalls++; + return Task.FromResult(new CreateMessageResult + { + Model = "test-model", + Role = "assistant", + Content = new Content + { + Type = "text", + Text = "Test response" + } + }); + }, + }, + }, + }; + + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); // Call the server's sampleLLM tool which should trigger our sampling handler var result = await client.CallToolAsync( @@ -200,14 +183,8 @@ public async Task ConnectAndReceiveMessage_InMemoryServer_WithFullEndpointEventU Location = "http://localhost:5000/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Act - var client = await factory.GetClientAsync("test_server"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); @@ -245,14 +222,8 @@ public async Task ConnectAndReceiveNotification_InMemoryServer() Location = "http://localhost:5000/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Act - var client = await factory.GetClientAsync("test_server"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); // Wait for SSE connection to be established await server.WaitForConnectionAsync(TimeSpan.FromSeconds(10)); @@ -300,14 +271,8 @@ public async Task ConnectTwice_Throws() Location = "http://localhost:5000/sse" }; - using var factory = new McpClientFactory( - [defaultConfig], - defaultOptions, - loggerFactory - ); - // Act - var client = await factory.GetClientAsync("test_server"); + var client = await McpClientFactory.CreateAsync(defaultConfig, defaultOptions, loggerFactory: loggerFactory); var mcpClient = (McpClient)client; var transport = (SseClientTransport)mcpClient.Transport;