Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/fastmcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ class Client(Generic[ClientTransportT]):
timeout: Optional timeout for requests (seconds or timedelta)
init_timeout: Optional timeout for initial connection (seconds or timedelta).
Set to 0 to disable. If None, uses the value in the FastMCP global settings.
concurrent_startup: If True and transport is an MCPConfig with multiple servers,
starts all servers concurrently instead of sequentially. Significantly speeds
up initialization when dealing with many servers. Default is False.

Examples:
```python
Expand All @@ -184,6 +187,17 @@ class Client(Generic[ClientTransportT]):

# Call a tool
result = await client.call_tool("my_tool", {"param": "value"})

# Connect to multiple MCP servers with concurrent startup
config = {
"mcpServers": {
"server1": {"command": "python", "args": ["server1.py"]},
"server2": {"command": "python", "args": ["server2.py"]},
}
}
client = Client(config, concurrent_startup=True)
async with client:
tools = await client.list_tools() # All servers started concurrently
```
"""

Expand Down Expand Up @@ -259,10 +273,15 @@ def __init__(
init_timeout: datetime.timedelta | float | int | None = None,
client_info: mcp.types.Implementation | None = None,
auth: httpx.Auth | Literal["oauth"] | str | None = None,
concurrent_startup: bool = False,
) -> None:
self.name = name or self.generate_name()

self.transport = cast(ClientTransportT, infer_transport(transport))
# Pass concurrent_startup to infer_transport for MCPConfig
self.transport = cast(
ClientTransportT,
infer_transport(transport, concurrent_startup=concurrent_startup),
)
if auth is not None:
self.transport._set_auth(auth)

Expand Down
35 changes: 31 additions & 4 deletions src/fastmcp/client/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,14 +969,22 @@ class MCPConfigTransport(ClientTransport):
```
"""

def __init__(self, config: MCPConfig | dict, name_as_prefix: bool = True):
def __init__(
self,
config: MCPConfig | dict,
name_as_prefix: bool = True,
concurrent_startup: bool = False,
):
from fastmcp.utilities.mcp_config import mcp_config_to_servers_and_transports

if isinstance(config, dict):
config = MCPConfig.from_dict(config)
self.config = config
self.concurrent_startup = concurrent_startup

self._underlying_transports: list[ClientTransport] = []
self._server_names: list[str] = []
self._warmup_done = False

# if there are no servers, raise an error
if len(self.config.mcpServers) == 0:
Expand All @@ -986,6 +994,7 @@ def __init__(self, config: MCPConfig | dict, name_as_prefix: bool = True):
elif len(self.config.mcpServers) == 1:
self.transport = next(iter(self.config.mcpServers.values())).to_transport()
self._underlying_transports.append(self.transport)
self._server_names.append(next(iter(self.config.mcpServers.keys())))

# otherwise create a composite client
else:
Expand All @@ -996,6 +1005,7 @@ def __init__(self, config: MCPConfig | dict, name_as_prefix: bool = True):
self.config
):
self._underlying_transports.append(transport)
self._server_names.append(name)
self._composite_server.mount(
server, prefix=name if name_as_prefix else None
)
Expand All @@ -1006,6 +1016,17 @@ def __init__(self, config: MCPConfig | dict, name_as_prefix: bool = True):
async def connect_session(
self, **session_kwargs: Unpack[SessionKwargs]
) -> AsyncIterator[ClientSession]:
# Warm up transports concurrently on first connection
if self.concurrent_startup and not self._warmup_done:
from fastmcp.utilities.mcp_config import warm_up_mcp_config_transports

await warm_up_mcp_config_transports(
self._underlying_transports,
server_names=self._server_names,
show_startup_logs=True,
)
self._warmup_done = True

async with self.transport.connect_session(**session_kwargs) as session:
yield session

Expand Down Expand Up @@ -1064,6 +1085,7 @@ def infer_transport(
| MCPConfig
| dict[str, Any]
| str,
concurrent_startup: bool = False,
) -> ClientTransport:
"""
Infer the appropriate transport type from the given transport argument.
Expand All @@ -1087,6 +1109,10 @@ def infer_transport(
`servername_toolname` for tools and `protocol://servername/path` for resources.
If the MCPConfig contains only one server, a direct connection is established without prefixing.

Args:
concurrent_startup: If True and transport is an MCPConfig with multiple servers,
all servers will be started concurrently for faster initialization.

Examples:
```python
# Connect to a local Python script
Expand All @@ -1095,14 +1121,14 @@ def infer_transport(
# Connect to a remote server via HTTP
transport = infer_transport("http://example.com/mcp")

# Connect to multiple servers using MCPConfig
# Connect to multiple servers using MCPConfig with concurrent startup
config = {
"mcpServers": {
"weather": {"url": "http://weather.example.com/mcp"},
"calendar": {"url": "http://calendar.example.com/mcp"}
}
}
transport = infer_transport(config)
transport = infer_transport(config, concurrent_startup=True)
```
"""

