1515
1616from __future__ import annotations
1717
18+ import asyncio
19+ import json
1820import logging
1921from abc import ABC
2022from abc import abstractmethod
23+ from collections .abc import AsyncGenerator
2124from contextlib import AsyncExitStack
2225from contextlib import asynccontextmanager
23- from typing import AsyncGenerator
26+ from datetime import timedelta
2427
2528import httpx
2629
3336from nat .authentication .interfaces import AuthProviderBase
3437from nat .data_models .authentication import AuthReason
3538from nat .data_models .authentication import AuthRequest
39+ from nat .plugins .mcp .exception_handler import convert_to_mcp_error
40+ from nat .plugins .mcp .exception_handler import format_mcp_error
3641from nat .plugins .mcp .exception_handler import mcp_exception_handler
42+ from nat .plugins .mcp .exceptions import MCPError
3743from nat .plugins .mcp .exceptions import MCPToolNotFoundError
3844from nat .plugins .mcp .utils import model_from_mcp_schema
3945from nat .utils .type_utils import override
@@ -85,7 +91,6 @@ def _is_tool_call_request(self, request: httpx.Request) -> bool:
8591 try :
8692 # Check if the request body contains a tool call
8793 if request .content :
88- import json
8994 body = json .loads (request .content .decode ('utf-8' ))
9095 # Check if it's a JSON-RPC request with method "tools/call"
9196 if (isinstance (body , dict ) and body .get ("method" ) == "tools/call" ):
@@ -131,7 +136,14 @@ class MCPBaseClient(ABC):
131136 auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
132137 """
133138
134- def __init__ (self , transport : str = 'streamable-http' , auth_provider : AuthProviderBase | None = None ):
139+ def __init__ (self ,
140+ transport : str = 'streamable-http' ,
141+ auth_provider : AuthProviderBase | None = None ,
142+ tool_call_timeout : timedelta = timedelta (seconds = 5 ),
143+ reconnect_enabled : bool = True ,
144+ reconnect_max_attempts : int = 2 ,
145+ reconnect_initial_backoff : float = 0.5 ,
146+ reconnect_max_backoff : float = 50.0 ):
135147 self ._tools = None
136148 self ._transport = transport .lower ()
137149 if self ._transport not in ['sse' , 'stdio' , 'streamable-http' ]:
@@ -145,6 +157,15 @@ def __init__(self, transport: str = 'streamable-http', auth_provider: AuthProvid
145157 # Convert auth provider to AuthAdapter
146158 self ._httpx_auth = AuthAdapter (auth_provider ) if auth_provider else None
147159
160+ self ._tool_call_timeout = tool_call_timeout
161+
162+ # Reconnect configuration
163+ self ._reconnect_enabled = reconnect_enabled
164+ self ._reconnect_max_attempts = reconnect_max_attempts
165+ self ._reconnect_initial_backoff = reconnect_initial_backoff
166+ self ._reconnect_max_backoff = reconnect_max_backoff
167+ self ._reconnect_lock : asyncio .Lock = asyncio .Lock ()
168+
148169 @property
149170 def transport (self ) -> str :
150171 return self ._transport
@@ -164,13 +185,14 @@ async def __aenter__(self):
164185 return self
165186
166187 async def __aexit__ (self , exc_type , exc_value , traceback ):
167- if not self ._exit_stack :
168- raise RuntimeError ("MCPBaseClient not initialized. Use async with to initialize." )
188+ if self ._exit_stack :
189+ # Close session
190+ await self ._exit_stack .aclose ()
191+ self ._session = None
192+ self ._exit_stack = None
169193
170- # Close session
171- await self ._exit_stack .aclose ()
172- self ._session = None
173- self ._exit_stack = None
194+ self ._connection_established = False
195+ self ._tools = None
174196
175197 @property
176198 def server_name (self ):
@@ -181,30 +203,89 @@ def server_name(self):
181203
182204 @abstractmethod
183205 @asynccontextmanager
184- async def connect_to_server (self ):
206+ async def connect_to_server (self ) -> AsyncGenerator [ ClientSession , None ] :
185207 """
186208 Establish a session with an MCP server within an async context
187209 """
188210 yield
189211
212+ async def _reconnect (self ):
213+ """
214+ Attempt to reconnect by tearing down and re-establishing the session.
215+ """
216+ async with self ._reconnect_lock :
217+ backoff = self ._reconnect_initial_backoff
218+ attempt = 0
219+ last_error : Exception | None = None
220+
221+ while attempt in range (0 , self ._reconnect_max_attempts ):
222+ attempt += 1
223+ try :
224+ # Close the existing stack and ClientSession
225+ if self ._exit_stack :
226+ await self ._exit_stack .aclose ()
227+ # Create a fresh stack and session
228+ self ._exit_stack = AsyncExitStack ()
229+ self ._session = await self ._exit_stack .enter_async_context (self .connect_to_server ())
230+
231+ self ._connection_established = True
232+ self ._tools = None
233+
234+ logger .info ("Reconnected to MCP server (%s) on attempt %d" , self .server_name , attempt )
235+ return
236+
237+ except Exception as e :
238+ last_error = e
239+ logger .warning ("Reconnect attempt %d failed for %s: %s" , attempt , self .server_name , e )
240+ await asyncio .sleep (min (backoff , self ._reconnect_max_backoff ))
241+ backoff = min (backoff * 2 , self ._reconnect_max_backoff )
242+
243+ # All attempts failed
244+ self ._connection_established = False
245+ if last_error :
246+ raise last_error
247+
248+ async def _with_reconnect (self , coro ):
249+ """
250+ Execute an awaited operation, reconnecting once on errors.
251+ """
252+ try :
253+ return await coro ()
254+ except Exception as e :
255+ if self ._reconnect_enabled :
256+ logger .warning ("MCP Client operation failed. Attempting reconnect: %s" , e )
257+ try :
258+ await self ._reconnect ()
259+ except Exception as reconnect_err :
260+ logger .error ("MCP Client reconnect attempt failed: %s" , reconnect_err )
261+ raise
262+ return await coro ()
263+ raise
264+
190265 async def get_tools (self ):
191266 """
192267 Retrieve a dictionary of all tools served by the MCP server.
193268 Uses unauthenticated session for discovery.
194269 """
195270
196- if not self ._session :
197- raise RuntimeError ("MCPBaseClient not initialized. Use async with to initialize." )
271+ async def _get_tools ():
272+ session = self ._session
273+ return await session .list_tools ()
198274
199- response = await self ._session .list_tools ()
275+ try :
276+ response = await self ._with_reconnect (_get_tools )
277+ except Exception as e :
278+ logger .warning ("Failed to get tools: %s" , e )
279+ raise
200280
201281 return {
202282 tool .name :
203283 MCPToolClient (session = self ._session ,
204284 tool_name = tool .name ,
205285 tool_description = tool .description ,
206286 tool_input_schema = tool .inputSchema ,
207- parent_client = self )
287+ parent_client = self ,
288+ tool_call_timeout = self ._tool_call_timeout )
208289 for tool in response .tools
209290 }
210291
@@ -235,11 +316,12 @@ async def get_tool(self, tool_name: str) -> MCPToolClient:
235316
236317 @mcp_exception_handler
237318 async def call_tool (self , tool_name : str , tool_args : dict | None ):
238- if not self ._session :
239- raise RuntimeError ("MCPBaseClient not initialized. Use async with to initialize." )
240319
241- result = await self ._session .call_tool (tool_name , tool_args )
242- return result
320+ async def _call_tool ():
321+ session = self ._session
322+ return await session .call_tool (tool_name , tool_args , read_timeout_seconds = self ._tool_call_timeout )
323+
324+ return await self ._with_reconnect (_call_tool )
243325
244326
245327class MCPSSEClient (MCPBaseClient ):
@@ -250,8 +332,19 @@ class MCPSSEClient(MCPBaseClient):
250332 url (str): The url of the MCP server
251333 """
252334
253- def __init__ (self , url : str ):
254- super ().__init__ ("sse" )
335+ def __init__ (self ,
336+ url : str ,
337+ tool_call_timeout : timedelta = timedelta (seconds = 5 ),
338+ reconnect_enabled : bool = True ,
339+ reconnect_max_attempts : int = 2 ,
340+ reconnect_initial_backoff : float = 0.5 ,
341+ reconnect_max_backoff : float = 50.0 ):
342+ super ().__init__ ("sse" ,
343+ tool_call_timeout = tool_call_timeout ,
344+ reconnect_enabled = reconnect_enabled ,
345+ reconnect_max_attempts = reconnect_max_attempts ,
346+ reconnect_initial_backoff = reconnect_initial_backoff ,
347+ reconnect_max_backoff = reconnect_max_backoff )
255348 self ._url = url
256349
257350 @property
@@ -286,8 +379,21 @@ class MCPStdioClient(MCPBaseClient):
286379 env (dict[str, str] | None): Environment variables to set for the process
287380 """
288381
289- def __init__ (self , command : str , args : list [str ] | None = None , env : dict [str , str ] | None = None ):
290- super ().__init__ ("stdio" )
382+ def __init__ (self ,
383+ command : str ,
384+ args : list [str ] | None = None ,
385+ env : dict [str , str ] | None = None ,
386+ tool_call_timeout : timedelta = timedelta (seconds = 5 ),
387+ reconnect_enabled : bool = True ,
388+ reconnect_max_attempts : int = 2 ,
389+ reconnect_initial_backoff : float = 0.5 ,
390+ reconnect_max_backoff : float = 50.0 ):
391+ super ().__init__ ("stdio" ,
392+ tool_call_timeout = tool_call_timeout ,
393+ reconnect_enabled = reconnect_enabled ,
394+ reconnect_max_attempts = reconnect_max_attempts ,
395+ reconnect_initial_backoff = reconnect_initial_backoff ,
396+ reconnect_max_backoff = reconnect_max_backoff )
291397 self ._command = command
292398 self ._args = args
293399 self ._env = env
@@ -331,8 +437,21 @@ class MCPStreamableHTTPClient(MCPBaseClient):
331437 auth_provider (AuthProviderBase | None): Optional authentication provider for Bearer token injection
332438 """
333439
334- def __init__ (self , url : str , auth_provider : AuthProviderBase | None = None ):
335- super ().__init__ ("streamable-http" , auth_provider = auth_provider )
440+ def __init__ (self ,
441+ url : str ,
442+ auth_provider : AuthProviderBase | None = None ,
443+ tool_call_timeout : timedelta = timedelta (seconds = 5 ),
444+ reconnect_enabled : bool = True ,
445+ reconnect_max_attempts : int = 2 ,
446+ reconnect_initial_backoff : float = 0.5 ,
447+ reconnect_max_backoff : float = 50.0 ):
448+ super ().__init__ ("streamable-http" ,
449+ auth_provider = auth_provider ,
450+ tool_call_timeout = tool_call_timeout ,
451+ reconnect_enabled = reconnect_enabled ,
452+ reconnect_max_attempts = reconnect_max_attempts ,
453+ reconnect_initial_backoff = reconnect_initial_backoff ,
454+ reconnect_max_backoff = reconnect_max_backoff )
336455 self ._url = url
337456
338457 @property
@@ -371,15 +490,20 @@ class MCPToolClient:
371490
372491 def __init__ (self ,
373492 session : ClientSession ,
493+ parent_client : "MCPBaseClient" ,
374494 tool_name : str ,
375495 tool_description : str | None ,
376496 tool_input_schema : dict | None = None ,
377- parent_client : "MCPBaseClient | None" = None ):
497+ tool_call_timeout : timedelta = timedelta ( seconds = 5 ) ):
378498 self ._session = session
379499 self ._tool_name = tool_name
380500 self ._tool_description = tool_description
381501 self ._input_schema = (model_from_mcp_schema (self ._tool_name , tool_input_schema ) if tool_input_schema else None )
382502 self ._parent_client = parent_client
503+ self ._tool_call_timeout = tool_call_timeout
504+
505+ if self ._parent_client is None :
506+ raise RuntimeError ("MCPToolClient initialized without a parent client." )
383507
384508 @property
385509 def name (self ):
@@ -415,22 +539,25 @@ async def acall(self, tool_args: dict) -> str:
415539 Args:
416540 tool_args (dict[str, Any]): A dictionary of key value pairs to serve as inputs for the MCP tool.
417541 """
418- if self ._session is None :
419- raise RuntimeError ("No session available for tool call" )
420542 logger .info ("Calling tool %s with arguments %s" , self ._tool_name , tool_args )
421- result = await self ._session .call_tool (self ._tool_name , tool_args )
422-
423- output = []
424-
425- for res in result .content :
426- if isinstance (res , TextContent ):
427- output .append (res .text )
428- else :
429- # Log non-text content for now
430- logger .warning ("Got not-text output from %s of type %s" , self .name , type (res ))
431- result_str = "\n " .join (output )
432-
433- if result .isError :
434- raise RuntimeError (result_str )
543+ try :
544+ result = await self ._parent_client .call_tool (self ._tool_name , tool_args )
545+
546+ output = []
547+ for res in result .content :
548+ if isinstance (res , TextContent ):
549+ output .append (res .text )
550+ else :
551+ # Log non-text content for now
552+ logger .warning ("Got not-text output from %s of type %s" , self .name , type (res ))
553+ result_str = "\n " .join (output )
554+
555+ if result .isError :
556+ mcp_error : MCPError = convert_to_mcp_error (RuntimeError (result_str ), self ._parent_client .server_name )
557+ raise mcp_error
558+
559+ except MCPError as e :
560+ format_mcp_error (e , include_traceback = False )
561+ result_str = "MCPToolClient tool call failed: %s" % e .original_exception
435562
436563 return result_str
0 commit comments