Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 82 additions & 77 deletions litellm/experimental_mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import asyncio
import base64
from typing import Awaitable, Callable, Dict, List, Optional, TypeVar, Union
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import httpx
from mcp import ClientSession, ReadResourceResult, Resource, StdioServerParameters
Expand Down Expand Up @@ -74,97 +74,102 @@ def __init__(
if auth_value:
self.update_auth_value(auth_value)

async def run_with_session(
self, operation: Callable[[ClientSession], Awaitable[TSessionResult]]
) -> TSessionResult:
"""Open a session, run the provided coroutine, and clean up."""
transport_ctx = None
def _create_transport_context(
self,
) -> Tuple[Any, Optional[httpx.AsyncClient]]:
"""
Create the appropriate transport context based on transport type.

Returns:
Tuple of (transport_context, http_client).
http_client is only set for HTTP transport and needs cleanup.
"""
http_client: Optional[httpx.AsyncClient] = None
transport = None
session_ctx = None

try:
if self.transport_type == MCPTransport.stdio:
if not self.stdio_config:
raise ValueError("stdio_config is required for stdio transport")

server_params = StdioServerParameters(
command=self.stdio_config.get("command", ""),
args=self.stdio_config.get("args", []),
env=self.stdio_config.get("env", {}),
)
transport_ctx = stdio_client(server_params)
elif self.transport_type == MCPTransport.sse:
headers = self._get_auth_headers()
httpx_client_factory = self._create_httpx_client_factory()
transport_ctx = sse_client(
url=self.server_url,
timeout=self.timeout,
headers=headers,
httpx_client_factory=httpx_client_factory,
)
else:
headers = self._get_auth_headers()
httpx_client_factory = self._create_httpx_client_factory()
verbose_logger.debug(
"litellm headers for streamable_http_client: %s", headers
)
http_client = httpx_client_factory(
headers=headers,
timeout=httpx.Timeout(self.timeout),
)
transport_ctx = streamable_http_client(
url=self.server_url,
http_client=http_client,
)
if self.transport_type == MCPTransport.stdio:
if not self.stdio_config:
raise ValueError("stdio_config is required for stdio transport")
server_params = StdioServerParameters(
command=self.stdio_config.get("command", ""),
args=self.stdio_config.get("args", []),
env=self.stdio_config.get("env", {}),
)
return stdio_client(server_params), None

if self.transport_type == MCPTransport.sse:
headers = self._get_auth_headers()
httpx_client_factory = self._create_httpx_client_factory()
return sse_client(
url=self.server_url,
timeout=self.timeout,
headers=headers,
httpx_client_factory=httpx_client_factory,
), None

if transport_ctx is None:
raise RuntimeError("Failed to create transport context")
# HTTP transport (default)
headers = self._get_auth_headers()
httpx_client_factory = self._create_httpx_client_factory()
verbose_logger.debug(
"litellm headers for streamable_http_client: %s", headers
)
http_client = httpx_client_factory(
headers=headers,
timeout=httpx.Timeout(self.timeout),
)
transport_ctx = streamable_http_client(
url=self.server_url,
http_client=http_client,
)
return transport_ctx, http_client

async def _execute_session_operation(
self,
transport_ctx: Any,
operation: Callable[[ClientSession], Awaitable[TSessionResult]],
) -> TSessionResult:
"""
Execute an operation within a transport and session context.

# Enter transport context
transport = await transport_ctx.__aenter__()
Handles entering/exiting contexts and running the operation.
"""
transport = await transport_ctx.__aenter__()
try:
read_stream, write_stream = transport[0], transport[1]
session_ctx = ClientSession(read_stream, write_stream)
session = await session_ctx.__aenter__()
try:
read_stream, write_stream = transport[0], transport[1]
session_ctx = ClientSession(read_stream, write_stream)

# Enter session context
session = await session_ctx.__aenter__()
try:
await session.initialize()
result = await operation(session)
return result
finally:
# Ensure session context is properly exited
if session_ctx is not None:
try:
await session_ctx.__aexit__(None, None, None)
except Exception as e:
verbose_logger.debug(
f"Error during session context exit: {e}"
)
await session.initialize()
return await operation(session)
finally:
# Ensure transport context is properly exited
if transport_ctx is not None:
try:
await transport_ctx.__aexit__(None, None, None)
except Exception as e:
verbose_logger.debug(
f"Error during transport context exit: {e}"
)
try:
await session_ctx.__aexit__(None, None, None)
except BaseException as e:
verbose_logger.debug(f"Error during session context exit: {e}")
finally:
try:
await transport_ctx.__aexit__(None, None, None)
except BaseException as e:
verbose_logger.debug(f"Error during transport context exit: {e}")

async def run_with_session(
self, operation: Callable[[ClientSession], Awaitable[TSessionResult]]
) -> TSessionResult:
"""Open a session, run the provided coroutine, and clean up."""
http_client: Optional[httpx.AsyncClient] = None
try:
transport_ctx, http_client = self._create_transport_context()
return await self._execute_session_operation(transport_ctx, operation)
except Exception:
verbose_logger.warning(
"MCP client run_with_session failed for %s", self.server_url or "stdio"
)
raise
finally:
# Always clean up http_client if it was created
if http_client is not None:
try:
await http_client.aclose()
except Exception as e:
verbose_logger.debug(
f"Error during http_client cleanup: {e}"
)
except BaseException as e:
verbose_logger.debug(f"Error during http_client cleanup: {e}")

def update_auth_value(self, mcp_auth_value: Union[str, Dict[str, str]]):
"""
Expand Down
Loading
Loading