Skip to content

Commit ffdecac

Browse files
authored
Improve robustness of MCP client remote tool calling (#840)
Adds configurable tool-call timeouts and backoff-based reconnect logic to MCP clients, propagates these settings through transport constructors and client config, enforces parent linkage for tool clients, updates lifecycle and error-handling flows, and adds tests covering reconnect/backoff, concurrency, and timeout propagation. Mainly focused on fixing the following scenarios, the MCP client should not hang but return comprehensive information - MCP server is not available (MCP client should try to re-connect before returning error information) - MCP server relaunched but missing some tools previously registered with the MCP client, and MCP client is calling that missing tool - MCP server takes too long to execute a tool Closes AIQ-1942 ## By Submitting this PR I confirm: - I am familiar with the [Contributing Guidelines](https://github.com/NVIDIA/NeMo-Agent-Toolkit/blob/develop/docs/source/resources/contributing.md). - We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. - Any contribution which contains commits that are not Signed-Off will not be accepted. - When the PR is ready for review, new or existing tests cover these changes. - When the PR is ready for review, the documentation is up to date with these changes. ## Summary by CodeRabbit - New Features - Added configurable automatic reconnect (enable, max attempts, initial/max backoff) and per-call tool timeouts; settings propagate across transports and to tool clients. - Refactor - Improved lifecycle and connection orchestration with single-retry wrapper, reconnection backoff, session state handling, and clearer error formatting for tool calls. - Tests - Added extensive tests for reconnect scenarios, backoff timing, concurrency safety, timeout propagation, and lifecycle/state transitions. Authors: - Yuchen Zhang (https://github.com/yczhang-nv) Approvers: - Zhongxuan (Daniel) Wang (https://github.com/zhongxuanwang-nv) - Will Killian (https://github.com/willkill07) - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah) URL: #840
1 parent a3c7296 commit ffdecac

File tree

3 files changed

+598
-44
lines changed

3 files changed

+598
-44
lines changed

packages/nvidia_nat_mcp/src/nat/plugins/mcp/client_base.py

Lines changed: 168 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515

1616
from __future__ import annotations
1717

18+
import asyncio
19+
import json
1820
import logging
1921
from abc import ABC
2022
from abc import abstractmethod
23+
from collections.abc import AsyncGenerator
2124
from contextlib import AsyncExitStack
2225
from contextlib import asynccontextmanager
23-
from typing import AsyncGenerator
26+
from datetime import timedelta
2427

2528
import httpx
2629

@@ -33,7 +36,10 @@
3336
from nat.authentication.interfaces import AuthProviderBase
3437
from nat.data_models.authentication import AuthReason
3538
from nat.data_models.authentication import AuthRequest
39+
from nat.plugins.mcp.exception_handler import convert_to_mcp_error
40+
from nat.plugins.mcp.exception_handler import format_mcp_error
3641
from nat.plugins.mcp.exception_handler import mcp_exception_handler
42+
from nat.plugins.mcp.exceptions import MCPError
3743
from nat.plugins.mcp.exceptions import MCPToolNotFoundError
3844
from nat.plugins.mcp.utils import model_from_mcp_schema
3945
from nat.utils.type_utils import override
@@ -85,7 +91,6 @@ def _is_tool_call_request(self, request: httpx.Request) -> bool:
8591
try:
8692
# Check if the request body contains a tool call
8793
if request.content:
88-
import json
8994
body = json.loads(request.content.decode('utf-8'))
9095
# Check if it's a JSON-RPC request with method "tools/call"
9196
if (isinstance(body, dict) and body.get("method") == "tools/call"):
@@ -131,7 +136,14 @@ class MCPBaseClient(ABC):
131136
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
132137
"""
133138

134-
def __init__(self, transport: str = 'streamable-http', auth_provider: AuthProviderBase | None = None):
139+
def __init__(self,
140+
transport: str = 'streamable-http',
141+
auth_provider: AuthProviderBase | None = None,
142+
tool_call_timeout: timedelta = timedelta(seconds=5),
143+
reconnect_enabled: bool = True,
144+
reconnect_max_attempts: int = 2,
145+
reconnect_initial_backoff: float = 0.5,
146+
reconnect_max_backoff: float = 50.0):
135147
self._tools = None
136148
self._transport = transport.lower()
137149
if self._transport not in ['sse', 'stdio', 'streamable-http']:
@@ -145,6 +157,15 @@ def __init__(self, transport: str = 'streamable-http', auth_provider: AuthProvid
145157
# Convert auth provider to AuthAdapter
146158
self._httpx_auth = AuthAdapter(auth_provider) if auth_provider else None
147159

160+
self._tool_call_timeout = tool_call_timeout
161+
162+
# Reconnect configuration
163+
self._reconnect_enabled = reconnect_enabled
164+
self._reconnect_max_attempts = reconnect_max_attempts
165+
self._reconnect_initial_backoff = reconnect_initial_backoff
166+
self._reconnect_max_backoff = reconnect_max_backoff
167+
self._reconnect_lock: asyncio.Lock = asyncio.Lock()
168+
148169
@property
149170
def transport(self) -> str:
150171
return self._transport
@@ -164,13 +185,14 @@ async def __aenter__(self):
164185
return self
165186

166187
async def __aexit__(self, exc_type, exc_value, traceback):
167-
if not self._exit_stack:
168-
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
188+
if self._exit_stack:
189+
# Close session
190+
await self._exit_stack.aclose()
191+
self._session = None
192+
self._exit_stack = None
169193

170-
# Close session
171-
await self._exit_stack.aclose()
172-
self._session = None
173-
self._exit_stack = None
194+
self._connection_established = False
195+
self._tools = None
174196

175197
@property
176198
def server_name(self):
@@ -181,30 +203,89 @@ def server_name(self):
181203

182204
@abstractmethod
183205
@asynccontextmanager
184-
async def connect_to_server(self):
206+
async def connect_to_server(self) -> AsyncGenerator[ClientSession, None]:
185207
"""
186208
Establish a session with an MCP server within an async context
187209
"""
188210
yield
189211

212+
async def _reconnect(self):
213+
"""
214+
Attempt to reconnect by tearing down and re-establishing the session.
215+
"""
216+
async with self._reconnect_lock:
217+
backoff = self._reconnect_initial_backoff
218+
attempt = 0
219+
last_error: Exception | None = None
220+
221+
while attempt in range(0, self._reconnect_max_attempts):
222+
attempt += 1
223+
try:
224+
# Close the existing stack and ClientSession
225+
if self._exit_stack:
226+
await self._exit_stack.aclose()
227+
# Create a fresh stack and session
228+
self._exit_stack = AsyncExitStack()
229+
self._session = await self._exit_stack.enter_async_context(self.connect_to_server())
230+
231+
self._connection_established = True
232+
self._tools = None
233+
234+
logger.info("Reconnected to MCP server (%s) on attempt %d", self.server_name, attempt)
235+
return
236+
237+
except Exception as e:
238+
last_error = e
239+
logger.warning("Reconnect attempt %d failed for %s: %s", attempt, self.server_name, e)
240+
await asyncio.sleep(min(backoff, self._reconnect_max_backoff))
241+
backoff = min(backoff * 2, self._reconnect_max_backoff)
242+
243+
# All attempts failed
244+
self._connection_established = False
245+
if last_error:
246+
raise last_error
247+
248+
async def _with_reconnect(self, coro):
249+
"""
250+
Execute an awaited operation, reconnecting once on errors.
251+
"""
252+
try:
253+
return await coro()
254+
except Exception as e:
255+
if self._reconnect_enabled:
256+
logger.warning("MCP Client operation failed. Attempting reconnect: %s", e)
257+
try:
258+
await self._reconnect()
259+
except Exception as reconnect_err:
260+
logger.error("MCP Client reconnect attempt failed: %s", reconnect_err)
261+
raise
262+
return await coro()
263+
raise
264+
190265
async def get_tools(self):
191266
"""
192267
Retrieve a dictionary of all tools served by the MCP server.
193268
Uses unauthenticated session for discovery.
194269
"""
195270

196-
if not self._session:
197-
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
271+
async def _get_tools():
272+
session = self._session
273+
return await session.list_tools()
198274

199-
response = await self._session.list_tools()
275+
try:
276+
response = await self._with_reconnect(_get_tools)
277+
except Exception as e:
278+
logger.warning("Failed to get tools: %s", e)
279+
raise
200280

201281
return {
202282
tool.name:
203283
MCPToolClient(session=self._session,
204284
tool_name=tool.name,
205285
tool_description=tool.description,
206286
tool_input_schema=tool.inputSchema,
207-
parent_client=self)
287+
parent_client=self,
288+
tool_call_timeout=self._tool_call_timeout)
208289
for tool in response.tools
209290
}
210291

@@ -235,11 +316,12 @@ async def get_tool(self, tool_name: str) -> MCPToolClient:
235316

236317
@mcp_exception_handler
237318
async def call_tool(self, tool_name: str, tool_args: dict | None):
238-
if not self._session:
239-
raise RuntimeError("MCPBaseClient not initialized. Use async with to initialize.")
240319

241-
result = await self._session.call_tool(tool_name, tool_args)
242-
return result
320+
async def _call_tool():
321+
session = self._session
322+
return await session.call_tool(tool_name, tool_args, read_timeout_seconds=self._tool_call_timeout)
323+
324+
return await self._with_reconnect(_call_tool)
243325

244326

245327
class MCPSSEClient(MCPBaseClient):
@@ -250,8 +332,19 @@ class MCPSSEClient(MCPBaseClient):
250332
url (str): The url of the MCP server
251333
"""
252334

253-
def __init__(self, url: str):
254-
super().__init__("sse")
335+
def __init__(self,
336+
url: str,
337+
tool_call_timeout: timedelta = timedelta(seconds=5),
338+
reconnect_enabled: bool = True,
339+
reconnect_max_attempts: int = 2,
340+
reconnect_initial_backoff: float = 0.5,
341+
reconnect_max_backoff: float = 50.0):
342+
super().__init__("sse",
343+
tool_call_timeout=tool_call_timeout,
344+
reconnect_enabled=reconnect_enabled,
345+
reconnect_max_attempts=reconnect_max_attempts,
346+
reconnect_initial_backoff=reconnect_initial_backoff,
347+
reconnect_max_backoff=reconnect_max_backoff)
255348
self._url = url
256349

257350
@property
@@ -286,8 +379,21 @@ class MCPStdioClient(MCPBaseClient):
286379
env (dict[str, str] | None): Environment variables to set for the process
287380
"""
288381

289-
def __init__(self, command: str, args: list[str] | None = None, env: dict[str, str] | None = None):
290-
super().__init__("stdio")
382+
def __init__(self,
383+
command: str,
384+
args: list[str] | None = None,
385+
env: dict[str, str] | None = None,
386+
tool_call_timeout: timedelta = timedelta(seconds=5),
387+
reconnect_enabled: bool = True,
388+
reconnect_max_attempts: int = 2,
389+
reconnect_initial_backoff: float = 0.5,
390+
reconnect_max_backoff: float = 50.0):
391+
super().__init__("stdio",
392+
tool_call_timeout=tool_call_timeout,
393+
reconnect_enabled=reconnect_enabled,
394+
reconnect_max_attempts=reconnect_max_attempts,
395+
reconnect_initial_backoff=reconnect_initial_backoff,
396+
reconnect_max_backoff=reconnect_max_backoff)
291397
self._command = command
292398
self._args = args
293399
self._env = env
@@ -331,8 +437,21 @@ class MCPStreamableHTTPClient(MCPBaseClient):
331437
auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
332438
"""
333439

334-
def __init__(self, url: str, auth_provider: AuthProviderBase | None = None):
335-
super().__init__("streamable-http", auth_provider=auth_provider)
440+
def __init__(self,
441+
url: str,
442+
auth_provider: AuthProviderBase | None = None,
443+
tool_call_timeout: timedelta = timedelta(seconds=5),
444+
reconnect_enabled: bool = True,
445+
reconnect_max_attempts: int = 2,
446+
reconnect_initial_backoff: float = 0.5,
447+
reconnect_max_backoff: float = 50.0):
448+
super().__init__("streamable-http",
449+
auth_provider=auth_provider,
450+
tool_call_timeout=tool_call_timeout,
451+
reconnect_enabled=reconnect_enabled,
452+
reconnect_max_attempts=reconnect_max_attempts,
453+
reconnect_initial_backoff=reconnect_initial_backoff,
454+
reconnect_max_backoff=reconnect_max_backoff)
336455
self._url = url
337456

338457
@property
@@ -371,15 +490,20 @@ class MCPToolClient:
371490

372491
def __init__(self,
373492
session: ClientSession,
493+
parent_client: "MCPBaseClient",
374494
tool_name: str,
375495
tool_description: str | None,
376496
tool_input_schema: dict | None = None,
377-
parent_client: "MCPBaseClient | None" = None):
497+
tool_call_timeout: timedelta = timedelta(seconds=5)):
378498
self._session = session
379499
self._tool_name = tool_name
380500
self._tool_description = tool_description
381501
self._input_schema = (model_from_mcp_schema(self._tool_name, tool_input_schema) if tool_input_schema else None)
382502
self._parent_client = parent_client
503+
self._tool_call_timeout = tool_call_timeout
504+
505+
if self._parent_client is None:
506+
raise RuntimeError("MCPToolClient initialized without a parent client.")
383507

