Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 67 additions & 17 deletions src/mcpadapt/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def __init__(
adapter: ToolAdapter,
connect_timeout: int = 30,
client_session_timeout_seconds: float | timedelta | None = 5,
fail_fast: bool = True,
on_connection_error: Callable[[Any, Exception], None] | None = None,
):
"""
Manage the MCP server / client lifecycle and expose tools adapted with the adapter.
Expand All @@ -197,9 +199,14 @@ def __init__(
adapter (ToolAdapter): Adapter to use to convert MCP tools call into agentic framework tools.
connect_timeout (int): Connection timeout in seconds to the mcp server (default is 30s).
client_session_timeout_seconds: Timeout for MCP ClientSession calls
fail_fast (bool): If True, any connection failure will cause the entire adapter to fail.
If False, failed connections are skipped and only successful connections are used.
Default is True to maintain backward compatibility.
on_connection_error: Optional callback function called when a connection fails.
Receives (server_params, exception) as arguments.

Raises:
TimeoutError: When the connection to the mcp server time out.
TimeoutError: When the connection to the mcp server time out and fail_fast=True.
"""

if isinstance(serverparams, list):
Expand All @@ -208,6 +215,11 @@ def __init__(
self.serverparams = [serverparams]

self.adapter = adapter
self.fail_fast = fail_fast
self.on_connection_error = on_connection_error

# Track failed connections for transparency
self.failed_connections: list[tuple[Any, Exception]] = []

# session and tools get set by the async loop during initialization.
self.sessions: list[ClientSession] = []
Expand All @@ -229,13 +241,31 @@ def _run_loop(self):

async def setup():
async with AsyncExitStack() as stack:
connections = [
await stack.enter_async_context(
mcptools(params, self.client_session_timeout_seconds)
)
for params in self.serverparams
]
self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)]
connections = []

# Try to connect to each server individually for better fault tolerance
for params in self.serverparams:
try:
connection = await stack.enter_async_context(
mcptools(params, self.client_session_timeout_seconds)
)
connections.append(connection)
except Exception as e:
self.failed_connections.append((params, e))

if self.on_connection_error:
self.on_connection_error(params, e)

if self.fail_fast:
raise
else:
pass

if not connections and not self.fail_fast:
self.sessions, self.mcp_tools = [], []
elif connections:
self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)]

self.ready.set() # Signal initialization is complete
await asyncio.Event().wait() # Keep session alive until stopped

Expand All @@ -257,9 +287,13 @@ def tools(self) -> list[Any]:
see :meth:`atools`.

"""
if not self.sessions:
if not self.sessions and not self.failed_connections:
raise RuntimeError("Session not initialized")

if not self.sessions:
# Only failed connections, no successful ones
return []

def _sync_call_tool(
session, name: str, arguments: dict | None = None
) -> mcp.types.CallToolResult:
Expand Down Expand Up @@ -337,14 +371,30 @@ async def atools(self) -> list[Any]:
async def __aenter__(self) -> list[Any]:
self._ctxmanager = AsyncExitStack()

connections = [
await self._ctxmanager.enter_async_context(
mcptools(params, self.client_session_timeout_seconds)
)
for params in self.serverparams
]

self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)]
connections = []

# Try to connect to each server individually for better fault tolerance
for params in self.serverparams:
try:
connection = await self._ctxmanager.enter_async_context(
mcptools(params, self.client_session_timeout_seconds)
)
connections.append(connection)
except Exception as e:
self.failed_connections.append((params, e))

if self.on_connection_error:
self.on_connection_error(params, e)

if self.fail_fast:
raise
else:
pass

if not connections and not self.fail_fast:
self.sessions, self.mcp_tools = [], []
elif connections:
self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)]

return await self.atools()

Expand Down
Loading