Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import get_server_root_path
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer

router = APIRouter(
Expand Down Expand Up @@ -125,6 +126,29 @@ def decode_state_hash(encrypted_state: str) -> dict:
return state_data


def _resolve_oauth2_server_for_root_endpoints(
client_ip: Optional[str] = None,
) -> Optional[MCPServer]:
"""
Resolve the MCP server for root-level OAuth endpoints (no server name in path).

When the MCP SDK hits root-level endpoints like /register, /authorize, /token
without a server name prefix, we try to find the right server automatically.
Returns the server if exactly one OAuth2 server is configured, else None.
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)

registry = global_mcp_server_manager.get_filtered_registry(client_ip=client_ip)
oauth2_servers = [
s for s in registry.values() if s.auth_type == MCPAuth.oauth2
]
if len(oauth2_servers) == 1:
return oauth2_servers[0]
return None


async def authorize_with_server(
request: Request,
mcp_server: MCPServer,
Expand Down Expand Up @@ -305,6 +329,8 @@ async def authorize(
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
lookup_name, client_ip=client_ip
)
if mcp_server is None and mcp_server_name is None:
mcp_server = _resolve_oauth2_server_for_root_endpoints()
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
return await authorize_with_server(
Expand Down Expand Up @@ -350,6 +376,8 @@ async def token_endpoint(
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
lookup_name, client_ip=client_ip
)
if mcp_server is None and mcp_server_name is None:
mcp_server = _resolve_oauth2_server_for_root_endpoints()
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
return await exchange_token_with_server(
Expand Down Expand Up @@ -430,6 +458,13 @@ def _build_oauth_protected_resource_response(
)

request_base_url = get_request_base_url(request)

# When no server name provided, try to resolve the single OAuth2 server
if mcp_server_name is None:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
mcp_server_name = resolved.server_name or resolved.name

mcp_server: Optional[MCPServer] = None
if mcp_server_name:
client_ip = IPAddressUtils.get_mcp_client_ip(request)
Expand Down Expand Up @@ -535,6 +570,12 @@ def _build_oauth_authorization_server_response(

request_base_url = get_request_base_url(request)

# When no server name provided, try to resolve the single OAuth2 server
if mcp_server_name is None:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
mcp_server_name = resolved.server_name or resolved.name

authorization_endpoint = (
f"{request_base_url}/{mcp_server_name}/authorize"
if mcp_server_name
Expand Down Expand Up @@ -640,6 +681,19 @@ async def register_client(request: Request, mcp_server_name: Optional[str] = Non
"redirect_uris": [f"{request_base_url}/callback"],
}
if not mcp_server_name:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
return await register_client_with_server(
request=request,
mcp_server=resolved,
client_name=data.get("client_name", ""),
grant_types=data.get("grant_types", []),
response_types=data.get("response_types", []),
token_endpoint_auth_method=data.get(
"token_endpoint_auth_method", ""
),
fallback_client_id=resolved.server_name or resolved.name,
)
return dummy_return

client_ip = IPAddressUtils.get_mcp_client_ip(request)
Expand Down
5 changes: 4 additions & 1 deletion litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
get_request_base_url,
)
from litellm.proxy._experimental.mcp_server.utils import (
LITELLM_MCP_SERVER_DESCRIPTION,
LITELLM_MCP_SERVER_NAME,
Expand Down Expand Up @@ -1972,7 +1975,7 @@ async def handle_streamable_http_mcp(
)
if server and server.auth_type == MCPAuth.oauth2 and not oauth2_headers:
request = StarletteRequest(scope)
base_url = str(request.base_url).rstrip("/")
base_url = get_request_base_url(request)

authorization_uri = (
f"Bearer authorization_uri="
Expand Down
Loading
Loading