99""" 
1010
1111import  contextlib 
12+ import  logging 
1213from  collections .abc  import  Callable 
1314from  datetime  import  timedelta 
15+ from  types  import  TracebackType 
1416from  typing  import  Any , TypeAlias 
1517
18+ import  anyio 
1619from  pydantic  import  BaseModel 
20+ from  typing_extensions  import  Self 
1721
1822import  mcp 
1923from  mcp  import  types 
@@ -67,11 +71,18 @@ class ClientSessionGroup:
6771    """Client for managing connections to multiple MCP servers. 
6872
6973    This class is responsible for encapsulating management of server connections. 
70-     It it  aggregates tools, resources, and prompts from all connected servers. 
74+     It aggregates tools, resources, and prompts from all connected servers. 
7175
7276    For auxiliary handlers, such as resource subscription, this is delegated to 
73-     the client and can be accessed via the session. For example: 
74-       mcp_session_group.get_session("server_name").subscribe_to_resource(...) 
77+     the client and can be accessed via the session. 
78+ 
79+     Example Usage: 
80+         name_fn = lambda name, server_info: f"{(server_info.name)}-{name}" 
81+         async with ClientSessionGroup(component_name_hook=name_fn) as group: 
82+             for server_params in server_params: 
83+                 group.connect_to_server(server_param) 
84+             ... 
85+ 
7586    """ 
7687
7788    class  _ComponentNames (BaseModel ):
@@ -90,6 +101,7 @@ class _ComponentNames(BaseModel):
90101    _sessions : dict [mcp .ClientSession , _ComponentNames ]
91102    _tool_to_session : dict [str , mcp .ClientSession ]
92103    _exit_stack : contextlib .AsyncExitStack 
104+     _session_exit_stacks : dict [mcp .ClientSession , contextlib .AsyncExitStack ]
93105
94106    # Optional fn consuming (component_name, serverInfo) for custom names. 
95107    # This is provide a means to mitigate naming conflicts across servers. 
@@ -99,7 +111,7 @@ class _ComponentNames(BaseModel):
99111
100112    def  __init__ (
101113        self ,
102-         exit_stack : contextlib .AsyncExitStack  =   contextlib . AsyncExitStack () ,
114+         exit_stack : contextlib .AsyncExitStack  |   None   =   None ,
103115        component_name_hook : _ComponentNameHook  |  None  =  None ,
104116    ) ->  None :
105117        """Initializes the MCP client.""" 
@@ -110,9 +122,43 @@ def __init__(
110122
111123        self ._sessions  =  {}
112124        self ._tool_to_session  =  {}
113-         self ._exit_stack  =  exit_stack 
125+         if  exit_stack  is  None :
126+             self ._exit_stack  =  contextlib .AsyncExitStack ()
127+             self ._owns_exit_stack  =  True 
128+         else :
129+             self ._exit_stack  =  exit_stack 
130+             self ._owns_exit_stack  =  False 
131+         self ._session_exit_stacks  =  {}
114132        self ._component_name_hook  =  component_name_hook 
115133
134+     async  def  __aenter__ (self ) ->  Self :
135+         # Enter the exit stack only if we created it ourselves 
136+         if  self ._owns_exit_stack :
137+             await  self ._exit_stack .__aenter__ ()
138+         return  self 
139+ 
140+     async  def  __aexit__ (
141+         self ,
142+         _exc_type : type [BaseException ] |  None ,
143+         _exc_val : BaseException  |  None ,
144+         _exc_tb : TracebackType  |  None ,
145+     ) ->  bool  |  None :
146+         """Closes session exit stacks and main exit stack upon completion.""" 
147+ 
148+         # Concurrently close session stacks. 
149+         async  with  anyio .create_task_group () as  tg :
150+             for  exit_stack  in  self ._session_exit_stacks .values ():
151+                 tg .start_soon (exit_stack .aclose )
152+ 
153+         # Only close the main exit stack if we created it 
154+         if  self ._owns_exit_stack :
155+             await  self ._exit_stack .aclose ()
156+ 
157+     @property  
158+     def  sessions (self ) ->  list [mcp .ClientSession ]:
159+         """Returns the list of sessions being managed.""" 
160+         return  list (self ._sessions .keys ())
161+ 
116162    @property  
117163    def  prompts (self ) ->  dict [str , types .Prompt ]:
118164        """Returns the prompts as a dictionary of names to prompts.""" 
@@ -131,33 +177,45 @@ def tools(self) -> dict[str, types.Tool]:
131177    async  def  call_tool (self , name : str , args : dict [str , Any ]) ->  types .CallToolResult :
132178        """Executes a tool given its name and arguments.""" 
133179        session  =  self ._tool_to_session [name ]
134-         return  await  session .call_tool (name , args )
180+         session_tool_name  =  self .tools [name ].name 
181+         return  await  session .call_tool (session_tool_name , args )
135182
136-     def  disconnect_from_server (self , session : mcp .ClientSession ) ->  None :
183+     async   def  disconnect_from_server (self , session : mcp .ClientSession ) ->  None :
137184        """Disconnects from a single MCP server.""" 
138185
139-         if  session  not  in self ._sessions :
186+         session_known_for_components  =  session  in  self ._sessions 
187+         session_known_for_stack  =  session  in  self ._session_exit_stacks 
188+ 
189+         if  not  session_known_for_components  and  not  session_known_for_stack :
140190            raise  McpError (
141191                types .ErrorData (
142192                    code = types .INVALID_PARAMS ,
143-                     message = "Provided session is not being  managed." ,
193+                     message = "Provided session is not managed or already disconnected ." ,
144194                )
145195            )
146-         component_names  =  self ._sessions [session ]
147- 
148-         # Remove prompts associated with the session. 
149-         for  name  in  component_names .prompts :
150-             del  self ._prompts [name ]
151- 
152-         # Remove resources associated with the session. 
153-         for  name  in  component_names .resources :
154-             del  self ._resources [name ]
155- 
156-         # Remove tools associated with the session. 
157-         for  name  in  component_names .tools :
158-             del  self ._tools [name ]
159196
160-         del  self ._sessions [session ]
197+         if  session_known_for_components :
198+             component_names  =  self ._sessions .pop (session )  # Pop from _sessions tracking 
199+ 
200+             # Remove prompts associated with the session. 
201+             for  name  in  component_names .prompts :
202+                 if  name  in  self ._prompts :
203+                     del  self ._prompts [name ]
204+             # Remove resources associated with the session. 
205+             for  name  in  component_names .resources :
206+                 if  name  in  self ._resources :
207+                     del  self ._resources [name ]
208+             # Remove tools associated with the session. 
209+             for  name  in  component_names .tools :
210+                 if  name  in  self ._tools :
211+                     del  self ._tools [name ]
212+                 if  name  in  self ._tool_to_session :
213+                     del  self ._tool_to_session [name ]
214+ 
215+         # Clean up the session's resources via its dedicated exit stack 
216+         if  session_known_for_stack :
217+             session_stack_to_close  =  self ._session_exit_stacks .pop (session )
218+             await  session_stack_to_close .aclose ()
161219
162220    async  def  connect_to_server (
163221        self ,
@@ -181,47 +239,66 @@ async def connect_to_server(
181239        tool_to_session_temp : dict [str , mcp .ClientSession ] =  {}
182240
183241        # Query the server for its prompts and aggregate to list. 
184-         prompts  =  (await  session .list_prompts ()).prompts 
185-         for  prompt  in  prompts :
186-             name  =  self ._component_name (prompt .name , server_info )
187-             if  name  in  self ._prompts :
188-                 raise  McpError (
189-                     types .ErrorData (
190-                         code = types .INVALID_PARAMS ,
191-                         message = f"{ name }  ,
192-                     )
193-                 )
194-             prompts_temp [name ] =  prompt 
195-             component_names .prompts .add (name )
242+         try :
243+             prompts  =  (await  session .list_prompts ()).prompts 
244+             for  prompt  in  prompts :
245+                 name  =  self ._component_name (prompt .name , server_info )
246+                 prompts_temp [name ] =  prompt 
247+                 component_names .prompts .add (name )
248+         except  McpError  as  err :
249+             logging .warning (f"Could not fetch prompts: { err }  )
196250
197251        # Query the server for its resources and aggregate to list. 
198-         resources  =  (await  session .list_resources ()).resources 
199-         for  resource  in  resources :
200-             name  =  self ._component_name (resource .name , server_info )
201-             if  name  in  self ._resources :
202-                 raise  McpError (
203-                     types .ErrorData (
204-                         code = types .INVALID_PARAMS ,
205-                         message = f"{ name }  ,
206-                     )
207-                 )
208-             resources_temp [name ] =  resource 
209-             component_names .resources .add (name )
252+         try :
253+             resources  =  (await  session .list_resources ()).resources 
254+             for  resource  in  resources :
255+                 name  =  self ._component_name (resource .name , server_info )
256+                 resources_temp [name ] =  resource 
257+                 component_names .resources .add (name )
258+         except  McpError  as  err :
259+             logging .warning (f"Could not fetch resources: { err }  )
210260
211261        # Query the server for its tools and aggregate to list. 
212-         tools  =  (await  session .list_tools ()).tools 
213-         for  tool  in  tools :
214-             name  =  self ._component_name (tool .name , server_info )
215-             if  name  in  self ._tools :
216-                 raise  McpError (
217-                     types .ErrorData (
218-                         code = types .INVALID_PARAMS ,
219-                         message = f"{ name }  ,
220-                     )
262+         try :
263+             tools  =  (await  session .list_tools ()).tools 
264+             for  tool  in  tools :
265+                 name  =  self ._component_name (tool .name , server_info )
266+                 tools_temp [name ] =  tool 
267+                 tool_to_session_temp [name ] =  session 
268+                 component_names .tools .add (name )
269+         except  McpError  as  err :
270+             logging .warning (f"Could not fetch tools: { err }  )
271+ 
272+         # Clean up exit stack for session if we couldn't retrieve anything 
273+         # from the server. 
274+         if  not  any ((prompts_temp , resources_temp , tools_temp )):
275+             del  self ._session_exit_stacks [session ]
276+ 
277+         # Check for duplicates. 
278+         matching_prompts  =  prompts_temp .keys () &  self ._prompts .keys ()
279+         if  matching_prompts :
280+             raise  McpError (
281+                 types .ErrorData (
282+                     code = types .INVALID_PARAMS ,
283+                     message = f"{ matching_prompts }  ,
284+                 )
285+             )
286+         matching_resources  =  resources_temp .keys () &  self ._resources .keys ()
287+         if  matching_resources :
288+             raise  McpError (
289+                 types .ErrorData (
290+                     code = types .INVALID_PARAMS ,
291+                     message = f"{ matching_resources }  ,
292+                 )
293+             )
294+         matching_tools  =  tools_temp .keys () &  self ._tools .keys ()
295+         if  matching_tools :
296+             raise  McpError (
297+                 types .ErrorData (
298+                     code = types .INVALID_PARAMS ,
299+                     message = f"{ matching_tools }  ,
221300                )
222-             tools_temp [name ] =  tool 
223-             tool_to_session_temp [name ] =  session 
224-             component_names .tools .add (name )
301+             )
225302
226303        # Aggregate components. 
227304        self ._sessions [session ] =  component_names 
@@ -237,33 +314,48 @@ async def _establish_session(
237314    ) ->  tuple [types .Implementation , mcp .ClientSession ]:
238315        """Establish a client session to an MCP server.""" 
239316
240-         # Create read and write streams that facilitate io with the server. 
241-         if  isinstance (server_params , StdioServerParameters ):
242-             client  =  mcp .stdio_client (server_params )
243-             read , write  =  await  self ._exit_stack .enter_async_context (client )
244-         elif  isinstance (server_params , SseServerParameters ):
245-             client  =  sse_client (
246-                 url = server_params .url ,
247-                 headers = server_params .headers ,
248-                 timeout = server_params .timeout ,
249-                 sse_read_timeout = server_params .sse_read_timeout ,
250-             )
251-             read , write  =  await  self ._exit_stack .enter_async_context (client )
252-         else :
253-             client  =  streamablehttp_client (
254-                 url = server_params .url ,
255-                 headers = server_params .headers ,
256-                 timeout = server_params .timeout ,
257-                 sse_read_timeout = server_params .sse_read_timeout ,
258-                 terminate_on_close = server_params .terminate_on_close ,
259-             )
260-             read , write , _  =  await  self ._exit_stack .enter_async_context (client )
317+         session_stack  =  contextlib .AsyncExitStack ()
318+         try :
319+             # Create read and write streams that facilitate io with the server. 
320+             if  isinstance (server_params , StdioServerParameters ):
321+                 client  =  mcp .stdio_client (server_params )
322+                 read , write  =  await  session_stack .enter_async_context (client )
323+             elif  isinstance (server_params , SseServerParameters ):
324+                 client  =  sse_client (
325+                     url = server_params .url ,
326+                     headers = server_params .headers ,
327+                     timeout = server_params .timeout ,
328+                     sse_read_timeout = server_params .sse_read_timeout ,
329+                 )
330+                 read , write  =  await  session_stack .enter_async_context (client )
331+             else :
332+                 client  =  streamablehttp_client (
333+                     url = server_params .url ,
334+                     headers = server_params .headers ,
335+                     timeout = server_params .timeout ,
336+                     sse_read_timeout = server_params .sse_read_timeout ,
337+                     terminate_on_close = server_params .terminate_on_close ,
338+                 )
339+                 read , write , _  =  await  session_stack .enter_async_context (client )
261340
262-         session  =  await  self ._exit_stack .enter_async_context (
263-             mcp .ClientSession (read , write )
264-         )
265-         result  =  await  session .initialize ()
266-         return  result .serverInfo , session 
341+             session  =  await  session_stack .enter_async_context (
342+                 mcp .ClientSession (read , write )
343+             )
344+             result  =  await  session .initialize ()
345+ 
346+             # Session successfully initialized. 
347+             # Store its stack and register the stack with the main group stack. 
348+             self ._session_exit_stacks [session ] =  session_stack 
349+             # session_stack itself becomes a resource managed by the 
350+             # main _exit_stack. 
351+             await  self ._exit_stack .enter_async_context (session_stack )
352+ 
353+             return  result .serverInfo , session 
354+         except  Exception :
355+             # If anything during this setup fails, ensure the session-specific 
356+             # stack is closed. 
357+             await  session_stack .aclose ()
358+             raise 
267359
268360    def  _component_name (self , name : str , server_info : types .Implementation ) ->  str :
269361        if  self ._component_name_hook :
0 commit comments