384508
@property
385509
def name(self):
@@ -415,22 +539,25 @@ async def acall(self, tool_args: dict) -> str:
415539
Args:
416540
tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
417541
"""
418-
if self._session is None:
419-
raise RuntimeError("No session available for tool call")
420542
logger.info("Calling tool %s with arguments %s", self._tool_name, tool_args)
421-
result = await self._session.call_tool(self._tool_name, tool_args)
422-
423-
output = []
424-
425-
for res in result.content:
426-
if isinstance(res, TextContent):
427-
output.append(res.text)
428-
else:
429-
# Log non-text content for now
430-
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
431-
result_str = "\n".join(output)
432-
433-
if result.isError:
434-
raise RuntimeError(result_str)
543+
try:
544+
result = await self._parent_client.call_tool(self._tool_name, tool_args)
545+
546+
output = []
547+
for res in result.content:
548+
if isinstance(res, TextContent):
549+
output.append(res.text)
550+
else:
551+
# Log non-text content for now
552+
logger.warning("Got not-text output from %s of type %s", self.name, type(res))
553+
result_str = "\n".join(output)
554+
555+
if result.isError:
556+
mcp_error: MCPError = convert_to_mcp_error(RuntimeError(result_str), self._parent_client.server_name)
557+
raise mcp_error
558+
559+
except MCPError as e:
560+
format_mcp_error(e, include_traceback=False)
561+
result_str = "MCPToolClient tool call failed: %s" % e.original_exception
435562

436563
return result_str

0 commit comments

Comments
 (0)