Expand Down Expand Up @@ -1140,7 +1166,8 @@ def infer_transport(
# if the transport is a config dict or MCPConfig
elif isinstance(transport, dict | MCPConfig):
inferred_transport = MCPConfigTransport(
config=cast(dict | MCPConfig, transport)
config=cast(dict | MCPConfig, transport),
concurrent_startup=concurrent_startup,
)

# the transport is an unknown type
Expand Down
126 changes: 121 additions & 5 deletions src/fastmcp/utilities/mcp_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
from typing import Any

from exceptiongroup import ExceptionGroup

from fastmcp.client.transports import (
ClientTransport,
SSETransport,
Expand All @@ -24,20 +27,134 @@ def mcp_config_to_servers_and_transports(
]


async def warm_up_mcp_config_transports(
transports: list[ClientTransport],
server_names: list[str] | None = None,
show_startup_logs: bool = True,
) -> None:
"""Pre-connect all transports concurrently.

Connects all stdio transports in parallel for faster initialization. When enabled,
captures and displays each server's startup logs sequentially for readability.

Args:
transports: List of ClientTransport objects to warm up
server_names: Optional list of server names for log formatting
show_startup_logs: If True, displays buffered startup logs sequentially
"""
import tempfile
from pathlib import Path

from fastmcp.client.transports import StdioTransport

if not transports:
return

server_names = server_names or [f"server_{i}" for i in range(len(transports))]

async def connect_with_log_capture(
transport: ClientTransport, name: str, index: int
) -> tuple[int, str, Exception | None]:
"""Connect a transport and capture its startup logs."""
if not isinstance(transport, StdioTransport):
return (index, "", None)

original_log_file = transport.log_file
temp_log_path = None

try:
if show_startup_logs:
with tempfile.NamedTemporaryFile(
mode="w", delete=False, suffix=f"_{name}.log"
) as f:
temp_log_path = Path(f.name)
transport.log_file = temp_log_path

await transport.connect()

logs = ""
if temp_log_path and temp_log_path.exists():
logs = temp_log_path.read_text()
temp_log_path.unlink()
return (index, logs, None)

except Exception as e:
logs = ""
if temp_log_path and temp_log_path.exists():
try:
logs = temp_log_path.read_text()
temp_log_path.unlink()
except Exception:
pass
return (index, logs, e)

finally:
transport.log_file = original_log_file

# Connect all transports concurrently
tasks = [
asyncio.create_task(connect_with_log_capture(t, name, i))
for i, (t, name) in enumerate(zip(transports, server_names, strict=False))
]
results = await asyncio.gather(*tasks, return_exceptions=False)

# Display logs sequentially
if show_startup_logs:
_display_startup_logs(results, server_names)

# Raise if any failed
errors = [error for _, _, error in results if error is not None]
if errors:
raise ExceptionGroup("Failed to start MCP servers", errors)


def _display_startup_logs(
results: list[tuple[int, str, Exception | None]], server_names: list[str]
) -> None:
"""Display captured startup logs in a readable format."""
import sys

for index, logs, error in sorted(results, key=lambda x: x[0]):
name = server_names[index]

if logs or error:
print(
f"\n{'=' * 60}\n[{name}] Startup\n{'=' * 60}",
file=sys.stderr,
flush=True,
)

if logs:
print(logs, file=sys.stderr, end="", flush=True)

status = "❌ ERROR: " + str(error) if error else "✓ Connected successfully"
print(f"\n{status}", file=sys.stderr, flush=True)

# Summary
total = len(results)
failed = sum(1 for _, _, error in results if error)
succeeded = total - failed

print(
f"\n{'=' * 60}\nStartup Summary: {succeeded}/{total} servers connected",
file=sys.stderr,
flush=True,
)
if failed > 0:
print(f"⚠️ {failed} server(s) failed to start", file=sys.stderr, flush=True)
print(f"{'=' * 60}\n", file=sys.stderr, flush=True)


def mcp_server_type_to_servers_and_transports(
name: str,
mcp_server: MCPServerTypes,
) -> tuple[str, FastMCP[Any], ClientTransport]:
"""A utility function to convert each entry of an MCP Config into a transport and server."""

from fastmcp.mcp_config import (
TransformingRemoteMCPServer,
TransformingStdioMCPServer,
)

server: FastMCP[Any]
transport: ClientTransport

client_name = ProxyClient.generate_name(f"MCP_{name}")
server_name = FastMCPProxy.generate_name(f"MCP_{name}")

Expand All @@ -50,7 +167,6 @@ def mcp_server_type_to_servers_and_transports(
client: ProxyClient[StreamableHttpTransport | SSETransport | StdioTransport] = (
ProxyClient(transport=transport, name=client_name)
)

server = FastMCP.as_proxy(name=server_name, backend=client)

return name, server, transport
Loading
Loading