diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index 8b052dd0da1..bdf4cc312d9 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -16,6 +16,7 @@ ) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.utils import get_server_root_path +from litellm.types.mcp import MCPAuth from litellm.types.mcp_server.mcp_server_manager import MCPServer router = APIRouter( @@ -125,6 +126,29 @@ def decode_state_hash(encrypted_state: str) -> dict: return state_data +def _resolve_oauth2_server_for_root_endpoints( + client_ip: Optional[str] = None, +) -> Optional[MCPServer]: + """ + Resolve the MCP server for root-level OAuth endpoints (no server name in path). + + When the MCP SDK hits root-level endpoints like /register, /authorize, /token + without a server name prefix, we try to find the right server automatically. + Returns the server if exactly one OAuth2 server is configured, else None. + """ + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + registry = global_mcp_server_manager.get_filtered_registry(client_ip=client_ip) + oauth2_servers = [ + s for s in registry.values() if s.auth_type == MCPAuth.oauth2 + ] + if len(oauth2_servers) == 1: + return oauth2_servers[0] + return None + + async def authorize_with_server( request: Request, mcp_server: MCPServer, @@ -305,6 +329,8 @@ async def authorize( mcp_server = global_mcp_server_manager.get_mcp_server_by_name( lookup_name, client_ip=client_ip ) + if mcp_server is None and mcp_server_name is None: + mcp_server = _resolve_oauth2_server_for_root_endpoints() if mcp_server is None: raise HTTPException(status_code=404, detail="MCP server not found") return await authorize_with_server( @@ -350,6 +376,8 @@ async def token_endpoint( mcp_server = global_mcp_server_manager.get_mcp_server_by_name( lookup_name, client_ip=client_ip ) + if mcp_server is None and mcp_server_name is None: + mcp_server = _resolve_oauth2_server_for_root_endpoints() if mcp_server is None: raise HTTPException(status_code=404, detail="MCP server not found") return await exchange_token_with_server( @@ -430,6 +458,13 @@ def _build_oauth_protected_resource_response( ) request_base_url = get_request_base_url(request) + + # When no server name provided, try to resolve the single OAuth2 server + if mcp_server_name is None: + resolved = _resolve_oauth2_server_for_root_endpoints() + if resolved: + mcp_server_name = resolved.server_name or resolved.name + mcp_server: Optional[MCPServer] = None if mcp_server_name: client_ip = IPAddressUtils.get_mcp_client_ip(request) @@ -535,6 +570,12 @@ def _build_oauth_authorization_server_response( request_base_url = get_request_base_url(request) + # When no server name provided, try to resolve the single OAuth2 server + if mcp_server_name is None: + resolved = _resolve_oauth2_server_for_root_endpoints() + if resolved: + mcp_server_name = resolved.server_name or resolved.name + authorization_endpoint = ( f"{request_base_url}/{mcp_server_name}/authorize" if mcp_server_name @@ -640,6 +681,19 @@ async def register_client(request: Request, mcp_server_name: Optional[str] = Non "redirect_uris": [f"{request_base_url}/callback"], } if not mcp_server_name: + resolved = _resolve_oauth2_server_for_root_endpoints() + if resolved: + return await register_client_with_server( + request=request, + mcp_server=resolved, + client_name=data.get("client_name", ""), + grant_types=data.get("grant_types", []), + response_types=data.get("response_types", []), + token_endpoint_auth_method=data.get( + "token_endpoint_auth_method", "" + ), + fallback_client_id=resolved.server_name or resolved.name, + ) return dummy_return client_ip = IPAddressUtils.get_mcp_client_ip(request) diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 890c4ae8fb2..58cd8c99e7b 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -31,6 +31,9 @@ from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( MCPRequestHandler, ) +from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + get_request_base_url, +) from litellm.proxy._experimental.mcp_server.utils import ( LITELLM_MCP_SERVER_DESCRIPTION, LITELLM_MCP_SERVER_NAME, @@ -1972,7 +1975,7 @@ async def handle_streamable_http_mcp( ) if server and server.auth_type == MCPAuth.oauth2 and not oauth2_headers: request = StarletteRequest(scope) - base_url = str(request.base_url).rstrip("/") + base_url = get_request_base_url(request) authorization_uri = ( f"Bearer authorization_uri=" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py index e4b4d3ba189..faabe40f2dc 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py @@ -2,6 +2,8 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException + # Fixture to mock IP address check for all MCP tests # This prevents tests from failing due to IP-based access control @@ -260,10 +262,16 @@ async def test_register_client_without_mcp_server_name_returns_dummy(): from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( register_client, ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) from fastapi import Request except ImportError: pytest.skip("MCP discoverable endpoints not available") + # Clear registry to ensure no OAuth2 servers exist (otherwise resolver would find one) + global_mcp_server_manager.registry.clear() + mock_request = MagicMock(spec=Request) mock_request.base_url = "https://proxy.litellm.example/" mock_request.headers = {} @@ -680,10 +688,16 @@ async def test_register_client_respects_x_forwarded_proto(): from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( register_client, ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) from fastapi import Request except ImportError: pytest.skip("MCP discoverable endpoints not available") + # Clear registry to ensure no OAuth2 servers exist (otherwise resolver would find one) + global_mcp_server_manager.registry.clear() + # Mock request with http base_url but X-Forwarded-Proto: https mock_request = MagicMock(spec=Request) mock_request.base_url = "http://proxy.litellm.example/" # HTTP @@ -1017,3 +1031,263 @@ def mock_get(header_name, default=None): f"X-Forwarded-Host={x_forwarded_host}, " f"X-Forwarded-Port={x_forwarded_port}" ) + + +# ------------------------------------------------------------------- +# Tests for root-level OAuth endpoint resolution (no server name) +# ------------------------------------------------------------------- + + +def _create_oauth2_server( + server_id="test_oauth_server", + name="test_oauth", + server_name="test_oauth", + alias="test_oauth", + client_id="test_client_id", + client_secret="test_client_secret", +): + """Helper to create a mock OAuth2 MCPServer.""" + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + from litellm.proxy._types import MCPTransport + + return MCPServer( + server_id=server_id, + name=name, + server_name=server_name, + alias=alias, + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + client_id=client_id, + client_secret=client_secret, + authorization_url="https://provider.com/oauth/authorize", + token_url="https://provider.com/oauth/token", + scopes=["read", "write"], + ) + + +@pytest.mark.asyncio +async def test_authorize_root_resolves_single_oauth2_server(): + """When /authorize is hit without server name and exactly 1 OAuth2 server exists, resolve it.""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + authorize, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from fastapi import Request + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + oauth2_server = _create_oauth2_server() + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + try: + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.encrypt_value_helper" + ) as mock_encrypt: + mock_encrypt.return_value = "mocked_encrypted_state" + + # Call /authorize WITHOUT mcp_server_name, with dummy_client as client_id + response = await authorize( + request=mock_request, + client_id="dummy_client", + mcp_server_name=None, + redirect_uri="http://localhost:62646/callback", + state="test_state", + ) + + # Should resolve to the single OAuth2 server and redirect + assert response.status_code == 307 + location = response.headers["location"] + assert "https://provider.com/oauth/authorize" in location + assert "client_id=test_client_id" in location + finally: + global_mcp_server_manager.registry.clear() + + +@pytest.mark.asyncio +async def test_authorize_root_fails_with_multiple_oauth2_servers(): + """When /authorize is hit without server name and multiple OAuth2 servers exist, return 404.""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + authorize, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from fastapi import Request + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + server1 = _create_oauth2_server( + server_id="server1", name="server1", server_name="server1", alias="server1" + ) + server2 = _create_oauth2_server( + server_id="server2", name="server2", server_name="server2", alias="server2" + ) + global_mcp_server_manager.registry[server1.server_id] = server1 + global_mcp_server_manager.registry[server2.server_id] = server2 + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + try: + with pytest.raises(HTTPException) as exc_info: + await authorize( + request=mock_request, + client_id="dummy_client", + mcp_server_name=None, + redirect_uri="http://localhost:62646/callback", + state="test_state", + ) + assert exc_info.value.status_code == 404 + assert "MCP server not found" in str(exc_info.value.detail) + finally: + global_mcp_server_manager.registry.clear() + + +@pytest.mark.asyncio +async def test_token_root_resolves_single_oauth2_server(): + """When /token is hit without server name and exactly 1 OAuth2 server exists, resolve it.""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + token_endpoint, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from fastapi import Request + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + oauth2_server = _create_oauth2_server() + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + mock_response = MagicMock() + mock_response.json.return_value = { + "access_token": "ya29.test_token", + "token_type": "Bearer", + "expires_in": 3599, + } + mock_response.raise_for_status = MagicMock() + + mock_async_client = MagicMock() + mock_async_client.post = AsyncMock(return_value=mock_response) + + try: + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.get_async_httpx_client" + ) as mock_get_client: + mock_get_client.return_value = mock_async_client + + # Call /token WITHOUT mcp_server_name + response = await token_endpoint( + request=mock_request, + grant_type="authorization_code", + code="test_auth_code", + redirect_uri="http://localhost:62646/callback", + client_id="dummy_client", + mcp_server_name=None, + client_secret=None, + code_verifier="test_verifier", + ) + + # Should resolve and exchange token with the upstream server + import json + + token_data = json.loads(response.body) + assert token_data["access_token"] == "ya29.test_token" + + # Verify it called the correct upstream token URL + call_args = mock_async_client.post.call_args + assert call_args.args[0] == "https://provider.com/oauth/token" + finally: + global_mcp_server_manager.registry.clear() + + +@pytest.mark.asyncio +async def test_register_root_resolves_single_oauth2_server(): + """When /register is hit without server name and exactly 1 OAuth2 server exists, resolve it.""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + register_client, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from fastapi import Request + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + oauth2_server = _create_oauth2_server() + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + try: + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints._read_request_body", + new=AsyncMock(return_value={}), + ): + result = await register_client(request=mock_request, mcp_server_name=None) + + # Should resolve to the single server and return its name as client_id + assert result["client_id"] == "test_oauth" + assert "redirect_uris" in result + finally: + global_mcp_server_manager.registry.clear() + + +@pytest.mark.asyncio +async def test_discovery_root_includes_server_name_prefix(): + """When root discovery is hit and exactly 1 OAuth2 server exists, include server name in URLs.""" + try: + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + _build_oauth_authorization_server_response, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from fastapi import Request + except ImportError: + pytest.skip("MCP discoverable endpoints not available") + + global_mcp_server_manager.registry.clear() + oauth2_server = _create_oauth2_server() + global_mcp_server_manager.registry[oauth2_server.server_id] = oauth2_server + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://llm.example.com/" + mock_request.headers = {} + + try: + # Call with mcp_server_name=None (root discovery) + response = _build_oauth_authorization_server_response( + request=mock_request, + mcp_server_name=None, + ) + + # Should resolve to the single server and include its name in endpoint URLs + assert "/test_oauth/authorize" in response["authorization_endpoint"] + assert "/test_oauth/token" in response["token_endpoint"] + assert "/test_oauth/register" in response["registration_endpoint"] + assert response["scopes_supported"] == ["read", "write"] + finally: + global_mcp_server_manager.registry.clear()