|  | 
|  | 1 | +""" | 
|  | 2 | +SessionGroup concurrently manages multiple MCP session connections. | 
|  | 3 | +
 | 
|  | 4 | +Tools, resources, and prompts are aggregated across servers. Servers may | 
|  | 5 | +be connected to or disconnected from at any point after initialization. | 
|  | 6 | +
 | 
|  | 7 | +This abstractions can handle naming collisions using a custom user-provided | 
|  | 8 | +hook. | 
|  | 9 | +""" | 
|  | 10 | + | 
|  | 11 | +import contextlib | 
|  | 12 | +import logging | 
|  | 13 | +from collections.abc import Callable | 
|  | 14 | +from datetime import timedelta | 
|  | 15 | +from types import TracebackType | 
|  | 16 | +from typing import Any, TypeAlias | 
|  | 17 | + | 
|  | 18 | +import anyio | 
|  | 19 | +from pydantic import BaseModel | 
|  | 20 | +from typing_extensions import Self | 
|  | 21 | + | 
|  | 22 | +import mcp | 
|  | 23 | +from mcp import types | 
|  | 24 | +from mcp.client.sse import sse_client | 
|  | 25 | +from mcp.client.stdio import StdioServerParameters | 
|  | 26 | +from mcp.client.streamable_http import streamablehttp_client | 
|  | 27 | +from mcp.shared.exceptions import McpError | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +class SseServerParameters(BaseModel): | 
|  | 31 | +    """Parameters for intializing a sse_client.""" | 
|  | 32 | + | 
|  | 33 | +    # The endpoint URL. | 
|  | 34 | +    url: str | 
|  | 35 | + | 
|  | 36 | +    # Optional headers to include in requests. | 
|  | 37 | +    headers: dict[str, Any] | None = None | 
|  | 38 | + | 
|  | 39 | +    # HTTP timeout for regular operations. | 
|  | 40 | +    timeout: float = 5 | 
|  | 41 | + | 
|  | 42 | +    # Timeout for SSE read operations. | 
|  | 43 | +    sse_read_timeout: float = 60 * 5 | 
|  | 44 | + | 
|  | 45 | + | 
|  | 46 | +class StreamableHttpParameters(BaseModel): | 
|  | 47 | +    """Parameters for intializing a streamablehttp_client.""" | 
|  | 48 | + | 
|  | 49 | +    # The endpoint URL. | 
|  | 50 | +    url: str | 
|  | 51 | + | 
|  | 52 | +    # Optional headers to include in requests. | 
|  | 53 | +    headers: dict[str, Any] | None = None | 
|  | 54 | + | 
|  | 55 | +    # HTTP timeout for regular operations. | 
|  | 56 | +    timeout: timedelta = timedelta(seconds=30) | 
|  | 57 | + | 
|  | 58 | +    # Timeout for SSE read operations. | 
|  | 59 | +    sse_read_timeout: timedelta = timedelta(seconds=60 * 5) | 
|  | 60 | + | 
|  | 61 | +    # Close the client session when the transport closes. | 
|  | 62 | +    terminate_on_close: bool = True | 
|  | 63 | + | 
|  | 64 | + | 
|  | 65 | +ServerParameters: TypeAlias = ( | 
|  | 66 | +    StdioServerParameters | SseServerParameters | StreamableHttpParameters | 
|  | 67 | +) | 
|  | 68 | + | 
|  | 69 | + | 
|  | 70 | +class ClientSessionGroup: | 
|  | 71 | +    """Client for managing connections to multiple MCP servers. | 
|  | 72 | +
 | 
|  | 73 | +    This class is responsible for encapsulating management of server connections. | 
|  | 74 | +    It aggregates tools, resources, and prompts from all connected servers. | 
|  | 75 | +
 | 
|  | 76 | +    For auxiliary handlers, such as resource subscription, this is delegated to | 
|  | 77 | +    the client and can be accessed via the session. | 
|  | 78 | +
 | 
|  | 79 | +    Example Usage: | 
|  | 80 | +        name_fn = lambda name, server_info: f"{(server_info.name)}-{name}" | 
|  | 81 | +        async with ClientSessionGroup(component_name_hook=name_fn) as group: | 
|  | 82 | +            for server_params in server_params: | 
|  | 83 | +                group.connect_to_server(server_param) | 
|  | 84 | +            ... | 
|  | 85 | +
 | 
|  | 86 | +    """ | 
|  | 87 | + | 
|  | 88 | +    class _ComponentNames(BaseModel): | 
|  | 89 | +        """Used for reverse index to find components.""" | 
|  | 90 | + | 
|  | 91 | +        prompts: set[str] = set() | 
|  | 92 | +        resources: set[str] = set() | 
|  | 93 | +        tools: set[str] = set() | 
|  | 94 | + | 
|  | 95 | +    # Standard MCP components. | 
|  | 96 | +    _prompts: dict[str, types.Prompt] | 
|  | 97 | +    _resources: dict[str, types.Resource] | 
|  | 98 | +    _tools: dict[str, types.Tool] | 
|  | 99 | + | 
|  | 100 | +    # Client-server connection management. | 
|  | 101 | +    _sessions: dict[mcp.ClientSession, _ComponentNames] | 
|  | 102 | +    _tool_to_session: dict[str, mcp.ClientSession] | 
|  | 103 | +    _exit_stack: contextlib.AsyncExitStack | 
|  | 104 | +    _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] | 
|  | 105 | + | 
|  | 106 | +    # Optional fn consuming (component_name, serverInfo) for custom names. | 
|  | 107 | +    # This is provide a means to mitigate naming conflicts across servers. | 
|  | 108 | +    # Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}" | 
|  | 109 | +    _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] | 
|  | 110 | +    _component_name_hook: _ComponentNameHook | None | 
|  | 111 | + | 
|  | 112 | +    def __init__( | 
|  | 113 | +        self, | 
|  | 114 | +        exit_stack: contextlib.AsyncExitStack | None = None, | 
|  | 115 | +        component_name_hook: _ComponentNameHook | None = None, | 
|  | 116 | +    ) -> None: | 
|  | 117 | +        """Initializes the MCP client.""" | 
|  | 118 | + | 
|  | 119 | +        self._tools = {} | 
|  | 120 | +        self._resources = {} | 
|  | 121 | +        self._prompts = {} | 
|  | 122 | + | 
|  | 123 | +        self._sessions = {} | 
|  | 124 | +        self._tool_to_session = {} | 
|  | 125 | +        if exit_stack is None: | 
|  | 126 | +            self._exit_stack = contextlib.AsyncExitStack() | 
|  | 127 | +            self._owns_exit_stack = True | 
|  | 128 | +        else: | 
|  | 129 | +            self._exit_stack = exit_stack | 
|  | 130 | +            self._owns_exit_stack = False | 
|  | 131 | +        self._session_exit_stacks = {} | 
|  | 132 | +        self._component_name_hook = component_name_hook | 
|  | 133 | + | 
|  | 134 | +    async def __aenter__(self) -> Self: | 
|  | 135 | +        # Enter the exit stack only if we created it ourselves | 
|  | 136 | +        if self._owns_exit_stack: | 
|  | 137 | +            await self._exit_stack.__aenter__() | 
|  | 138 | +        return self | 
|  | 139 | + | 
|  | 140 | +    async def __aexit__( | 
|  | 141 | +        self, | 
|  | 142 | +        _exc_type: type[BaseException] | None, | 
|  | 143 | +        _exc_val: BaseException | None, | 
|  | 144 | +        _exc_tb: TracebackType | None, | 
|  | 145 | +    ) -> bool | None: | 
|  | 146 | +        """Closes session exit stacks and main exit stack upon completion.""" | 
|  | 147 | + | 
|  | 148 | +        # Concurrently close session stacks. | 
|  | 149 | +        async with anyio.create_task_group() as tg: | 
|  | 150 | +            for exit_stack in self._session_exit_stacks.values(): | 
|  | 151 | +                tg.start_soon(exit_stack.aclose) | 
|  | 152 | + | 
|  | 153 | +        # Only close the main exit stack if we created it | 
|  | 154 | +        if self._owns_exit_stack: | 
|  | 155 | +            await self._exit_stack.aclose() | 
|  | 156 | + | 
|  | 157 | +    @property | 
|  | 158 | +    def sessions(self) -> list[mcp.ClientSession]: | 
|  | 159 | +        """Returns the list of sessions being managed.""" | 
|  | 160 | +        return list(self._sessions.keys()) | 
|  | 161 | + | 
|  | 162 | +    @property | 
|  | 163 | +    def prompts(self) -> dict[str, types.Prompt]: | 
|  | 164 | +        """Returns the prompts as a dictionary of names to prompts.""" | 
|  | 165 | +        return self._prompts | 
|  | 166 | + | 
|  | 167 | +    @property | 
|  | 168 | +    def resources(self) -> dict[str, types.Resource]: | 
|  | 169 | +        """Returns the resources as a dictionary of names to resources.""" | 
|  | 170 | +        return self._resources | 
|  | 171 | + | 
|  | 172 | +    @property | 
|  | 173 | +    def tools(self) -> dict[str, types.Tool]: | 
|  | 174 | +        """Returns the tools as a dictionary of names to tools.""" | 
|  | 175 | +        return self._tools | 
|  | 176 | + | 
|  | 177 | +    async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: | 
|  | 178 | +        """Executes a tool given its name and arguments.""" | 
|  | 179 | +        session = self._tool_to_session[name] | 
|  | 180 | +        session_tool_name = self.tools[name].name | 
|  | 181 | +        return await session.call_tool(session_tool_name, args) | 
|  | 182 | + | 
|  | 183 | +    async def disconnect_from_server(self, session: mcp.ClientSession) -> None: | 
|  | 184 | +        """Disconnects from a single MCP server.""" | 
|  | 185 | + | 
|  | 186 | +        session_known_for_components = session in self._sessions | 
|  | 187 | +        session_known_for_stack = session in self._session_exit_stacks | 
|  | 188 | + | 
|  | 189 | +        if not session_known_for_components and not session_known_for_stack: | 
|  | 190 | +            raise McpError( | 
|  | 191 | +                types.ErrorData( | 
|  | 192 | +                    code=types.INVALID_PARAMS, | 
|  | 193 | +                    message="Provided session is not managed or already disconnected.", | 
|  | 194 | +                ) | 
|  | 195 | +            ) | 
|  | 196 | + | 
|  | 197 | +        if session_known_for_components: | 
|  | 198 | +            component_names = self._sessions.pop(session)  # Pop from _sessions tracking | 
|  | 199 | + | 
|  | 200 | +            # Remove prompts associated with the session. | 
|  | 201 | +            for name in component_names.prompts: | 
|  | 202 | +                if name in self._prompts: | 
|  | 203 | +                    del self._prompts[name] | 
|  | 204 | +            # Remove resources associated with the session. | 
|  | 205 | +            for name in component_names.resources: | 
|  | 206 | +                if name in self._resources: | 
|  | 207 | +                    del self._resources[name] | 
|  | 208 | +            # Remove tools associated with the session. | 
|  | 209 | +            for name in component_names.tools: | 
|  | 210 | +                if name in self._tools: | 
|  | 211 | +                    del self._tools[name] | 
|  | 212 | +                if name in self._tool_to_session: | 
|  | 213 | +                    del self._tool_to_session[name] | 
|  | 214 | + | 
|  | 215 | +        # Clean up the session's resources via its dedicated exit stack | 
|  | 216 | +        if session_known_for_stack: | 
|  | 217 | +            session_stack_to_close = self._session_exit_stacks.pop(session) | 
|  | 218 | +            await session_stack_to_close.aclose() | 
|  | 219 | + | 
|  | 220 | +    async def connect_with_session( | 
|  | 221 | +        self, server_info: types.Implementation, session: mcp.ClientSession | 
|  | 222 | +    ) -> mcp.ClientSession: | 
|  | 223 | +        """Connects to a single MCP server.""" | 
|  | 224 | +        await self._aggregate_components(server_info, session) | 
|  | 225 | +        return session | 
|  | 226 | + | 
|  | 227 | +    async def connect_to_server( | 
|  | 228 | +        self, | 
|  | 229 | +        server_params: ServerParameters, | 
|  | 230 | +    ) -> mcp.ClientSession: | 
|  | 231 | +        """Connects to a single MCP server.""" | 
|  | 232 | +        server_info, session = await self._establish_session(server_params) | 
|  | 233 | +        return await self.connect_with_session(server_info, session) | 
|  | 234 | + | 
|  | 235 | +    async def _establish_session( | 
|  | 236 | +        self, server_params: ServerParameters | 
|  | 237 | +    ) -> tuple[types.Implementation, mcp.ClientSession]: | 
|  | 238 | +        """Establish a client session to an MCP server.""" | 
|  | 239 | + | 
|  | 240 | +        session_stack = contextlib.AsyncExitStack() | 
|  | 241 | +        try: | 
|  | 242 | +            # Create read and write streams that facilitate io with the server. | 
|  | 243 | +            if isinstance(server_params, StdioServerParameters): | 
|  | 244 | +                client = mcp.stdio_client(server_params) | 
|  | 245 | +                read, write = await session_stack.enter_async_context(client) | 
|  | 246 | +            elif isinstance(server_params, SseServerParameters): | 
|  | 247 | +                client = sse_client( | 
|  | 248 | +                    url=server_params.url, | 
|  | 249 | +                    headers=server_params.headers, | 
|  | 250 | +                    timeout=server_params.timeout, | 
|  | 251 | +                    sse_read_timeout=server_params.sse_read_timeout, | 
|  | 252 | +                ) | 
|  | 253 | +                read, write = await session_stack.enter_async_context(client) | 
|  | 254 | +            else: | 
|  | 255 | +                client = streamablehttp_client( | 
|  | 256 | +                    url=server_params.url, | 
|  | 257 | +                    headers=server_params.headers, | 
|  | 258 | +                    timeout=server_params.timeout, | 
|  | 259 | +                    sse_read_timeout=server_params.sse_read_timeout, | 
|  | 260 | +                    terminate_on_close=server_params.terminate_on_close, | 
|  | 261 | +                ) | 
|  | 262 | +                read, write, _ = await session_stack.enter_async_context(client) | 
|  | 263 | + | 
|  | 264 | +            session = await session_stack.enter_async_context( | 
|  | 265 | +                mcp.ClientSession(read, write) | 
|  | 266 | +            ) | 
|  | 267 | +            result = await session.initialize() | 
|  | 268 | + | 
|  | 269 | +            # Session successfully initialized. | 
|  | 270 | +            # Store its stack and register the stack with the main group stack. | 
|  | 271 | +            self._session_exit_stacks[session] = session_stack | 
|  | 272 | +            # session_stack itself becomes a resource managed by the | 
|  | 273 | +            # main _exit_stack. | 
|  | 274 | +            await self._exit_stack.enter_async_context(session_stack) | 
|  | 275 | + | 
|  | 276 | +            return result.serverInfo, session | 
|  | 277 | +        except Exception: | 
|  | 278 | +            # If anything during this setup fails, ensure the session-specific | 
|  | 279 | +            # stack is closed. | 
|  | 280 | +            await session_stack.aclose() | 
|  | 281 | +            raise | 
|  | 282 | + | 
|  | 283 | +    async def _aggregate_components( | 
|  | 284 | +        self, server_info: types.Implementation, session: mcp.ClientSession | 
|  | 285 | +    ) -> None: | 
|  | 286 | +        """Aggregates prompts, resources, and tools from a given session.""" | 
|  | 287 | + | 
|  | 288 | +        # Create a reverse index so we can find all prompts, resources, and | 
|  | 289 | +        # tools belonging to this session. Used for removing components from | 
|  | 290 | +        # the session group via self.disconnect_from_server. | 
|  | 291 | +        component_names = self._ComponentNames() | 
|  | 292 | + | 
|  | 293 | +        # Temporary components dicts. We do not want to modify the aggregate | 
|  | 294 | +        # lists in case of an intermediate failure. | 
|  | 295 | +        prompts_temp: dict[str, types.Prompt] = {} | 
|  | 296 | +        resources_temp: dict[str, types.Resource] = {} | 
|  | 297 | +        tools_temp: dict[str, types.Tool] = {} | 
|  | 298 | +        tool_to_session_temp: dict[str, mcp.ClientSession] = {} | 
|  | 299 | + | 
|  | 300 | +        # Query the server for its prompts and aggregate to list. | 
|  | 301 | +        try: | 
|  | 302 | +            prompts = (await session.list_prompts()).prompts | 
|  | 303 | +            for prompt in prompts: | 
|  | 304 | +                name = self._component_name(prompt.name, server_info) | 
|  | 305 | +                prompts_temp[name] = prompt | 
|  | 306 | +                component_names.prompts.add(name) | 
|  | 307 | +        except McpError as err: | 
|  | 308 | +            logging.warning(f"Could not fetch prompts: {err}") | 
|  | 309 | + | 
|  | 310 | +        # Query the server for its resources and aggregate to list. | 
|  | 311 | +        try: | 
|  | 312 | +            resources = (await session.list_resources()).resources | 
|  | 313 | +            for resource in resources: | 
|  | 314 | +                name = self._component_name(resource.name, server_info) | 
|  | 315 | +                resources_temp[name] = resource | 
|  | 316 | +                component_names.resources.add(name) | 
|  | 317 | +        except McpError as err: | 
|  | 318 | +            logging.warning(f"Could not fetch resources: {err}") | 
|  | 319 | + | 
|  | 320 | +        # Query the server for its tools and aggregate to list. | 
|  | 321 | +        try: | 
|  | 322 | +            tools = (await session.list_tools()).tools | 
|  | 323 | +            for tool in tools: | 
|  | 324 | +                name = self._component_name(tool.name, server_info) | 
|  | 325 | +                tools_temp[name] = tool | 
|  | 326 | +                tool_to_session_temp[name] = session | 
|  | 327 | +                component_names.tools.add(name) | 
|  | 328 | +        except McpError as err: | 
|  | 329 | +            logging.warning(f"Could not fetch tools: {err}") | 
|  | 330 | + | 
|  | 331 | +        # Clean up exit stack for session if we couldn't retrieve anything | 
|  | 332 | +        # from the server. | 
|  | 333 | +        if not any((prompts_temp, resources_temp, tools_temp)): | 
|  | 334 | +            del self._session_exit_stacks[session] | 
|  | 335 | + | 
|  | 336 | +        # Check for duplicates. | 
|  | 337 | +        matching_prompts = prompts_temp.keys() & self._prompts.keys() | 
|  | 338 | +        if matching_prompts: | 
|  | 339 | +            raise McpError( | 
|  | 340 | +                types.ErrorData( | 
|  | 341 | +                    code=types.INVALID_PARAMS, | 
|  | 342 | +                    message=f"{matching_prompts} already exist in group prompts.", | 
|  | 343 | +                ) | 
|  | 344 | +            ) | 
|  | 345 | +        matching_resources = resources_temp.keys() & self._resources.keys() | 
|  | 346 | +        if matching_resources: | 
|  | 347 | +            raise McpError( | 
|  | 348 | +                types.ErrorData( | 
|  | 349 | +                    code=types.INVALID_PARAMS, | 
|  | 350 | +                    message=f"{matching_resources} already exist in group resources.", | 
|  | 351 | +                ) | 
|  | 352 | +            ) | 
|  | 353 | +        matching_tools = tools_temp.keys() & self._tools.keys() | 
|  | 354 | +        if matching_tools: | 
|  | 355 | +            raise McpError( | 
|  | 356 | +                types.ErrorData( | 
|  | 357 | +                    code=types.INVALID_PARAMS, | 
|  | 358 | +                    message=f"{matching_tools} already exist in group tools.", | 
|  | 359 | +                ) | 
|  | 360 | +            ) | 
|  | 361 | + | 
|  | 362 | +        # Aggregate components. | 
|  | 363 | +        self._sessions[session] = component_names | 
|  | 364 | +        self._prompts.update(prompts_temp) | 
|  | 365 | +        self._resources.update(resources_temp) | 
|  | 366 | +        self._tools.update(tools_temp) | 
|  | 367 | +        self._tool_to_session.update(tool_to_session_temp) | 
|  | 368 | + | 
|  | 369 | +    def _component_name(self, name: str, server_info: types.Implementation) -> str: | 
|  | 370 | +        if self._component_name_hook: | 
|  | 371 | +            return self._component_name_hook(name, server_info) | 
|  | 372 | +        return name | 
0 commit comments