diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index 5ad2dd5485..63433c5a46 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -4,7 +4,7 @@ import asyncio import base64 -from typing import Awaitable, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, Awaitable, Callable, Dict, List, Optional, TypeVar, Union import httpx from mcp import ClientSession, ReadResourceResult, Resource, StdioServerParameters @@ -74,52 +74,61 @@ def __init__( if auth_value: self.update_auth_value(auth_value) + def _create_transport_context( + self, + ) -> tuple[Any, Optional[httpx.AsyncClient]]: + """Create the appropriate transport context based on transport type.""" + http_client: Optional[httpx.AsyncClient] = None + + if self.transport_type == MCPTransport.stdio: + if not self.stdio_config: + raise ValueError("stdio_config is required for stdio transport") + + server_params = StdioServerParameters( + command=self.stdio_config.get("command", ""), + args=self.stdio_config.get("args", []), + env=self.stdio_config.get("env", {}), + ) + transport_ctx = stdio_client(server_params) + elif self.transport_type == MCPTransport.sse: + headers = self._get_auth_headers() + httpx_client_factory = self._create_httpx_client_factory() + transport_ctx = sse_client( + url=self.server_url, + timeout=self.timeout, + headers=headers, + httpx_client_factory=httpx_client_factory, + ) + else: + headers = self._get_auth_headers() + httpx_client_factory = self._create_httpx_client_factory() + verbose_logger.debug( + "litellm headers for streamable_http_client: %s", headers + ) + http_client = httpx_client_factory( + headers=headers, + timeout=httpx.Timeout(self.timeout), + ) + transport_ctx = streamable_http_client( + url=self.server_url, + http_client=http_client, + ) + + if transport_ctx is None: + raise RuntimeError("Failed to create transport context") + + return transport_ctx, http_client + async def run_with_session( self, operation: Callable[[ClientSession], Awaitable[TSessionResult]] ) -> TSessionResult: """Open a session, run the provided coroutine, and clean up.""" transport_ctx = None http_client: Optional[httpx.AsyncClient] = None - transport = None session_ctx = None try: - if self.transport_type == MCPTransport.stdio: - if not self.stdio_config: - raise ValueError("stdio_config is required for stdio transport") - - server_params = StdioServerParameters( - command=self.stdio_config.get("command", ""), - args=self.stdio_config.get("args", []), - env=self.stdio_config.get("env", {}), - ) - transport_ctx = stdio_client(server_params) - elif self.transport_type == MCPTransport.sse: - headers = self._get_auth_headers() - httpx_client_factory = self._create_httpx_client_factory() - transport_ctx = sse_client( - url=self.server_url, - timeout=self.timeout, - headers=headers, - httpx_client_factory=httpx_client_factory, - ) - else: - headers = self._get_auth_headers() - httpx_client_factory = self._create_httpx_client_factory() - verbose_logger.debug( - "litellm headers for streamable_http_client: %s", headers - ) - http_client = httpx_client_factory( - headers=headers, - timeout=httpx.Timeout(self.timeout), - ) - transport_ctx = streamable_http_client( - url=self.server_url, - http_client=http_client, - ) - - if transport_ctx is None: - raise RuntimeError("Failed to create transport context") + transport_ctx, http_client = self._create_transport_context() # Enter transport context transport = await transport_ctx.__aenter__()