diff --git a/src/fastmcp/client/transports.py b/src/fastmcp/client/transports.py index be01e7a643..8265b35fdb 100644 --- a/src/fastmcp/client/transports.py +++ b/src/fastmcp/client/transports.py @@ -33,7 +33,15 @@ import fastmcp from fastmcp.client.auth.bearer import BearerAuth from fastmcp.client.auth.oauth import OAuth -from fastmcp.mcp_config import MCPConfig, infer_transport_type_from_url +from fastmcp.mcp_config import ( + MCPConfig, + MCPServerTypes, + RemoteMCPServer, + StdioMCPServer, + TransformingRemoteMCPServer, + TransformingStdioMCPServer, + infer_transport_type_from_url, +) from fastmcp.server.dependencies import get_http_headers from fastmcp.server.server import FastMCP from fastmcp.server.tasks.capabilities import get_task_capabilities @@ -959,7 +967,6 @@ class MCPConfigTransport(ClientTransport): Examples: ```python from fastmcp import Client - from fastmcp.utilities.mcp_config import MCPConfig # Create a config with multiple servers config = { @@ -989,47 +996,95 @@ class MCPConfigTransport(ClientTransport): """ def __init__(self, config: MCPConfig | dict, name_as_prefix: bool = True): - from fastmcp.utilities.mcp_config import mcp_config_to_servers_and_transports - if isinstance(config, dict): config = MCPConfig.from_dict(config) self.config = config + self.name_as_prefix = name_as_prefix + self._transports: list[ClientTransport] = [] - self._underlying_transports: list[ClientTransport] = [] - - # if there are no servers, raise an error - if len(self.config.mcpServers) == 0: + if not self.config.mcpServers: raise ValueError("No MCP servers defined in the config") - # if there's exactly one server, create a client for that server - elif len(self.config.mcpServers) == 1: + # For single server, create transport eagerly so it can be inspected + if len(self.config.mcpServers) == 1: self.transport = next(iter(self.config.mcpServers.values())).to_transport() - self._underlying_transports.append(self.transport) - - # otherwise create a composite client - else: - name = FastMCP.generate_name("MCPRouter") - self._composite_server = FastMCP[Any](name=name) - - for name, server, transport in mcp_config_to_servers_and_transports( - self.config - ): - self._underlying_transports.append(transport) - self._composite_server.mount( - server, namespace=name if name_as_prefix else None - ) - - self.transport = FastMCPTransport(mcp=self._composite_server) + self._transports.append(self.transport) @contextlib.asynccontextmanager async def connect_session( self, **session_kwargs: Unpack[SessionKwargs] ) -> AsyncIterator[ClientSession]: - async with self.transport.connect_session(**session_kwargs) as session: + # Single server - delegate directly to pre-created transport + if len(self.config.mcpServers) == 1: + async with self.transport.connect_session(**session_kwargs) as session: + yield session + return + + # Multiple servers - create composite with mounted proxies + # Close any previous transports from prior connections to avoid leaking + for t in self._transports: + await t.close() + self._transports = [] + timeout = session_kwargs.get("read_timeout_seconds") + composite = FastMCP[Any](name="MCPRouter") + + try: + for name, server_config in self.config.mcpServers.items(): + transport, proxy = self._create_proxy(name, server_config, timeout) + self._transports.append(transport) + composite.mount(proxy, namespace=name if self.name_as_prefix else None) + except Exception: + # Clean up any transports created before the failure + for t in self._transports: + await t.close() + self._transports = [] + raise + + async with FastMCPTransport(mcp=composite).connect_session( + **session_kwargs + ) as session: yield session + def _create_proxy( + self, + name: str, + config: MCPServerTypes, + timeout: datetime.timedelta | None, + ) -> tuple[ClientTransport, FastMCP[Any]]: + """Create underlying transport and proxy server for a single backend.""" + # Import here to avoid circular dependency + from fastmcp.server.providers.proxy import ProxyClient + + tool_transforms = None + include_tags = None + exclude_tags = None + + # Handle transforming servers - call base class to_transport() for underlying transport + if isinstance(config, TransformingStdioMCPServer): + transport = StdioMCPServer.to_transport(config) + tool_transforms = config.tools + include_tags = config.include_tags + exclude_tags = config.exclude_tags + elif isinstance(config, TransformingRemoteMCPServer): + transport = RemoteMCPServer.to_transport(config) + tool_transforms = config.tools + include_tags = config.include_tags + exclude_tags = config.exclude_tags + else: + transport = config.to_transport() + + client = ProxyClient(transport=transport, timeout=timeout) + proxy = FastMCP.as_proxy( + name=f"Proxy-{name}", + backend=client, + tool_transformations=tool_transforms, + include_tags=include_tags, + exclude_tags=exclude_tags, + ) + return transport, proxy + async def close(self): - for transport in self._underlying_transports: + for transport in self._transports: await transport.close() def __repr__(self) -> str: diff --git a/src/fastmcp/utilities/mcp_config.py b/src/fastmcp/utilities/mcp_config.py deleted file mode 100644 index 4c54131915..0000000000 --- a/src/fastmcp/utilities/mcp_config.py +++ /dev/null @@ -1,56 +0,0 @@ -from typing import Any - -from fastmcp.client.transports import ( - ClientTransport, - SSETransport, - StdioTransport, - StreamableHttpTransport, -) -from fastmcp.mcp_config import ( - MCPConfig, - MCPServerTypes, -) -from fastmcp.server.providers.proxy import FastMCPProxy, ProxyClient -from fastmcp.server.server import FastMCP - - -def mcp_config_to_servers_and_transports( - config: MCPConfig, -) -> list[tuple[str, FastMCP[Any], ClientTransport]]: - """A utility function to convert each entry of an MCP Config into a transport and server.""" - return [ - mcp_server_type_to_servers_and_transports(name, mcp_server) - for name, mcp_server in config.mcpServers.items() - ] - - -def mcp_server_type_to_servers_and_transports( - name: str, - mcp_server: MCPServerTypes, -) -> tuple[str, FastMCP[Any], ClientTransport]: - """A utility function to convert each entry of an MCP Config into a transport and server.""" - - from fastmcp.mcp_config import ( - TransformingRemoteMCPServer, - TransformingStdioMCPServer, - ) - - server: FastMCP[Any] - transport: ClientTransport - - client_name = ProxyClient.generate_name(f"MCP_{name}") - server_name = FastMCPProxy.generate_name(f"MCP_{name}") - - if isinstance(mcp_server, TransformingRemoteMCPServer | TransformingStdioMCPServer): - server, transport = mcp_server._to_server_and_underlying_transport( - server_name=server_name, client_name=client_name - ) - else: - transport = mcp_server.to_transport() - client: ProxyClient[StreamableHttpTransport | SSETransport | StdioTransport] = ( - ProxyClient(transport=transport, name=client_name) - ) - - server = FastMCP.as_proxy(name=server_name, backend=client) - - return name, server, transport diff --git a/tests/client/test_client.py b/tests/client/test_client.py index b0676561a0..36691e111c 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1110,9 +1110,8 @@ def test_infer_composite_client(self): } transport = infer_transport(config) assert isinstance(transport, MCPConfigTransport) - assert isinstance(transport.transport, FastMCPTransport) - # 3 providers: LocalProvider (always first) + 2 mounted MCP servers - assert len(cast(FastMCP, transport.transport.server)._providers) == 3 + # Multi-server configs create composite server at connect time + assert len(transport.config.mcpServers) == 2 def test_infer_fastmcp_server(self, fastmcp_server): """FastMCP server instances should infer to FastMCPTransport.""" diff --git a/tests/test_mcp_config.py b/tests/test_mcp_config.py index f927b623ac..0c3fc3e1ad 100644 --- a/tests/test_mcp_config.py +++ b/tests/test_mcp_config.py @@ -6,12 +6,16 @@ import sys import tempfile from collections.abc import AsyncGenerator +from datetime import timedelta from pathlib import Path from typing import Any +from unittest.mock import AsyncMock, patch import psutil import pytest +from mcp.types import TextContent +from fastmcp import FastMCP from fastmcp.client.auth.bearer import BearerAuth from fastmcp.client.auth.oauth import OAuthClientProvider from fastmcp.client.client import Client @@ -774,6 +778,127 @@ async def elicitation_handler(message, response_type, params, ctx): assert result.data == 42 +async def test_multi_server_config_transport(tmp_path: Path): + """ + Tests that MCPConfigTransport properly handles multi-server configurations. + + Related to https://github.com/jlowin/fastmcp/issues/2802 - verifies the + refactored architecture creates composite servers correctly. + """ + server_script = inspect.cleandoc(""" + from fastmcp import FastMCP + + mcp = FastMCP() + + @mcp.tool + def greet(name: str) -> str: + return f"Hello, {name}!" + + if __name__ == '__main__': + mcp.run() + """) + + script_path = tmp_path / "greet_server.py" + script_path.write_text(server_script) + + config = { + "mcpServers": { + "server1": { + "command": "python", + "args": [str(script_path)], + }, + "server2": { + "command": "python", + "args": [str(script_path)], + }, + } + } + + # Create client with multiple servers + client = Client(config) + assert isinstance(client.transport, MCPConfigTransport) + + # Verify both servers are accessible via prefixed tool names + async with client: + tools = await client.list_tools() + tool_names = [t.name for t in tools] + assert "server1_greet" in tool_names + assert "server2_greet" in tool_names + + # Call tools on both servers + result1 = await client.call_tool("server1_greet", {"name": "World"}) + assert isinstance(result1.content[0], TextContent) + assert "Hello, World!" in result1.content[0].text + + result2 = await client.call_tool("server2_greet", {"name": "FastMCP"}) + assert isinstance(result2.content[0], TextContent) + assert "Hello, FastMCP!" in result2.content[0].text + + +async def test_multi_server_timeout_propagation(): + """Test that timeout is correctly propagated to proxy clients in multi-server configs.""" + # Create a config with multiple servers + config = MCPConfig( + mcpServers={ + "server1": StdioMCPServer(command="echo", args=["test"]), + "server2": StdioMCPServer(command="echo", args=["test"]), + } + ) + + transport = MCPConfigTransport(config) + timeout = timedelta(seconds=42) + + # Patch _create_proxy to verify timeout is passed correctly + with ( + patch("fastmcp.client.transports.FastMCP.as_proxy") as mock_as_proxy, + patch.object( + transport, "_create_proxy", wraps=transport._create_proxy + ) as mock_create_proxy, + ): + # Make as_proxy return a mock FastMCP + mock_proxy = FastMCP(name="MockProxy") + mock_as_proxy.return_value = mock_proxy + + # Mock connect_session on FastMCPTransport to avoid actual connection + with patch( + "fastmcp.client.transports.FastMCPTransport.connect_session" + ) as mock_connect: + mock_session = AsyncMock() + mock_connect.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_connect.return_value.__aexit__ = AsyncMock(return_value=None) + + async with transport.connect_session(read_timeout_seconds=timeout): + pass + + # Verify _create_proxy was called with the timeout for each server + assert mock_create_proxy.call_count == 2 + for call in mock_create_proxy.call_args_list: + _, kwargs = call.args, call.kwargs if call.kwargs else {} + # Third positional arg is timeout + call_timeout = call[0][2] if len(call[0]) > 2 else kwargs.get("timeout") + assert call_timeout == timeout, ( + f"Expected timeout {timeout}, got {call_timeout}" + ) + + +async def test_single_server_config_transport(): + """Test that single-server configs delegate directly without creating a composite.""" + config = MCPConfig( + mcpServers={ + "only_server": StdioMCPServer(command="echo", args=["test"]), + } + ) + + transport = MCPConfigTransport(config) + + # Single server should have transport created eagerly (not at connect time) + assert hasattr(transport, "transport") + assert isinstance(transport.transport, StdioTransport) + + # _transports should already contain the single transport + assert len(transport._transports) == 1 + + def sample_tool_fn(arg1: int, arg2: str) -> str: return f"Hello, world! {arg1} {arg2}"