diff --git a/docs/my-website/docs/mcp.md b/docs/my-website/docs/mcp.md index d63b55ee29e..84d10c25931 100644 --- a/docs/my-website/docs/mcp.md +++ b/docs/my-website/docs/mcp.md @@ -506,7 +506,14 @@ Your OpenAPI specification should follow standard OpenAPI/Swagger conventions: - **Operation IDs**: Each operation should have a unique `operationId` (this becomes the tool name) - **Parameters**: Request parameters should be properly documented with types and descriptions -## MCP Oauth +## MCP OAuth + +LiteLLM supports OAuth 2.0 for MCP servers -- both interactive (PKCE) flows for user-facing clients and machine-to-machine (M2M) `client_credentials` for backend services. + +See the **[MCP OAuth guide](./mcp_oauth.md)** for setup instructions, sequence diagrams, and a test server. + +
+Detailed OAuth reference (click to expand) LiteLLM v 1.77.6 added support for OAuth 2.0 Client Credentials for MCP servers. @@ -588,6 +595,8 @@ sequenceDiagram See the official [MCP Authorization Flow](https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#authorization-flow-steps) for additional reference. +
+ ## Forwarding Custom Headers to MCP Servers @@ -1486,7 +1495,7 @@ async with stdio_client(server_params) as (read, write): **Q: How do I use OAuth2 client_credentials (machine-to-machine) with MCP servers behind LiteLLM?** -At the moment LiteLLM only forwards whatever `Authorization` header/value you configure for the MCP server; it does not issue OAuth2 tokens by itself. If your MCP requires the Client Credentials grant, obtain the access token directly from the authorization server and set that bearer token as the MCP server’s Authorization header value. LiteLLM does not yet fetch or refresh those machine-to-machine tokens on your behalf, but we plan to add first-class client_credentials support in a future release so the proxy can manage those tokens automatically. +LiteLLM supports automatic token management for the `client_credentials` grant. Configure `client_id`, `client_secret`, and `token_url` on your MCP server and LiteLLM will fetch, cache, and refresh tokens automatically. See the [MCP OAuth M2M guide](./mcp_oauth.md#machine-to-machine-m2m-auth) for setup instructions. **Q: When I fetch an OAuth token from the LiteLLM UI, where is it stored?** diff --git a/docs/my-website/docs/mcp_oauth.md b/docs/my-website/docs/mcp_oauth.md new file mode 100644 index 00000000000..ed69408196f --- /dev/null +++ b/docs/my-website/docs/mcp_oauth.md @@ -0,0 +1,192 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# MCP OAuth + +LiteLLM supports two OAuth 2.0 flows for MCP servers: + +| Flow | Use Case | How It Works | +|------|----------|--------------| +| **Interactive (PKCE)** | User-facing apps (Claude Code, Cursor) | Browser-based consent, per-user tokens | +| **Machine-to-Machine (M2M)** | Backend services, CI/CD, automated agents | `client_credentials` grant, proxy-managed tokens | + +## Interactive OAuth (PKCE) + +For user-facing MCP clients (Claude Code, Cursor), LiteLLM supports the full OAuth 2.0 authorization code flow with PKCE. + +### Setup + +```yaml title="config.yaml" showLineNumbers +mcp_servers: + github_mcp: + url: "https://api.githubcopilot.com/mcp" + auth_type: oauth2 + client_id: os.environ/GITHUB_OAUTH_CLIENT_ID + client_secret: os.environ/GITHUB_OAUTH_CLIENT_SECRET +``` + +[**See Claude Code Tutorial**](./tutorials/claude_responses_api#connecting-mcp-servers) + +### How It Works + +```mermaid +sequenceDiagram + participant Browser as User-Agent (Browser) + participant Client as Client + participant LiteLLM as LiteLLM Proxy + participant MCP as MCP Server (Resource Server) + participant Auth as Authorization Server + + Note over Client,LiteLLM: Step 1 – Resource discovery + Client->>LiteLLM: GET /.well-known/oauth-protected-resource/{mcp_server_name}/mcp + LiteLLM->>Client: Return resource metadata + + Note over Client,LiteLLM: Step 2 – Authorization server discovery + Client->>LiteLLM: GET /.well-known/oauth-authorization-server/{mcp_server_name} + LiteLLM->>Client: Return authorization server metadata + + Note over Client,Auth: Step 3 – Dynamic client registration + Client->>LiteLLM: POST /{mcp_server_name}/register + LiteLLM->>Auth: Forward registration request + Auth->>LiteLLM: Issue client credentials + LiteLLM->>Client: Return client credentials + + Note over Client,Browser: Step 4 – User authorization (PKCE) + Client->>Browser: Open authorization URL + code_challenge + resource + Browser->>Auth: Authorization request + Note over Auth: User authorizes + Auth->>Browser: Redirect with authorization code + Browser->>LiteLLM: Callback to LiteLLM with code + LiteLLM->>Browser: Redirect back with authorization code + Browser->>Client: Callback with authorization code + + Note over Client,Auth: Step 5 – Token exchange + Client->>LiteLLM: Token request + code_verifier + resource + LiteLLM->>Auth: Forward token request + Auth->>LiteLLM: Access (and refresh) token + LiteLLM->>Client: Return tokens + + Note over Client,MCP: Step 6 – Authenticated MCP call + Client->>LiteLLM: MCP request with access token + LiteLLM API key + LiteLLM->>MCP: MCP request with Bearer token + MCP-->>LiteLLM: MCP response + LiteLLM-->>Client: Return MCP response +``` + +**Participants** + +- **Client** -- The MCP-capable AI agent (e.g., Claude Code, Cursor, or another IDE/agent) that initiates OAuth discovery, authorization, and tool invocations on behalf of the user. +- **LiteLLM Proxy** -- Mediates all OAuth discovery, registration, token exchange, and MCP traffic while protecting stored credentials. +- **Authorization Server** -- Issues OAuth 2.0 tokens via dynamic client registration, PKCE authorization, and token endpoints. +- **MCP Server (Resource Server)** -- The protected MCP endpoint that receives LiteLLM's authenticated JSON-RPC requests. +- **User-Agent (Browser)** -- Temporarily involved so the end user can grant consent during the authorization step. + +**Flow Steps** + +1. **Resource Discovery**: The client fetches MCP resource metadata from LiteLLM's `.well-known/oauth-protected-resource` endpoint to understand scopes and capabilities. +2. **Authorization Server Discovery**: The client retrieves the OAuth server metadata (token endpoint, authorization endpoint, supported PKCE methods) through LiteLLM's `.well-known/oauth-authorization-server` endpoint. +3. **Dynamic Client Registration**: The client registers through LiteLLM, which forwards the request to the authorization server (RFC 7591). If the provider doesn't support dynamic registration, you can pre-store `client_id`/`client_secret` in LiteLLM (e.g., GitHub MCP) and the flow proceeds the same way. +4. **User Authorization**: The client launches a browser session (with code challenge and resource hints). The user approves access, the authorization server sends the code through LiteLLM back to the client. +5. **Token Exchange**: The client calls LiteLLM with the authorization code, code verifier, and resource. LiteLLM exchanges them with the authorization server and returns the issued access/refresh tokens. +6. **MCP Invocation**: With a valid token, the client sends the MCP JSON-RPC request (plus LiteLLM API key) to LiteLLM, which forwards it to the MCP server and relays the tool response. + +See the official [MCP Authorization Flow](https://modelcontextprotocol.io/specification/2025-06-18/basic/authorization#authorization-flow-steps) for additional reference. + +## Machine-to-Machine (M2M) Auth + +LiteLLM automatically fetches, caches, and refreshes OAuth2 tokens using the `client_credentials` grant. No manual token management required. + +### Setup + + + + +```yaml title="config.yaml" showLineNumbers +mcp_servers: + my_mcp_server: + url: "https://my-mcp-server.com/mcp" + auth_type: oauth2 + client_id: os.environ/MCP_CLIENT_ID + client_secret: os.environ/MCP_CLIENT_SECRET + token_url: "https://auth.example.com/oauth/token" + scopes: ["mcp:read", "mcp:write"] # optional +``` + + + + +Navigate to **MCP Servers → Add Server → Authentication → OAuth**, then fill in `client_id`, `client_secret`, and `token_url`. + + + + +### How It Works + +1. On first MCP request, LiteLLM POSTs to `token_url` with `grant_type=client_credentials` +2. The access token is cached in-memory with TTL = `expires_in - 60s` +3. Subsequent requests reuse the cached token +4. When the token expires, LiteLLM fetches a new one automatically + +```mermaid +sequenceDiagram + participant Client as Client + participant LiteLLM as LiteLLM Proxy + participant Auth as Authorization Server + participant MCP as MCP Server + + Client->>LiteLLM: MCP request + LiteLLM API key + LiteLLM->>Auth: POST /oauth/token (client_credentials) + Auth->>LiteLLM: access_token (expires_in: 3600) + LiteLLM->>MCP: MCP request + Bearer token + MCP-->>LiteLLM: MCP response + LiteLLM-->>Client: MCP response + + Note over LiteLLM: Token cached for subsequent requests + Client->>LiteLLM: Next MCP request + LiteLLM->>MCP: MCP request + cached Bearer token + MCP-->>LiteLLM: MCP response + LiteLLM-->>Client: MCP response +``` + +### Test with Mock Server + +Use [BerriAI/mock-oauth2-mcp-server](https://github.com/BerriAI/mock-oauth2-mcp-server) to test locally: + +```bash title="Terminal 1 - Start mock server" showLineNumbers +pip install fastapi uvicorn +python mock_oauth2_mcp_server.py # starts on :8765 +``` + +```yaml title="config.yaml" showLineNumbers +mcp_servers: + test_oauth2: + url: "http://localhost:8765/mcp" + auth_type: oauth2 + client_id: "test-client" + client_secret: "test-secret" + token_url: "http://localhost:8765/oauth/token" +``` + +```bash title="Terminal 2 - Start proxy and test" showLineNumbers +litellm --config config.yaml --port 4000 + +# List tools +curl http://localhost:4000/mcp-rest/tools/list \ + -H "Authorization: Bearer sk-1234" + +# Call a tool +curl http://localhost:4000/mcp-rest/tools/call \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"name": "echo", "arguments": {"message": "hello"}}' +``` + +### Config Reference + +| Field | Required | Description | +|-------|----------|-------------| +| `auth_type` | Yes | Must be `oauth2` | +| `client_id` | Yes | OAuth2 client ID. Supports `os.environ/VAR_NAME` | +| `client_secret` | Yes | OAuth2 client secret. Supports `os.environ/VAR_NAME` | +| `token_url` | Yes | Token endpoint URL | +| `scopes` | No | List of scopes to request | diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 0d95bf7e545..2c3dfb2b863 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -563,6 +563,7 @@ const sidebars = { items: [ "mcp", "mcp_usage", + "mcp_oauth", "mcp_public_internet", "mcp_semantic_filter", "mcp_control", diff --git a/litellm/constants.py b/litellm/constants.py index f5f2c7a49e4..9c25cf77906 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -83,6 +83,20 @@ os.getenv("MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH", 150) ) +# MCP OAuth2 Client Credentials Defaults +MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS = int( + os.getenv("MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS", "60") +) +MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE = int( + os.getenv("MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE", "200") +) +MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL = int( + os.getenv("MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL", "3600") +) +MCP_OAUTH2_TOKEN_CACHE_MIN_TTL = int( + os.getenv("MCP_OAUTH2_TOKEN_CACHE_MIN_TTL", "10") +) + LITELLM_UI_ALLOW_HEADERS = [ "x-litellm-semantic-filter", "x-litellm-semantic-filter-tools", diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index e7174f943b2..532aea249bf 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -36,6 +36,7 @@ from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( MCPRequestHandler, ) +from litellm.proxy._experimental.mcp_server.oauth2_token_cache import resolve_mcp_auth from litellm.proxy._experimental.mcp_server.utils import ( MCP_TOOL_PREFIX_SEPARATOR, add_server_prefix_to_name, @@ -833,7 +834,7 @@ def _build_stdio_env( return resolved_env - def _create_mcp_client( + async def _create_mcp_client( self, server: MCPServer, mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None, @@ -843,13 +844,22 @@ def _create_mcp_client( """ Create an MCPClient instance for the given server. + Auth resolution (single place for all auth logic): + 1. ``mcp_auth_header`` — per-request/per-user override + 2. OAuth2 client_credentials token — auto-fetched and cached + 3. ``server.authentication_token`` — static token from config/DB + Args: - server (MCPServer): The server configuration - mcp_auth_header: MCP auth header to be passed to the MCP server. This is optional and will be used if provided. + server: The server configuration. + mcp_auth_header: Optional per-request auth override. + extra_headers: Additional headers to forward. + stdio_env: Environment variables for stdio transport. Returns: - MCPClient: Configured MCP client instance + Configured MCP client instance. """ + auth_value = await resolve_mcp_auth(server, mcp_auth_header) + transport = server.transport or MCPTransport.sse # Handle stdio transport @@ -868,7 +878,7 @@ def _create_mcp_client( server_url="", # Not used for stdio transport_type=transport, auth_type=server.auth_type, - auth_value=mcp_auth_header or server.authentication_token, + auth_value=auth_value, timeout=60.0, stdio_config=stdio_config, extra_headers=extra_headers, @@ -880,7 +890,7 @@ def _create_mcp_client( server_url=server_url, transport_type=transport, auth_type=server.auth_type, - auth_value=mcp_auth_header or server.authentication_token, + auth_value=auth_value, timeout=60.0, extra_headers=extra_headers, ) @@ -920,7 +930,7 @@ async def _get_tools_from_server( stdio_env = self._build_stdio_env(server, raw_headers) - client = self._create_mcp_client( + client = await self._create_mcp_client( server=server, mcp_auth_header=mcp_auth_header, extra_headers=extra_headers, @@ -980,7 +990,7 @@ async def get_prompts_from_server( stdio_env = self._build_stdio_env(server, raw_headers) - client = self._create_mcp_client( + client = await self._create_mcp_client( server=server, mcp_auth_header=mcp_auth_header, extra_headers=extra_headers, @@ -1024,7 +1034,7 @@ async def get_resources_from_server( stdio_env = self._build_stdio_env(server, raw_headers) - client = self._create_mcp_client( + client = await self._create_mcp_client( server=server, mcp_auth_header=mcp_auth_header, extra_headers=extra_headers, @@ -1068,7 +1078,7 @@ async def get_resource_templates_from_server( stdio_env = self._build_stdio_env(server, raw_headers) - client = self._create_mcp_client( + client = await self._create_mcp_client( server=server, mcp_auth_header=mcp_auth_header, extra_headers=extra_headers, @@ -1109,7 +1119,7 @@ async def read_resource_from_server( stdio_env = self._build_stdio_env(server, raw_headers) - client = self._create_mcp_client( + client = await self._create_mcp_client( server=server, mcp_auth_header=mcp_auth_header, extra_headers=extra_headers, @@ -1139,7 +1149,7 @@ async def get_prompt_from_server( stdio_env = self._build_stdio_env(server, raw_headers) - client = self._create_mcp_client( + client = await self._create_mcp_client( server=server, mcp_auth_header=mcp_auth_header, extra_headers=extra_headers, @@ -1943,7 +1953,7 @@ async def _call_regular_mcp_tool( stdio_env = self._build_stdio_env(mcp_server, raw_headers) - client = self._create_mcp_client( + client = await self._create_mcp_client( server=mcp_server, mcp_auth_header=server_auth_header, extra_headers=extra_headers, @@ -2119,8 +2129,8 @@ async def _initialize_tool_name_to_mcp_server_name_mapping(self): Note: This now handles prefixed tool names """ for server in self.get_registry().values(): - if server.auth_type == MCPAuth.oauth2: - # Skip OAuth2 servers for now as they may require user-specific tokens + if server.needs_user_oauth_token: + # Skip OAuth2 servers that rely on user-provided tokens continue tools = await self._get_tools_from_server(server) for tool in tools: @@ -2414,7 +2424,7 @@ async def health_check_server( should_skip_health_check = False # Skip if auth_type is oauth2 - if server.auth_type == MCPAuth.oauth2: + if server.needs_user_oauth_token: should_skip_health_check = True # Skip if auth_type is not none and authentication_token is missing elif ( @@ -2429,7 +2439,7 @@ async def health_check_server( if server.static_headers: extra_headers.update(server.static_headers) - client = self._create_mcp_client( + client = await self._create_mcp_client( server=server, mcp_auth_header=None, extra_headers=extra_headers, diff --git a/litellm/proxy/_experimental/mcp_server/oauth2_token_cache.py b/litellm/proxy/_experimental/mcp_server/oauth2_token_cache.py new file mode 100644 index 00000000000..0de381ee1df --- /dev/null +++ b/litellm/proxy/_experimental/mcp_server/oauth2_token_cache.py @@ -0,0 +1,163 @@ +""" +OAuth2 client_credentials token cache for MCP servers. + +Automatically fetches and refreshes access tokens for MCP servers configured +with ``client_id``, ``client_secret``, and ``token_url``. +""" + +import asyncio +from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union + +import httpx + +from litellm._logging import verbose_logger +from litellm.caching.in_memory_cache import InMemoryCache +from litellm.constants import ( + MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL, + MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE, + MCP_OAUTH2_TOKEN_CACHE_MIN_TTL, + MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS, +) +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.types.llms.custom_http import httpxSpecialProvider + +if TYPE_CHECKING: + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + +class MCPOAuth2TokenCache(InMemoryCache): + """ + In-memory cache for OAuth2 client_credentials tokens, keyed by server_id. + + Inherits from ``InMemoryCache`` for TTL-based storage and eviction. + Adds per-server ``asyncio.Lock`` to prevent duplicate concurrent fetches. + """ + + def __init__(self) -> None: + super().__init__( + max_size_in_memory=MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE, + default_ttl=MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL, + ) + self._locks: Dict[str, asyncio.Lock] = {} + + def _get_lock(self, server_id: str) -> asyncio.Lock: + return self._locks.setdefault(server_id, asyncio.Lock()) + + async def async_get_token(self, server: "MCPServer") -> Optional[str]: + """Return a valid access token, fetching or refreshing as needed. + + Returns ``None`` when the server lacks client credentials config. + """ + if not server.has_client_credentials: + return None + + server_id = server.server_id + + # Fast path — cached token is still valid + cached = self.get_cache(server_id) + if cached is not None: + return cached + + # Slow path — acquire per-server lock then double-check + async with self._get_lock(server_id): + cached = self.get_cache(server_id) + if cached is not None: + return cached + + token, ttl = await self._fetch_token(server) + self.set_cache(server_id, token, ttl=ttl) + return token + + async def _fetch_token(self, server: "MCPServer") -> Tuple[str, int]: + """POST to ``token_url`` with ``grant_type=client_credentials``. + + Returns ``(access_token, ttl_seconds)`` where ttl accounts for the + expiry buffer so the cache entry expires before the real token does. + """ + client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP) + + if not server.client_id or not server.client_secret or not server.token_url: + raise ValueError( + f"MCP server '{server.server_id}' missing required OAuth2 fields: " + f"client_id={bool(server.client_id)}, " + f"client_secret={bool(server.client_secret)}, " + f"token_url={bool(server.token_url)}" + ) + + data: Dict[str, str] = { + "grant_type": "client_credentials", + "client_id": server.client_id, + "client_secret": server.client_secret, + } + if server.scopes: + data["scope"] = " ".join(server.scopes) + + verbose_logger.debug( + "Fetching OAuth2 client_credentials token for MCP server %s", + server.server_id, + ) + + try: + response = await client.post(server.token_url, data=data) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + raise ValueError( + f"OAuth2 token request for MCP server '{server.server_id}' " + f"failed with status {exc.response.status_code}" + ) from exc + + body = response.json() + + if not isinstance(body, dict): + raise ValueError( + f"OAuth2 token response for MCP server '{server.server_id}' " + f"returned non-object JSON (got {type(body).__name__})" + ) + + access_token = body.get("access_token") + if not access_token: + raise ValueError( + f"OAuth2 token response for MCP server '{server.server_id}' " + f"missing 'access_token'" + ) + + # Safely parse expires_in — providers may return null or non-numeric values + raw_expires_in = body.get("expires_in") + try: + expires_in = int(raw_expires_in) if raw_expires_in is not None else MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL + except (TypeError, ValueError): + expires_in = MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL + + ttl = max(expires_in - MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS, MCP_OAUTH2_TOKEN_CACHE_MIN_TTL) + + verbose_logger.info( + "Fetched OAuth2 token for MCP server %s (expires in %ds)", + server.server_id, + expires_in, + ) + return access_token, ttl + + def invalidate(self, server_id: str) -> None: + """Remove a cached token (e.g. after a 401).""" + self.delete_cache(server_id) + + +mcp_oauth2_token_cache = MCPOAuth2TokenCache() + + +async def resolve_mcp_auth( + server: "MCPServer", + mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None, +) -> Optional[Union[str, Dict[str, str]]]: + """Resolve the auth value for an MCP server. + + Priority: + 1. ``mcp_auth_header`` — per-request/per-user override + 2. OAuth2 client_credentials token — auto-fetched and cached + 3. ``server.authentication_token`` — static token from config/DB + """ + if mcp_auth_header: + return mcp_auth_header + if server.has_client_credentials: + return await mcp_oauth2_token_cache.async_get_token(server) + return server.authentication_token diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index 2fe40bb197e..65d3a2caa16 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -541,7 +541,7 @@ async def _execute_with_mcp_client( static_headers=request.static_headers, ) - client = global_mcp_server_manager._create_mcp_client( + client = await global_mcp_server_manager._create_mcp_client( server=server_model, mcp_auth_header=mcp_auth_header, extra_headers=merged_headers, diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 4e6c1810255..2cd385c5bf6 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, ConfigDict from litellm.proxy._types import MCPAuthType, MCPTransportType +from litellm.types.mcp import MCPAuth # MCPInfo now allows arbitrary additional fields for custom metadata MCPInfo = Dict[str, Any] @@ -54,3 +55,13 @@ class MCPServer(BaseModel): available_on_public_internet: bool = False updated_at: Optional[datetime] = None model_config = ConfigDict(arbitrary_types_allowed=True) + + @property + def has_client_credentials(self) -> bool: + """True if this server has OAuth2 client_credentials config (client_id, client_secret, token_url).""" + return bool(self.client_id and self.client_secret and self.token_url) + + @property + def needs_user_oauth_token(self) -> bool: + """True if this is an OAuth2 server that relies on per-user tokens (no client_credentials).""" + return self.auth_type == MCPAuth.oauth2 and not self.has_client_credentials diff --git a/tests/mcp_tests/test_mcp_auth_priority.py b/tests/mcp_tests/test_mcp_auth_priority.py index bc217a7903b..ad6e9438edd 100644 --- a/tests/mcp_tests/test_mcp_auth_priority.py +++ b/tests/mcp_tests/test_mcp_auth_priority.py @@ -12,7 +12,8 @@ from litellm.types.mcp_server.mcp_server_manager import MCPServer -def test_mcp_server_works_without_config_auth_value(): +@pytest.mark.asyncio +async def test_mcp_server_works_without_config_auth_value(): """ Test that MCP servers work without auth_value in config when headers are provided. This validates that auth_value is truly optional in config.yaml. @@ -32,7 +33,7 @@ def test_mcp_server_works_without_config_auth_value(): manager = MCPServerManager() # Test that it works with only header auth - client = manager._create_mcp_client( + client = await manager._create_mcp_client( server=server_without_config_auth, mcp_auth_header="Bearer token_from_header_only", ) @@ -58,7 +59,7 @@ async def test_mcp_server_config_auth_value_header_used(token_key): await manager.load_servers_from_config(config) server = next(iter(manager.config_mcp_servers.values())) - client = manager._create_mcp_client(server) + client = await manager._create_mcp_client(server) headers = client._get_auth_headers() assert headers["Authorization"] == "Bearer example_token" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 9aa7de89eff..07c3dfcc763 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -886,7 +886,7 @@ async def test_oauth2_headers_passed_to_mcp_client(): # This will capture the arguments passed to _create_mcp_client captured_client_args = {} - def mock_create_mcp_client( + async def mock_create_mcp_client( server, mcp_auth_header=None, extra_headers=None, diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index f241b2aa0d5..abb8dd49159 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -1,5 +1,5 @@ -import json import importlib +import json import logging import os import sys @@ -21,8 +21,8 @@ Prompt, ResourceTemplate, TextResourceContents, - Tool as MCPTool, ) +from mcp.types import Tool as MCPTool from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( MCPServerManager, @@ -92,7 +92,7 @@ async def test_add_update_server_stdio(self): assert added_server.args == ["-m", "server"] assert added_server.env == {"DEBUG": "1", "TEST": "1"} - def test_create_mcp_client_stdio(self): + async def test_create_mcp_client_stdio(self): """Test creating MCP client for stdio transport""" manager = MCPServerManager() @@ -106,7 +106,7 @@ def test_create_mcp_client_stdio(self): env={"NODE_ENV": "test"}, ) - client = manager._create_mcp_client(stdio_server) + client = await manager._create_mcp_client(stdio_server) assert client.transport_type == MCPTransport.stdio assert client.stdio_config is not None @@ -404,14 +404,14 @@ async def test_call_regular_mcp_tool_case_insensitive_extra_headers(self): ) captured_extra_headers = None - def capture_create_mcp_client( + async def capture_create_mcp_client( server, mcp_auth_header, extra_headers, stdio_env ): # pragma: no cover - helper nonlocal captured_extra_headers captured_extra_headers = extra_headers return mock_client - manager._create_mcp_client = MagicMock(side_effect=capture_create_mcp_client) + manager._create_mcp_client = AsyncMock(side_effect=capture_create_mcp_client) result = await manager._call_regular_mcp_tool( mcp_server=server, @@ -446,7 +446,7 @@ async def test_get_prompts_from_server_success(self): mock_client = AsyncMock() mock_client.list_prompts = AsyncMock(return_value=[mock_prompt]) - with patch.object(manager, "_create_mcp_client", return_value=mock_client): + with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client): prompts = await manager.get_prompts_from_server(server, add_prefix=True) mock_client.list_prompts.assert_awaited_once() @@ -474,7 +474,7 @@ async def test_get_prompt_from_server_success(self): mock_client = AsyncMock() mock_client.get_prompt = AsyncMock(return_value=mock_result) - with patch.object(manager, "_create_mcp_client", return_value=mock_client): + with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client): result = await manager.get_prompt_from_server( server=server, prompt_name="hello", @@ -507,7 +507,7 @@ async def test_get_resources_from_server_success(self): mock_client.list_resources = AsyncMock(return_value=mock_resources) prefixed_resources = [Resource(name="alias-server-file", uri="https://example.com/file")] - with patch.object(manager, "_create_mcp_client", return_value=mock_client) as mock_create_client, patch.object( + with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client, patch.object( manager, "_create_prefixed_resources", return_value=prefixed_resources, @@ -556,7 +556,7 @@ async def test_get_resource_templates_from_server_success(self): ) ] - with patch.object(manager, "_create_mcp_client", return_value=mock_client) as mock_create_client, patch.object( + with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client, patch.object( manager, "_create_prefixed_resource_templates", return_value=prefixed_templates, @@ -604,7 +604,7 @@ async def test_read_resource_from_server_success(self): ) mock_client.read_resource = AsyncMock(return_value=read_result) - with patch.object(manager, "_create_mcp_client", return_value=mock_client) as mock_create_client: + with patch.object(manager, "_create_mcp_client", new_callable=AsyncMock, return_value=mock_client) as mock_create_client: result = await manager.read_resource_from_server( server=server, url="https://example.com/resource", @@ -816,7 +816,7 @@ async def test_health_check_server_healthy(self): # Mock successful client.run_with_session mock_client = AsyncMock() mock_client.run_with_session = AsyncMock(return_value="ok") - manager._create_mcp_client = MagicMock(return_value=mock_client) + manager._create_mcp_client = AsyncMock(return_value=mock_client) # Perform health check result = await manager.health_check_server("test-server") @@ -850,7 +850,7 @@ async def test_health_check_server_unhealthy(self): mock_client.run_with_session = AsyncMock( side_effect=Exception("Connection timeout") ) - manager._create_mcp_client = MagicMock(return_value=mock_client) + manager._create_mcp_client = AsyncMock(return_value=mock_client) # Perform health check result = await manager.health_check_server("test-server") @@ -898,7 +898,7 @@ async def test_health_check_server_oauth2_skips_check(self): manager.get_mcp_server_by_id = MagicMock(return_value=server) # _create_mcp_client should not be called for OAuth2 servers - manager._create_mcp_client = MagicMock() + manager._create_mcp_client = AsyncMock() # Perform health check result = await manager.health_check_server("oauth2-server") @@ -931,7 +931,7 @@ async def test_health_check_server_no_token_skips_check(self): manager.get_mcp_server_by_id = MagicMock(return_value=server) # _create_mcp_client should not be called - manager._create_mcp_client = MagicMock() + manager._create_mcp_client = AsyncMock() # Perform health check result = await manager.health_check_server("no-token-server") @@ -971,12 +971,12 @@ async def test_health_check_server_with_static_headers(self): # Capture the extra_headers passed to _create_mcp_client captured_extra_headers = None - def capture_create_mcp_client(server, mcp_auth_header, extra_headers, stdio_env): + async def capture_create_mcp_client(server, mcp_auth_header, extra_headers, stdio_env): nonlocal captured_extra_headers captured_extra_headers = extra_headers return mock_client - manager._create_mcp_client = MagicMock(side_effect=capture_create_mcp_client) + manager._create_mcp_client = AsyncMock(side_effect=capture_create_mcp_client) # Perform health check result = await manager.health_check_server("test-server") @@ -1310,7 +1310,7 @@ async def test_get_tools_from_server_add_prefix(self): ) # Mock client creation and fetching tools - manager._create_mcp_client = MagicMock(return_value=object()) + manager._create_mcp_client = AsyncMock(return_value=object()) # Tools returned upstream (unprefixed from provider) upstream_tool = MCPTool( @@ -1895,7 +1895,7 @@ async def mock_call_tool(params, host_progress_callback=None): mock_client.call_tool.side_effect = mock_call_tool # Mock _create_mcp_client to return our mock client - manager._create_mcp_client = MagicMock(return_value=mock_client) + manager._create_mcp_client = AsyncMock(return_value=mock_client) # Mock user auth with no restrictions user_api_key_auth = MagicMock() diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_oauth2_token_cache.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_oauth2_token_cache.py new file mode 100644 index 00000000000..55735dca98e --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_oauth2_token_cache.py @@ -0,0 +1,157 @@ +""" +Core tests for MCP OAuth2 machine-to-machine (client_credentials) token management. + +Covers the critical path: resolve_mcp_auth(), token caching, auth priority, +fallback to static token, and the skip-condition property. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from litellm.proxy._experimental.mcp_server.oauth2_token_cache import ( + MCPOAuth2TokenCache, + resolve_mcp_auth, +) +from litellm.proxy._types import MCPTransport +from litellm.types.mcp import MCPAuth +from litellm.types.mcp_server.mcp_server_manager import MCPServer + + +def _server(**overrides) -> MCPServer: + defaults = dict( + server_id="srv-1", + name="test", + url="https://mcp.example.com/mcp", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + client_id="cid", + client_secret="csec", + token_url="https://auth.example.com/token", + ) + defaults.update(overrides) + return MCPServer(**defaults) + + +def _token_response(token="tok-abc", expires_in=3600): + resp = MagicMock() + resp.json.return_value = { + "access_token": token, + "token_type": "bearer", + "expires_in": expires_in, + } + resp.raise_for_status = MagicMock() + return resp + + +@pytest.mark.asyncio +async def test_resolve_mcp_auth_fetches_oauth2_token(): + """resolve_mcp_auth fetches a token via client_credentials when the server has OAuth2 config.""" + server = _server() + mock_client = AsyncMock() + mock_client.post.return_value = _token_response("m2m-token-1") + + with patch( + "litellm.proxy._experimental.mcp_server.oauth2_token_cache.get_async_httpx_client", + return_value=mock_client, + ): + result = await resolve_mcp_auth(server) + + assert result == "m2m-token-1" + mock_client.post.assert_called_once() + post_data = mock_client.post.call_args[1]["data"] + assert post_data["grant_type"] == "client_credentials" + assert post_data["client_id"] == "cid" + assert post_data["client_secret"] == "csec" + + +@pytest.mark.asyncio +async def test_token_cached_across_calls(): + """Second resolve_mcp_auth call reuses the cached token — only 1 HTTP POST.""" + cache = MCPOAuth2TokenCache() + server = _server() + mock_client = AsyncMock() + mock_client.post.return_value = _token_response("cached-tok") + + with patch( + "litellm.proxy._experimental.mcp_server.oauth2_token_cache.get_async_httpx_client", + return_value=mock_client, + ), patch( + "litellm.proxy._experimental.mcp_server.oauth2_token_cache.mcp_oauth2_token_cache", + cache, + ): + t1 = await resolve_mcp_auth(server) + t2 = await resolve_mcp_auth(server) + + assert t1 == t2 == "cached-tok" + assert mock_client.post.call_count == 1 + + +@pytest.mark.asyncio +async def test_per_request_header_beats_oauth2(): + """An explicit mcp_auth_header takes priority over the OAuth2 token.""" + server = _server() + result = await resolve_mcp_auth(server, mcp_auth_header="Bearer user-tok") + assert result == "Bearer user-tok" + + +@pytest.mark.asyncio +async def test_falls_back_to_static_token(): + """When no client_credentials config, resolve_mcp_auth returns the static authentication_token.""" + server = _server( + client_id=None, + client_secret=None, + token_url=None, + authentication_token="static-tok-xyz", + ) + result = await resolve_mcp_auth(server) + assert result == "static-tok-xyz" + + +def test_needs_user_oauth_token_property(): + """needs_user_oauth_token is True only for OAuth2 servers WITHOUT client_credentials.""" + # OAuth2 with credentials → M2M, no user token needed + assert _server().needs_user_oauth_token is False + + # OAuth2 without credentials → needs per-user token + assert _server(client_id=None, client_secret=None, token_url=None).needs_user_oauth_token is True + + # Non-OAuth2 → never needs user OAuth token + assert _server(auth_type=MCPAuth.bearer_token).needs_user_oauth_token is False + + +@pytest.mark.asyncio +async def test_http_error_raises_value_error(): + """HTTP errors from the token endpoint are wrapped in a clear ValueError.""" + server = _server() + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Unauthorized", request=MagicMock(), response=mock_response, + ) + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch( + "litellm.proxy._experimental.mcp_server.oauth2_token_cache.get_async_httpx_client", + return_value=mock_client, + ), pytest.raises(ValueError, match="failed with status 401"): + await resolve_mcp_auth(server) + + +@pytest.mark.asyncio +async def test_non_dict_response_raises_value_error(): + """A non-dict JSON response raises a clear ValueError.""" + server = _server() + resp = MagicMock() + resp.json.return_value = ["not", "a", "dict"] + resp.raise_for_status = MagicMock() + mock_client = AsyncMock() + mock_client.post.return_value = resp + + with patch( + "litellm.proxy._experimental.mcp_server.oauth2_token_cache.get_async_httpx_client", + return_value=mock_client, + ), pytest.raises(ValueError, match="non-object JSON"): + await resolve_mcp_auth(server) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py index 3091f03a53d..f6c8115f7bc 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py @@ -76,7 +76,7 @@ def _route_has_dependency(route, dependency) -> bool: class TestExecuteWithMcpClient: @pytest.mark.asyncio async def test_redacts_stack_trace(self, monkeypatch): - def fake_create_client(*args, **kwargs): + async def fake_create_client(*args, **kwargs): return object() monkeypatch.setattr( @@ -114,7 +114,7 @@ async def test_forwards_static_headers(self, monkeypatch): def fake_build_stdio_env(server, raw_headers): return None - def fake_create_client(*args, **kwargs): + async def fake_create_client(*args, **kwargs): captured["extra_headers"] = kwargs.get("extra_headers") return object()