99"""
1010
1111import contextlib
12+ import logging
1213from collections .abc import Callable
1314from datetime import timedelta
15+ from types import TracebackType
1416from typing import Any , TypeAlias
1517
18+ import anyio
1619from pydantic import BaseModel
20+ from typing_extensions import Self
1721
1822import mcp
1923from mcp import types
@@ -67,11 +71,18 @@ class ClientSessionGroup:
6771 """Client for managing connections to multiple MCP servers.
6872
6973 This class is responsible for encapsulating management of server connections.
70- It it aggregates tools, resources, and prompts from all connected servers.
74+ It aggregates tools, resources, and prompts from all connected servers.
7175
7276 For auxiliary handlers, such as resource subscription, this is delegated to
73- the client and can be accessed via the session. For example:
74- mcp_session_group.get_session("server_name").subscribe_to_resource(...)
77+ the client and can be accessed via the session.
78+
79+ Example Usage:
80+ name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
81+ async with ClientSessionGroup(component_name_hook=name_fn) as group:
82+ for server_params in server_params:
83+ group.connect_to_server(server_param)
84+ ...
85+
7586 """
7687
7788 class _ComponentNames (BaseModel ):
@@ -90,6 +101,7 @@ class _ComponentNames(BaseModel):
90101 _sessions : dict [mcp .ClientSession , _ComponentNames ]
91102 _tool_to_session : dict [str , mcp .ClientSession ]
92103 _exit_stack : contextlib .AsyncExitStack
104+ _session_exit_stacks : dict [mcp .ClientSession , contextlib .AsyncExitStack ]
93105
94106 # Optional fn consuming (component_name, serverInfo) for custom names.
95107 # This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +111,7 @@ class _ComponentNames(BaseModel):
99111
100112 def __init__ (
101113 self ,
102- exit_stack : contextlib .AsyncExitStack = contextlib . AsyncExitStack () ,
114+ exit_stack : contextlib .AsyncExitStack | None = None ,
103115 component_name_hook : _ComponentNameHook | None = None ,
104116 ) -> None :
105117 """Initializes the MCP client."""
@@ -110,9 +122,43 @@ def __init__(
110122
111123 self ._sessions = {}
112124 self ._tool_to_session = {}
113- self ._exit_stack = exit_stack
125+ if exit_stack is None :
126+ self ._exit_stack = contextlib .AsyncExitStack ()
127+ self ._owns_exit_stack = True
128+ else :
129+ self ._exit_stack = exit_stack
130+ self ._owns_exit_stack = False
131+ self ._session_exit_stacks = {}
114132 self ._component_name_hook = component_name_hook
115133
134+ async def __aenter__ (self ) -> Self :
135+ # Enter the exit stack only if we created it ourselves
136+ if self ._owns_exit_stack :
137+ await self ._exit_stack .__aenter__ ()
138+ return self
139+
140+ async def __aexit__ (
141+ self ,
142+ _exc_type : type [BaseException ] | None ,
143+ _exc_val : BaseException | None ,
144+ _exc_tb : TracebackType | None ,
145+ ) -> bool | None :
146+ """Closes session exit stacks and main exit stack upon completion."""
147+
148+ # Concurrently close session stacks.
149+ async with anyio .create_task_group () as tg :
150+ for exit_stack in self ._session_exit_stacks .values ():
151+ tg .start_soon (exit_stack .aclose )
152+
153+ # Only close the main exit stack if we created it
154+ if self ._owns_exit_stack :
155+ await self ._exit_stack .aclose ()
156+
157+ @property
158+ def sessions (self ) -> list [mcp .ClientSession ]:
159+ """Returns the list of sessions being managed."""
160+ return list (self ._sessions .keys ())
161+
116162 @property
117163 def prompts (self ) -> dict [str , types .Prompt ]:
118164 """Returns the prompts as a dictionary of names to prompts."""
@@ -131,42 +177,113 @@ def tools(self) -> dict[str, types.Tool]:
131177 async def call_tool (self , name : str , args : dict [str , Any ]) -> types .CallToolResult :
132178 """Executes a tool given its name and arguments."""
133179 session = self ._tool_to_session [name ]
134- return await session .call_tool (name , args )
180+ session_tool_name = self .tools [name ].name
181+ return await session .call_tool (session_tool_name , args )
135182
136- def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
183+ async def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
137184 """Disconnects from a single MCP server."""
138185
139- if session not in self ._sessions :
186+ session_known_for_components = session in self ._sessions
187+ session_known_for_stack = session in self ._session_exit_stacks
188+
189+ if not session_known_for_components and not session_known_for_stack :
140190 raise McpError (
141191 types .ErrorData (
142192 code = types .INVALID_PARAMS ,
143- message = "Provided session is not being managed." ,
193+ message = "Provided session is not managed or already disconnected ." ,
144194 )
145195 )
146- component_names = self ._sessions [session ]
147-
148- # Remove prompts associated with the session.
149- for name in component_names .prompts :
150- del self ._prompts [name ]
151196
152- # Remove resources associated with the session.
153- for name in component_names .resources :
154- del self ._resources [name ]
155-
156- # Remove tools associated with the session.
157- for name in component_names .tools :
158- del self ._tools [name ]
159-
160- del self ._sessions [session ]
197+ if session_known_for_components :
198+ component_names = self ._sessions .pop (session ) # Pop from _sessions tracking
199+
200+ # Remove prompts associated with the session.
201+ for name in component_names .prompts :
202+ if name in self ._prompts :
203+ del self ._prompts [name ]
204+ # Remove resources associated with the session.
205+ for name in component_names .resources :
206+ if name in self ._resources :
207+ del self ._resources [name ]
208+ # Remove tools associated with the session.
209+ for name in component_names .tools :
210+ if name in self ._tools :
211+ del self ._tools [name ]
212+ if name in self ._tool_to_session :
213+ del self ._tool_to_session [name ]
214+
215+ # Clean up the session's resources via its dedicated exit stack
216+ if session_known_for_stack :
217+ session_stack_to_close = self ._session_exit_stacks .pop (session )
218+ await session_stack_to_close .aclose ()
219+
220+ async def connect_with_session (
221+ self , server_info : types .Implementation , session : mcp .ClientSession
222+ ) -> mcp .ClientSession :
223+ """Connects to a single MCP server."""
224+ await self ._aggregate_components (server_info , session )
225+ return session
161226
162227 async def connect_to_server (
163228 self ,
164229 server_params : ServerParameters ,
165230 ) -> mcp .ClientSession :
166231 """Connects to a single MCP server."""
167-
168- # Establish server connection and create session.
169232 server_info , session = await self ._establish_session (server_params )
233+ return await self .connect_with_session (server_info , session )
234+
235+ async def _establish_session (
236+ self , server_params : ServerParameters
237+ ) -> tuple [types .Implementation , mcp .ClientSession ]:
238+ """Establish a client session to an MCP server."""
239+
240+ session_stack = contextlib .AsyncExitStack ()
241+ try :
242+ # Create read and write streams that facilitate io with the server.
243+ if isinstance (server_params , StdioServerParameters ):
244+ client = mcp .stdio_client (server_params )
245+ read , write = await session_stack .enter_async_context (client )
246+ elif isinstance (server_params , SseServerParameters ):
247+ client = sse_client (
248+ url = server_params .url ,
249+ headers = server_params .headers ,
250+ timeout = server_params .timeout ,
251+ sse_read_timeout = server_params .sse_read_timeout ,
252+ )
253+ read , write = await session_stack .enter_async_context (client )
254+ else :
255+ client = streamablehttp_client (
256+ url = server_params .url ,
257+ headers = server_params .headers ,
258+ timeout = server_params .timeout ,
259+ sse_read_timeout = server_params .sse_read_timeout ,
260+ terminate_on_close = server_params .terminate_on_close ,
261+ )
262+ read , write , _ = await session_stack .enter_async_context (client )
263+
264+ session = await session_stack .enter_async_context (
265+ mcp .ClientSession (read , write )
266+ )
267+ result = await session .initialize ()
268+
269+ # Session successfully initialized.
270+ # Store its stack and register the stack with the main group stack.
271+ self ._session_exit_stacks [session ] = session_stack
272+ # session_stack itself becomes a resource managed by the
273+ # main _exit_stack.
274+ await self ._exit_stack .enter_async_context (session_stack )
275+
276+ return result .serverInfo , session
277+ except Exception :
278+ # If anything during this setup fails, ensure the session-specific
279+ # stack is closed.
280+ await session_stack .aclose ()
281+ raise
282+
283+ async def _aggregate_components (
284+ self , server_info : types .Implementation , session : mcp .ClientSession
285+ ) -> None :
286+ """Aggregates prompts, resources, and tools from a given session."""
170287
171288 # Create a reverse index so we can find all prompts, resources, and
172289 # tools belonging to this session. Used for removing components from
@@ -181,47 +298,66 @@ async def connect_to_server(
181298 tool_to_session_temp : dict [str , mcp .ClientSession ] = {}
182299
183300 # Query the server for its prompts and aggregate to list.
184- prompts = (await session .list_prompts ()).prompts
185- for prompt in prompts :
186- name = self ._component_name (prompt .name , server_info )
187- if name in self ._prompts :
188- raise McpError (
189- types .ErrorData (
190- code = types .INVALID_PARAMS ,
191- message = f"{ name } already exists in group prompts." ,
192- )
193- )
194- prompts_temp [name ] = prompt
195- component_names .prompts .add (name )
301+ try :
302+ prompts = (await session .list_prompts ()).prompts
303+ for prompt in prompts :
304+ name = self ._component_name (prompt .name , server_info )
305+ prompts_temp [name ] = prompt
306+ component_names .prompts .add (name )
307+ except McpError as err :
308+ logging .warning (f"Could not fetch prompts: { err } " )
196309
197310 # Query the server for its resources and aggregate to list.
198- resources = (await session .list_resources ()).resources
199- for resource in resources :
200- name = self ._component_name (resource .name , server_info )
201- if name in self ._resources :
202- raise McpError (
203- types .ErrorData (
204- code = types .INVALID_PARAMS ,
205- message = f"{ name } already exists in group resources." ,
206- )
207- )
208- resources_temp [name ] = resource
209- component_names .resources .add (name )
311+ try :
312+ resources = (await session .list_resources ()).resources
313+ for resource in resources :
314+ name = self ._component_name (resource .name , server_info )
315+ resources_temp [name ] = resource
316+ component_names .resources .add (name )
317+ except McpError as err :
318+ logging .warning (f"Could not fetch resources: { err } " )
210319
211320 # Query the server for its tools and aggregate to list.
212- tools = (await session .list_tools ()).tools
213- for tool in tools :
214- name = self ._component_name (tool .name , server_info )
215- if name in self ._tools :
216- raise McpError (
217- types .ErrorData (
218- code = types .INVALID_PARAMS ,
219- message = f"{ name } already exists in group tools." ,
220- )
321+ try :
322+ tools = (await session .list_tools ()).tools
323+ for tool in tools :
324+ name = self ._component_name (tool .name , server_info )
325+ tools_temp [name ] = tool
326+ tool_to_session_temp [name ] = session
327+ component_names .tools .add (name )
328+ except McpError as err :
329+ logging .warning (f"Could not fetch tools: { err } " )
330+
331+ # Clean up exit stack for session if we couldn't retrieve anything
332+ # from the server.
333+ if not any ((prompts_temp , resources_temp , tools_temp )):
334+ del self ._session_exit_stacks [session ]
335+
336+ # Check for duplicates.
337+ matching_prompts = prompts_temp .keys () & self ._prompts .keys ()
338+ if matching_prompts :
339+ raise McpError (
340+ types .ErrorData (
341+ code = types .INVALID_PARAMS ,
342+ message = f"{ matching_prompts } already exist in group prompts." ,
343+ )
344+ )
345+ matching_resources = resources_temp .keys () & self ._resources .keys ()
346+ if matching_resources :
347+ raise McpError (
348+ types .ErrorData (
349+ code = types .INVALID_PARAMS ,
350+ message = f"{ matching_resources } already exist in group resources." ,
351+ )
352+ )
353+ matching_tools = tools_temp .keys () & self ._tools .keys ()
354+ if matching_tools :
355+ raise McpError (
356+ types .ErrorData (
357+ code = types .INVALID_PARAMS ,
358+ message = f"{ matching_tools } already exist in group tools." ,
221359 )
222- tools_temp [name ] = tool
223- tool_to_session_temp [name ] = session
224- component_names .tools .add (name )
360+ )
225361
226362 # Aggregate components.
227363 self ._sessions [session ] = component_names
@@ -230,41 +366,6 @@ async def connect_to_server(
230366 self ._tools .update (tools_temp )
231367 self ._tool_to_session .update (tool_to_session_temp )
232368
233- return session
234-
235- async def _establish_session (
236- self , server_params : ServerParameters
237- ) -> tuple [types .Implementation , mcp .ClientSession ]:
238- """Establish a client session to an MCP server."""
239-
240- # Create read and write streams that facilitate io with the server.
241- if isinstance (server_params , StdioServerParameters ):
242- client = mcp .stdio_client (server_params )
243- read , write = await self ._exit_stack .enter_async_context (client )
244- elif isinstance (server_params , SseServerParameters ):
245- client = sse_client (
246- url = server_params .url ,
247- headers = server_params .headers ,
248- timeout = server_params .timeout ,
249- sse_read_timeout = server_params .sse_read_timeout ,
250- )
251- read , write = await self ._exit_stack .enter_async_context (client )
252- else :
253- client = streamablehttp_client (
254- url = server_params .url ,
255- headers = server_params .headers ,
256- timeout = server_params .timeout ,
257- sse_read_timeout = server_params .sse_read_timeout ,
258- terminate_on_close = server_params .terminate_on_close ,
259- )
260- read , write , _ = await self ._exit_stack .enter_async_context (client )
261-
262- session = await self ._exit_stack .enter_async_context (
263- mcp .ClientSession (read , write )
264- )
265- result = await session .initialize ()
266- return result .serverInfo , session
267-
268369 def _component_name (self , name : str , server_info : types .Implementation ) -> str :
269370 if self ._component_name_hook :
270371 return self ._component_name_hook (name , server_info )
0 commit comments