From 8fcbc8d0115eaa74ceec75209374a80c0be1f30d Mon Sep 17 00:00:00 2001 From: Jack Venberg Date: Tue, 24 Feb 2026 23:49:44 -0800 Subject: [PATCH 1/3] fix(mcp): stream SSE responses, fix mount order, remove OAuth root pollution Three fixes for the MCP server proxy that prevented external clients (e.g. Claude Code) from properly connecting to upstream MCP servers: 1. dynamic_mcp_route now streams ASGI responses via StreamingResponse instead of buffering the entire body. This is critical for SSE (text/event-stream) used by MCP Streamable HTTP transport. 2. Reorder ASGI mounts so /sse is matched before the / catch-all. Remove broken /mcp and /{mcp_server_name}/mcp mounts that either mapped to wrong paths or used unsupported path parameters. 3. _resolve_oauth2_server_for_root_endpoints now always returns None. The previous auto-resolution polluted root OAuth discovery endpoints with metadata from an unrelated server when exactly one OAuth2 server was configured, breaking non-OAuth MCP servers. --- .../mcp_server/discoverable_endpoints.py | 17 +- .../proxy/_experimental/mcp_server/server.py | 5 +- litellm/proxy/proxy_server.py | 87 +++++--- .../mcp_server/test_discoverable_endpoints.py | 97 ++++----- .../test_dynamic_mcp_route_streaming.py | 191 ++++++++++++++++++ 5 files changed, 300 insertions(+), 97 deletions(-) create mode 100644 tests/test_litellm/proxy/_experimental/mcp_server/test_dynamic_mcp_route_streaming.py diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index b731bc7bc2f..ef922c6e0d5 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -132,20 +132,11 @@ def _resolve_oauth2_server_for_root_endpoints( """ Resolve the MCP server for root-level OAuth endpoints (no server name in path). - When the MCP SDK hits root-level endpoints like /register, /authorize, /token - without a server name prefix, we try to find the right server automatically. - Returns the server if exactly one OAuth2 server is configured, else None. + Always returns None. Root-level OAuth discovery endpoints should not + auto-resolve to an arbitrary server because doing so pollutes non-OAuth + servers' discovery responses when any single OAuth2 server is configured. + Clients should use server-specific paths instead (e.g. /{server_name}/authorize). """ - from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( - global_mcp_server_manager, - ) - - registry = global_mcp_server_manager.get_filtered_registry(client_ip=client_ip) - oauth2_servers = [ - s for s in registry.values() if s.auth_type == MCPAuth.oauth2 - ] - if len(oauth2_servers) == 1: - return oauth2_servers[0] return None diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index e8877b4fff7..af2a210f07c 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -2149,10 +2149,9 @@ def get_mcp_server_enabled() -> Dict[str, bool]: return {"enabled": MCP_AVAILABLE} # Mount the MCP handlers - app.mount("/", handle_streamable_http_mcp) - app.mount("/mcp", handle_streamable_http_mcp) - app.mount("/{mcp_server_name}/mcp", handle_streamable_http_mcp) + # /sse must be mounted before the "/" catch-all so it's matched first app.mount("/sse", handle_sse_mcp) + app.mount("/", handle_streamable_http_mcp) app.add_middleware(AuthContextMiddleware) ######################################################## diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 607306f3806..d9c6a6b13d8 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -12917,7 +12917,6 @@ async def dynamic_mcp_route(mcp_server_name: str, request: Request): global_mcp_server_manager, ) from litellm.proxy.auth.ip_address_utils import IPAddressUtils - from litellm.types.mcp import MCPAuth client_ip = IPAddressUtils.get_mcp_client_ip(request) mcp_server = global_mcp_server_manager.get_mcp_server_by_name( @@ -12938,35 +12937,75 @@ async def dynamic_mcp_route(mcp_server_name: str, request: Request): handle_streamable_http_mcp, ) - # Create a custom send function to capture the response - response_started = False - response_body = b"" - response_status = 200 - response_headers = [] + # Stream the ASGI response instead of buffering it. This is critical for + # SSE (text/event-stream) responses used by the MCP Streamable HTTP + # transport — buffering would break incremental event delivery. + response_meta: dict = {} + headers_ready = asyncio.Event() + body_queue: asyncio.Queue[bytes | None] = asyncio.Queue() + handler_error: list = [] - async def custom_send(message): - nonlocal response_started, response_body, response_status, response_headers + async def streaming_send(message): if message["type"] == "http.response.start": - response_started = True - response_status = message["status"] - response_headers = message.get("headers", []) + response_meta["status"] = message["status"] + response_meta["headers"] = { + k.decode(): v.decode() + for k, v in message.get("headers", []) + } + headers_ready.set() elif message["type"] == "http.response.body": - response_body += message.get("body", b"") + chunk = message.get("body", b"") + if chunk: + await body_queue.put(chunk) + if not message.get("more_body", False): + await body_queue.put(None) # sentinel - # Call the existing MCP handler - await handle_streamable_http_mcp( - scope, receive=request.receive, send=custom_send - ) + async def run_handler(): + try: + await handle_streamable_http_mcp( + scope, receive=request.receive, send=streaming_send + ) + except Exception as exc: + handler_error.append(exc) + finally: + # Ensure consumers aren't stuck waiting if the handler exits + # without sending a complete response. + headers_ready.set() + await body_queue.put(None) + + handler_task = asyncio.create_task(run_handler()) + + # Wait for the ASGI handler to send http.response.start + await headers_ready.wait() - # Return the response - from starlette.responses import Response + # If the handler errored before sending headers, raise + if handler_error and "status" not in response_meta: + await handler_task + raise handler_error[0] - headers_dict = {k.decode(): v.decode() for k, v in response_headers} - return Response( - content=response_body, - status_code=response_status, - headers=headers_dict, - media_type=headers_dict.get("content-type", "application/json"), + async def body_generator(): + try: + while True: + chunk = await body_queue.get() + if chunk is None: + break + yield chunk + finally: + if not handler_task.done(): + handler_task.cancel() + try: + await handler_task + except asyncio.CancelledError: + pass + + headers = response_meta.get("headers", {}) + media_type = headers.pop("content-type", "application/json") + + return StreamingResponse( + content=body_generator(), + status_code=response_meta.get("status", 200), + headers=headers, + media_type=media_type, ) except HTTPException as e: diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py index 700ba86b108..7a8f6254a99 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py @@ -1258,8 +1258,12 @@ def _create_oauth2_server( @pytest.mark.asyncio -async def test_authorize_root_resolves_single_oauth2_server(): - """When /authorize is hit without server name and exactly 1 OAuth2 server exists, resolve it.""" +async def test_authorize_root_returns_404_without_server_name(): + """When /authorize is hit without server name, return 404 even if 1 OAuth2 server exists. + + Root auto-resolution was removed to prevent OAuth discovery pollution across + unrelated MCP servers. + """ try: from fastapi import Request @@ -1281,25 +1285,16 @@ async def test_authorize_root_resolves_single_oauth2_server(): mock_request.headers = {} try: - with patch( - "litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper" - ) as mock_encrypt: - mock_encrypt.return_value = "mocked_encrypted_state" - - # Call /authorize WITHOUT mcp_server_name, with dummy_client as client_id - response = await authorize( + with pytest.raises(HTTPException) as exc_info: + await authorize( request=mock_request, client_id="dummy_client", mcp_server_name=None, redirect_uri="http://localhost:62646/callback", state="test_state", ) - - # Should resolve to the single OAuth2 server and redirect - assert response.status_code == 307 - location = response.headers["location"] - assert "https://provider.com/oauth/authorize" in location - assert "client_id=test_client_id" in location + assert exc_info.value.status_code == 404 + assert "MCP server not found" in str(exc_info.value.detail) finally: global_mcp_server_manager.registry.clear() @@ -1349,8 +1344,12 @@ async def test_authorize_root_fails_with_multiple_oauth2_servers(): @pytest.mark.asyncio -async def test_token_root_resolves_single_oauth2_server(): - """When /token is hit without server name and exactly 1 OAuth2 server exists, resolve it.""" +async def test_token_root_returns_404_without_server_name(): + """When /token is hit without server name, return 404 even if 1 OAuth2 server exists. + + Root auto-resolution was removed to prevent OAuth discovery pollution across + unrelated MCP servers. + """ try: from fastapi import Request @@ -1371,25 +1370,9 @@ async def test_token_root_resolves_single_oauth2_server(): mock_request.base_url = "https://llm.example.com/" mock_request.headers = {} - mock_response = MagicMock() - mock_response.json.return_value = { - "access_token": "ya29.test_token", - "token_type": "Bearer", - "expires_in": 3599, - } - mock_response.raise_for_status = MagicMock() - - mock_async_client = MagicMock() - mock_async_client.post = AsyncMock(return_value=mock_response) - try: - with patch( - "litellm.proxy._experimental.mcp_server.discoverable_endpoints.get_async_httpx_client" - ) as mock_get_client: - mock_get_client.return_value = mock_async_client - - # Call /token WITHOUT mcp_server_name - response = await token_endpoint( + with pytest.raises(HTTPException) as exc_info: + await token_endpoint( request=mock_request, grant_type="authorization_code", code="test_auth_code", @@ -1399,23 +1382,20 @@ async def test_token_root_resolves_single_oauth2_server(): client_secret=None, code_verifier="test_verifier", ) - - # Should resolve and exchange token with the upstream server - import json - - token_data = json.loads(response.body) - assert token_data["access_token"] == "ya29.test_token" - - # Verify it called the correct upstream token URL - call_args = mock_async_client.post.call_args - assert call_args.args[0] == "https://provider.com/oauth/token" + assert exc_info.value.status_code == 404 + assert "MCP server not found" in str(exc_info.value.detail) finally: global_mcp_server_manager.registry.clear() @pytest.mark.asyncio -async def test_register_root_resolves_single_oauth2_server(): - """When /register is hit without server name and exactly 1 OAuth2 server exists, resolve it.""" +async def test_register_root_returns_dummy_without_server_name(): + """When /register is hit without server name, return dummy registration even if 1 OAuth2 server exists. + + Root auto-resolution was removed to prevent OAuth discovery pollution across + unrelated MCP servers. The register endpoint returns a dummy response when + no server can be resolved. + """ try: from fastapi import Request @@ -1443,16 +1423,19 @@ async def test_register_root_resolves_single_oauth2_server(): ): result = await register_client(request=mock_request, mcp_server_name=None) - # Should resolve to the single server and return its name as client_id - assert result["client_id"] == "test_oauth" - assert "redirect_uris" in result + # Should return dummy registration since no server is resolved + assert result["client_id"] == "dummy_client" finally: global_mcp_server_manager.registry.clear() @pytest.mark.asyncio -async def test_discovery_root_includes_server_name_prefix(): - """When root discovery is hit and exactly 1 OAuth2 server exists, include server name in URLs.""" +async def test_discovery_root_returns_generic_urls_without_server_name(): + """When root discovery is hit without server name, return generic URLs without server prefix. + + Root auto-resolution was removed to prevent OAuth discovery pollution across + unrelated MCP servers. + """ try: from fastapi import Request @@ -1480,11 +1463,11 @@ async def test_discovery_root_includes_server_name_prefix(): mcp_server_name=None, ) - # Should resolve to the single server and include its name in endpoint URLs - assert "/test_oauth/authorize" in response["authorization_endpoint"] - assert "/test_oauth/token" in response["token_endpoint"] - assert "/test_oauth/register" in response["registration_endpoint"] - assert response["scopes_supported"] == ["read", "write"] + # Should NOT resolve to the single server — return generic URLs instead + assert response["authorization_endpoint"] == "https://llm.example.com/authorize" + assert response["token_endpoint"] == "https://llm.example.com/token" + assert response["registration_endpoint"] == "https://llm.example.com/register" + assert response["scopes_supported"] == [] finally: global_mcp_server_manager.registry.clear() diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_dynamic_mcp_route_streaming.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_dynamic_mcp_route_streaming.py new file mode 100644 index 00000000000..1688c4fd6e2 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_dynamic_mcp_route_streaming.py @@ -0,0 +1,191 @@ +"""Tests for dynamic_mcp_route streaming behavior and MCP mount ordering.""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException +from starlette.responses import StreamingResponse + + +@pytest.mark.asyncio +async def test_dynamic_mcp_route_streams_sse_response(): + """dynamic_mcp_route should return a StreamingResponse for SSE content.""" + try: + from litellm.proxy.proxy_server import dynamic_mcp_route + except ImportError: + pytest.skip("proxy_server not available") + + mock_request = MagicMock() + mock_request.scope = { + "type": "http", + "method": "POST", + "path": "/test_server/mcp", + "headers": [], + "query_string": b"", + } + mock_request.receive = AsyncMock() + + sse_chunks = [b"event: message\ndata: chunk1\n\n", b"event: message\ndata: chunk2\n\n"] + + async def fake_asgi_handler(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/event-stream"), + ], + } + ) + for i, chunk in enumerate(sse_chunks): + await send( + { + "type": "http.response.body", + "body": chunk, + "more_body": i < len(sse_chunks) - 1, + } + ) + + mock_server = MagicMock() + mock_server.server_name = "test_server" + + with ( + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager.get_mcp_server_by_name", + return_value=mock_server, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.handle_streamable_http_mcp", + side_effect=fake_asgi_handler, + ), + ): + response = await dynamic_mcp_route("test_server", mock_request) + + assert isinstance(response, StreamingResponse) + assert response.status_code == 200 + assert response.media_type == "text/event-stream" + + # Consume the body generator and verify chunks + collected = b"" + async for chunk in response.body_iterator: + collected += chunk + assert collected == b"".join(sse_chunks) + + +@pytest.mark.asyncio +async def test_dynamic_mcp_route_handles_non_streaming_response(): + """dynamic_mcp_route should also work for non-SSE (single-chunk) responses.""" + try: + from litellm.proxy.proxy_server import dynamic_mcp_route + except ImportError: + pytest.skip("proxy_server not available") + + mock_request = MagicMock() + mock_request.scope = { + "type": "http", + "method": "POST", + "path": "/test_server/mcp", + "headers": [], + "query_string": b"", + } + mock_request.receive = AsyncMock() + + body = b'{"jsonrpc":"2.0","result":{},"id":1}' + + async def fake_asgi_handler(scope, receive, send): + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + ], + } + ) + await send( + { + "type": "http.response.body", + "body": body, + "more_body": False, + } + ) + + mock_server = MagicMock() + mock_server.server_name = "test_server" + + with ( + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager.get_mcp_server_by_name", + return_value=mock_server, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.handle_streamable_http_mcp", + side_effect=fake_asgi_handler, + ), + ): + response = await dynamic_mcp_route("test_server", mock_request) + + assert isinstance(response, StreamingResponse) + assert response.status_code == 200 + assert response.media_type == "application/json" + + collected = b"" + async for chunk in response.body_iterator: + collected += chunk + assert collected == body + + +@pytest.mark.asyncio +async def test_dynamic_mcp_route_returns_404_for_unknown_server(): + """dynamic_mcp_route should raise 404 for a server that doesn't exist.""" + try: + from litellm.proxy.proxy_server import dynamic_mcp_route + except ImportError: + pytest.skip("proxy_server not available") + + mock_request = MagicMock() + mock_request.scope = { + "type": "http", + "method": "POST", + "path": "/nonexistent/mcp", + "headers": [], + "query_string": b"", + } + + with patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager.get_mcp_server_by_name", + return_value=None, + ): + with pytest.raises(HTTPException) as exc_info: + await dynamic_mcp_route("nonexistent", mock_request) + assert exc_info.value.status_code == 404 + + +def test_sse_mount_precedes_catch_all(): + """The /sse mount must appear before the / catch-all in the MCP sub-app.""" + try: + from litellm.proxy._experimental.mcp_server.server import app as mcp_app + except ImportError: + pytest.skip("MCP server module not available") + + from starlette.routing import Mount + + mounts = [r for r in mcp_app.routes if isinstance(r, Mount)] + mount_paths = [m.path for m in mounts] + + # Starlette normalizes "/" to "" for mount paths + catch_all = "" if "" in mount_paths else "/" + + # /sse must be present and before the catch-all + assert "/sse" in mount_paths, f"Missing /sse mount. Found: {mount_paths}" + assert catch_all in mount_paths, f"Missing catch-all mount. Found: {mount_paths}" + assert mount_paths.index("/sse") < mount_paths.index( + catch_all + ), f"/sse must precede catch-all. Order: {mount_paths}" + + # The broken /mcp and /{mcp_server_name}/mcp mounts should not exist + assert "/mcp" not in mount_paths, f"Unexpected /mcp mount found. Mounts: {mount_paths}" + assert "/{mcp_server_name}/mcp" not in mount_paths, ( + f"Unexpected /{{mcp_server_name}}/mcp mount found. Mounts: {mount_paths}" + ) From 77da3b1eac948f1f6e4aefadff59d6c68d7cb71a Mon Sep 17 00:00:00 2001 From: Jack Venberg Date: Wed, 25 Feb 2026 00:08:56 -0800 Subject: [PATCH 2/3] fix(mcp): refine OAuth root auto-resolution to preserve single-server convenience Instead of always returning None, auto-resolve to the single OAuth2 server only when no non-OAuth servers are configured. When non-OAuth servers are also present, return None to avoid polluting their discovery responses. This preserves the convenience behavior for single-OAuth-only setups while fixing the mixed-server case. Add 4 new tests for the mixed-server scenario (OAuth + non-OAuth) and restore original test expectations for the single-OAuth-only case. --- .../mcp_server/discoverable_endpoints.py | 22 +- .../mcp_server/test_discoverable_endpoints.py | 262 ++++++++++++++++-- 2 files changed, 250 insertions(+), 34 deletions(-) diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index ef922c6e0d5..0c96528add6 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -132,11 +132,25 @@ def _resolve_oauth2_server_for_root_endpoints( """ Resolve the MCP server for root-level OAuth endpoints (no server name in path). - Always returns None. Root-level OAuth discovery endpoints should not - auto-resolve to an arbitrary server because doing so pollutes non-OAuth - servers' discovery responses when any single OAuth2 server is configured. - Clients should use server-specific paths instead (e.g. /{server_name}/authorize). + When exactly one OAuth2 server is configured and there are no non-OAuth + servers, returns that server as a convenience for single-server setups. + Otherwise returns None to avoid polluting discovery responses for non-OAuth + servers. Clients should use server-specific paths (e.g. /{server_name}/authorize) + when multiple servers are configured. """ + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + registry = global_mcp_server_manager.get_filtered_registry(client_ip=client_ip) + oauth2_servers = [ + s for s in registry.values() if s.auth_type == MCPAuth.oauth2 + ] + non_oauth2_servers = [ + s for s in registry.values() if s.auth_type != MCPAuth.oauth2 + ] + if len(oauth2_servers) == 1 and len(non_oauth2_servers) == 0: + return oauth2_servers[0] return None diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py index 7a8f6254a99..04e0b0050c2 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py @@ -1258,12 +1258,8 @@ def _create_oauth2_server( @pytest.mark.asyncio -async def test_authorize_root_returns_404_without_server_name(): - """When /authorize is hit without server name, return 404 even if 1 OAuth2 server exists. - - Root auto-resolution was removed to prevent OAuth discovery pollution across - unrelated MCP servers. - """ +async def test_authorize_root_resolves_single_oauth2_server(): + """When /authorize is hit without server name and exactly 1 OAuth2 server exists (no non-OAuth), resolve it.""" try: from fastapi import Request @@ -1285,16 +1281,23 @@ async def test_authorize_root_returns_404_without_server_name(): mock_request.headers = {} try: - with pytest.raises(HTTPException) as exc_info: - await authorize( + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper" + ) as mock_encrypt: + mock_encrypt.return_value = "mocked_encrypted_state" + + response = await authorize( request=mock_request, client_id="dummy_client", mcp_server_name=None, redirect_uri="http://localhost:62646/callback", state="test_state", ) - assert exc_info.value.status_code == 404 - assert "MCP server not found" in str(exc_info.value.detail) + + # Should resolve to the single OAuth2 server and redirect + assert response.status_code == 307 + location = response.headers["location"] + assert "https://provider.com/oauth/authorize" in location finally: global_mcp_server_manager.registry.clear() @@ -1344,15 +1347,218 @@ async def test_authorize_root_fails_with_multiple_oauth2_servers(): @pytest.mark.asyncio -async def test_token_root_returns_404_without_server_name(): - """When /token is hit without server name, return 404 even if 1 OAuth2 server exists. +async def test_token_root_resolves_single_oauth2_server(): + """When /token is hit without server name and exactly 1 OAuth2 server exists (no non-OAuth), resolve it.""" + try: + from fastapi import Request + + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + token_endpoint, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + oauth2_server = _create_oauth2_server() + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "ya29.test_token", + "token_type": "Bearer", + "expires_in": 3599, + } + mock_response.raise_for_status = MagicMock() - Root auto-resolution was removed to prevent OAuth discovery pollution across - unrelated MCP servers. + mock_async_client = MagicMock() + mock_async_client.post = AsyncMock(return_value=mock_response) + + try: + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = mock_async_client + + response = await token_endpoint( + request=mock_request, + grant_type="authorization_code", + code="test_auth_code", + redirect_uri="http://localhost:62646/callback", + client_id="dummy_client", + mcp_server_name=None, + client_secret=None, + code_verifier="test_verifier", + ) + + import json + + token_data = json.loads(response.body) + assert token_data["access_token"] == "ya29.test_token" + + call_args = mock_async_client.post.call_args + assert call_args.args[0] == "https://provider.com/oauth/token" + finally: + global_mcp_server_manager.registry.clear() + + +@pytest.mark.asyncio +async def test_register_root_resolves_single_oauth2_server(): + """When /register is hit without server name and exactly 1 OAuth2 server exists (no non-OAuth), resolve it.""" + try: + from fastapi import Request + + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + register_client, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + oauth2_server = _create_oauth2_server() + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + try: + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints._read_request_body", + new=AsyncMock(return_value={}), + ): + result = await register_client(request=mock_request, mcp_server_name=None) + + # Should resolve to the single server and return its name as client_id + assert result["client_id"] == "test_oauth" + assert "redirect_uris" in result + finally: + global_mcp_server_manager.registry.clear() + + +@pytest.mark.asyncio +async def test_discovery_root_includes_server_name_prefix(): + """When root discovery is hit and exactly 1 OAuth2 server exists (no non-OAuth), include server name in URLs.""" + try: + from fastapi import Request + + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + _build_oauth_authorization_server_response, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + oauth2_server = _create_oauth2_server() + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + try: + # Call with mcp_server_name=None (root discovery) + response = _build_oauth_authorization_server_response( + request=mock_request, + mcp_server_name=None, + ) + + # Should resolve to the single server and include its name in endpoint URLs + assert "/test_oauth/authorize" in response["authorization_endpoint"] + assert "/test_oauth/token" in response["token_endpoint"] + assert "/test_oauth/register" in response["registration_endpoint"] + assert response["scopes_supported"] == ["read", "write"] + finally: + global_mcp_server_manager.registry.clear() + + +def _create_non_oauth_server( + server_id="test_plain_server", + name="test_plain", + server_name="test_plain", + alias="test_plain", +): + """Helper to create a mock non-OAuth MCPServer.""" + from litellm.proxy._types import MCPTransport + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + return MCPServer( + server_id=server_id, + name=name, + server_name=server_name, + alias=alias, + url="https://example.com/mcp", + transport=MCPTransport.http, + auth_type=None, + ) + + +# ------------------------------------------------------------------- +# Tests for root-level OAuth when non-OAuth servers are also present +# ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_authorize_root_returns_404_when_non_oauth_servers_present(): + """When /authorize is hit without server name and non-OAuth servers also exist, return 404. + + Auto-resolution is skipped to avoid polluting non-OAuth servers' discovery. """ try: from fastapi import Request + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + authorize, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + oauth2_server = _create_oauth2_server() + plain_server = _create_non_oauth_server() + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + global_mcp_server_manager.registry[plain_server.server_id] = plain_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + try: + with pytest.raises(HTTPException) as exc_info: + await authorize( + request=mock_request, + client_id="dummy_client", + mcp_server_name=None, + redirect_uri="http://localhost:62646/callback", + state="test_state", + ) + assert exc_info.value.status_code == 404 + assert "MCP server not found" in str(exc_info.value.detail) + finally: + global_mcp_server_manager.registry.clear() + + +@pytest.mark.asyncio +async def test_token_root_returns_404_when_non_oauth_servers_present(): + """When /token is hit without server name and non-OAuth servers also exist, return 404.""" + try: + from fastapi import Request + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( token_endpoint, ) @@ -1364,7 +1570,9 @@ async def test_token_root_returns_404_without_server_name(): global_mcp_server_manager.registry.clear() oauth2_server = _create_oauth2_server() + plain_server = _create_non_oauth_server() global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + global_mcp_server_manager.registry[plain_server.server_id] = plain_server mock_request = MagicMock(spec=Request) mock_request.base_url = "https://llm.example.com/" @@ -1389,13 +1597,8 @@ async def test_token_root_returns_404_without_server_name(): @pytest.mark.asyncio -async def test_register_root_returns_dummy_without_server_name(): - """When /register is hit without server name, return dummy registration even if 1 OAuth2 server exists. - - Root auto-resolution was removed to prevent OAuth discovery pollution across - unrelated MCP servers. The register endpoint returns a dummy response when - no server can be resolved. - """ +async def test_register_root_returns_dummy_when_non_oauth_servers_present(): + """When /register is hit without server name and non-OAuth servers also exist, return dummy.""" try: from fastapi import Request @@ -1410,7 +1613,9 @@ async def test_register_root_returns_dummy_without_server_name(): global_mcp_server_manager.registry.clear() oauth2_server = _create_oauth2_server() + plain_server = _create_non_oauth_server() global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + global_mcp_server_manager.registry[plain_server.server_id] = plain_server mock_request = MagicMock(spec=Request) mock_request.base_url = "https://llm.example.com/" @@ -1423,19 +1628,15 @@ async def test_register_root_returns_dummy_without_server_name(): ): result = await register_client(request=mock_request, mcp_server_name=None) - # Should return dummy registration since no server is resolved + # Should return dummy registration since auto-resolution is skipped assert result["client_id"] == "dummy_client" finally: global_mcp_server_manager.registry.clear() @pytest.mark.asyncio -async def test_discovery_root_returns_generic_urls_without_server_name(): - """When root discovery is hit without server name, return generic URLs without server prefix. - - Root auto-resolution was removed to prevent OAuth discovery pollution across - unrelated MCP servers. - """ +async def test_discovery_root_returns_generic_urls_when_non_oauth_servers_present(): + """When root discovery is hit and non-OAuth servers also exist, return generic URLs without server prefix.""" try: from fastapi import Request @@ -1450,20 +1651,21 @@ async def test_discovery_root_returns_generic_urls_without_server_name(): global_mcp_server_manager.registry.clear() oauth2_server = _create_oauth2_server() + plain_server = _create_non_oauth_server() global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + global_mcp_server_manager.registry[plain_server.server_id] = plain_server mock_request = MagicMock(spec=Request) mock_request.base_url = "https://llm.example.com/" mock_request.headers = {} try: - # Call with mcp_server_name=None (root discovery) response = _build_oauth_authorization_server_response( request=mock_request, mcp_server_name=None, ) - # Should NOT resolve to the single server — return generic URLs instead + # Should NOT resolve — return generic URLs without server prefix assert response["authorization_endpoint"] == "https://llm.example.com/authorize" assert response["token_endpoint"] == "https://llm.example.com/token" assert response["registration_endpoint"] == "https://llm.example.com/register" From de74f859f2713f3220f4a29d60895829871d087a Mon Sep 17 00:00:00 2001 From: Jack Venberg Date: Wed, 25 Feb 2026 00:31:04 -0800 Subject: [PATCH 3/3] fix(mcp): guard against duplicate sentinel in streaming body queue Add body_terminated flag so the run_handler finally block only sends the None sentinel when streaming_send hasn't already sent one on the normal more_body=False path. --- litellm/proxy/proxy_server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d9c6a6b13d8..42a104fe3be 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -12944,8 +12944,10 @@ async def dynamic_mcp_route(mcp_server_name: str, request: Request): headers_ready = asyncio.Event() body_queue: asyncio.Queue[bytes | None] = asyncio.Queue() handler_error: list = [] + body_terminated = False async def streaming_send(message): + nonlocal body_terminated if message["type"] == "http.response.start": response_meta["status"] = message["status"] response_meta["headers"] = { @@ -12958,6 +12960,7 @@ async def streaming_send(message): if chunk: await body_queue.put(chunk) if not message.get("more_body", False): + body_terminated = True await body_queue.put(None) # sentinel async def run_handler(): @@ -12971,7 +12974,8 @@ async def run_handler(): # Ensure consumers aren't stuck waiting if the handler exits # without sending a complete response. headers_ready.set() - await body_queue.put(None) + if not body_terminated: + await body_queue.put(None) handler_task = asyncio.create_task(run_handler())