Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ def _resolve_oauth2_server_for_root_endpoints(
"""
Resolve the MCP server for root-level OAuth endpoints (no server name in path).

When the MCP SDK hits root-level endpoints like /register, /authorize, /token
without a server name prefix, we try to find the right server automatically.
Returns the server if exactly one OAuth2 server is configured, else None.
When exactly one OAuth2 server is configured and there are no non-OAuth
servers, returns that server as a convenience for single-server setups.
Otherwise returns None to avoid polluting discovery responses for non-OAuth
servers. Clients should use server-specific paths (e.g. /{server_name}/authorize)
when multiple servers are configured.
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
Expand All @@ -144,7 +146,10 @@ def _resolve_oauth2_server_for_root_endpoints(
oauth2_servers = [
s for s in registry.values() if s.auth_type == MCPAuth.oauth2
]
if len(oauth2_servers) == 1:
non_oauth2_servers = [
s for s in registry.values() if s.auth_type != MCPAuth.oauth2
]
if len(oauth2_servers) == 1 and len(non_oauth2_servers) == 0:
return oauth2_servers[0]
return None

Expand Down
5 changes: 2 additions & 3 deletions litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2149,10 +2149,9 @@ def get_mcp_server_enabled() -> Dict[str, bool]:
return {"enabled": MCP_AVAILABLE}

# Mount the MCP handlers
app.mount("/", handle_streamable_http_mcp)
app.mount("/mcp", handle_streamable_http_mcp)
app.mount("/{mcp_server_name}/mcp", handle_streamable_http_mcp)
# /sse must be mounted before the "/" catch-all so it's matched first
app.mount("/sse", handle_sse_mcp)
app.mount("/", handle_streamable_http_mcp)
app.add_middleware(AuthContextMiddleware)

########################################################
Expand Down
91 changes: 67 additions & 24 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12917,7 +12917,6 @@ async def dynamic_mcp_route(mcp_server_name: str, request: Request):
global_mcp_server_manager,
)
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.types.mcp import MCPAuth

client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
Expand All @@ -12938,35 +12937,79 @@ async def dynamic_mcp_route(mcp_server_name: str, request: Request):
handle_streamable_http_mcp,
)

# Create a custom send function to capture the response
response_started = False
response_body = b""
response_status = 200
response_headers = []
# Stream the ASGI response instead of buffering it. This is critical for
# SSE (text/event-stream) responses used by the MCP Streamable HTTP
# transport — buffering would break incremental event delivery.
response_meta: dict = {}
headers_ready = asyncio.Event()
body_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
handler_error: list = []
body_terminated = False

async def custom_send(message):
nonlocal response_started, response_body, response_status, response_headers
async def streaming_send(message):
nonlocal body_terminated
if message["type"] == "http.response.start":
response_started = True
response_status = message["status"]
response_headers = message.get("headers", [])
response_meta["status"] = message["status"]
response_meta["headers"] = {
k.decode(): v.decode()
for k, v in message.get("headers", [])
}
headers_ready.set()
elif message["type"] == "http.response.body":
response_body += message.get("body", b"")
chunk = message.get("body", b"")
if chunk:
await body_queue.put(chunk)
if not message.get("more_body", False):
body_terminated = True
await body_queue.put(None) # sentinel

async def run_handler():
try:
await handle_streamable_http_mcp(
scope, receive=request.receive, send=streaming_send
)
except Exception as exc:
handler_error.append(exc)
finally:
# Ensure consumers aren't stuck waiting if the handler exits
# without sending a complete response.
headers_ready.set()
if not body_terminated:
await body_queue.put(None)

# Call the existing MCP handler
await handle_streamable_http_mcp(
scope, receive=request.receive, send=custom_send
)
handler_task = asyncio.create_task(run_handler())

# Wait for the ASGI handler to send http.response.start
await headers_ready.wait()

# Return the response
from starlette.responses import Response
# If the handler errored before sending headers, raise
if handler_error and "status" not in response_meta:
await handler_task
raise handler_error[0]

headers_dict = {k.decode(): v.decode() for k, v in response_headers}
return Response(
content=response_body,
status_code=response_status,
headers=headers_dict,
media_type=headers_dict.get("content-type", "application/json"),
async def body_generator():
try:
while True:
chunk = await body_queue.get()
if chunk is None:
break
yield chunk
finally:
if not handler_task.done():
handler_task.cancel()
try:
await handler_task
except asyncio.CancelledError:
pass

headers = response_meta.get("headers", {})
media_type = headers.pop("content-type", "application/json")

return StreamingResponse(
content=body_generator(),
status_code=response_meta.get("status", 200),
headers=headers,
media_type=media_type,
)

except HTTPException as e:
Expand Down
Loading
Loading