diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index 3e8f9bc337b..5e21ff9754f 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -209,6 +209,8 @@ def _get_auth_headers(self) -> dict: headers["X-API-Key"] = self._mcp_auth_value elif self.auth_type == MCPAuth.authorization: headers["Authorization"] = self._mcp_auth_value + elif self.auth_type == MCPAuth.oauth2: + headers["Authorization"] = f"Bearer {self._mcp_auth_value}" elif isinstance(self._mcp_auth_value, dict): headers.update(self._mcp_auth_value) diff --git a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py index 7e70b5baae4..786cfbfb008 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py +++ b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py @@ -1,11 +1,12 @@ from typing import Dict, List, Optional, Set, Tuple +from fastapi import HTTPException from starlette.datastructures import Headers from starlette.requests import Request from starlette.types import Scope from litellm._logging import verbose_logger -from litellm.proxy._types import LiteLLM_TeamTable, SpecialHeaders, UserAPIKeyAuth +from litellm.proxy._types import LiteLLM_TeamTable, ProxyException, SpecialHeaders, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth @@ -63,6 +64,13 @@ async def process_mcp_request( HTTPException: If headers are invalid or missing required headers """ headers = MCPRequestHandler._safe_get_headers_from_scope(scope) + + # Check if there is an explicit LiteLLM API key (primary header) + has_explicit_litellm_key = ( + headers.get(MCPRequestHandler.LITELLM_API_KEY_HEADER_NAME_PRIMARY) + is not None + ) + litellm_api_key = ( MCPRequestHandler.get_litellm_api_key_from_headers(headers) or "" ) @@ -106,16 +114,38 @@ async def mock_body(): request.body = mock_body # type: ignore if ".well-known" in str(request.url): # public routes validated_user_api_key_auth = UserAPIKeyAuth() - # elif litellm_api_key == "": - # from fastapi import HTTPException - - # raise HTTPException( - # status_code=401, - # detail="LiteLLM API key is missing. Please add it or use OAuth authentication.", - # headers={ - # "WWW-Authenticate": f'Bearer resource_metadata=f"{request.base_url}/.well-known/oauth-protected-resource"', - # }, - # ) + elif has_explicit_litellm_key: + # Explicit x-litellm-api-key provided - always validate normally + validated_user_api_key_auth = await user_api_key_auth( + api_key=litellm_api_key, request=request + ) + elif oauth2_headers: + # No x-litellm-api-key, but Authorization header present. + # Could be a LiteLLM key (backward compat) OR an OAuth2 token + # from an upstream MCP provider (e.g. Atlassian). + # Try LiteLLM auth first; on auth failure, treat as OAuth2 passthrough. + try: + validated_user_api_key_auth = await user_api_key_auth( + api_key=litellm_api_key, request=request + ) + except HTTPException as e: + if e.status_code in (401, 403): + verbose_logger.debug( + "MCP OAuth2: Authorization header is not a valid LiteLLM key, " + "treating as OAuth2 token passthrough" + ) + validated_user_api_key_auth = UserAPIKeyAuth() + else: + raise + except ProxyException as e: + if str(e.code) in ("401", "403"): + verbose_logger.debug( + "MCP OAuth2: Authorization header is not a valid LiteLLM key, " + "treating as OAuth2 token passthrough" + ) + validated_user_api_key_auth = UserAPIKeyAuth() + else: + raise else: validated_user_api_key_auth = await user_api_key_auth( api_key=litellm_api_key, request=request diff --git a/tests/mcp_tests/test_mcp_client_unit.py b/tests/mcp_tests/test_mcp_client_unit.py index 9533e56bcc8..c70d0c42cd8 100644 --- a/tests/mcp_tests/test_mcp_client_unit.py +++ b/tests/mcp_tests/test_mcp_client_unit.py @@ -77,6 +77,27 @@ def test_get_auth_headers(self): "Authorization": "Token custom_token", } + # OAuth2 + client = MCPClient( + "http://example.com", + auth_type=MCPAuth.oauth2, + auth_value="oauth2-access-token-xyz", + ) + headers = client._get_auth_headers() + assert headers == { + "Authorization": "Bearer oauth2-access-token-xyz", + } + + # OAuth2 with extra_headers (per-user flow overrides auth_value) + client = MCPClient( + "http://example.com", + auth_type=MCPAuth.oauth2, + auth_value="static-server-token", + extra_headers={"Authorization": "Bearer per-user-token"}, + ) + headers = client._get_auth_headers() + assert headers["Authorization"] == "Bearer per-user-token" + # No auth client = MCPClient("http://example.com") headers = client._get_auth_headers() diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py index e1e4b3a8b6d..68afe784988 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py @@ -544,6 +544,235 @@ async def test_header_extraction(self, headers, expected_result): assert mcp_server_auth_headers == {} +@pytest.mark.asyncio +class TestMCPOAuth2AuthFlow: + """Test suite for OAuth2 authentication flow in MCP requests. + + Tests the fix for the 'Capabilities: none' bug where OAuth2 tokens + from upstream MCP providers (e.g., Atlassian) were mistakenly validated + as LiteLLM API keys, causing auth failures and empty tool listings. + """ + + async def test_oauth2_token_in_authorization_header_fallback(self): + """ + When only Authorization header is present with a non-LiteLLM OAuth2 token, + auth should fall back to permissive mode (OAuth2 passthrough). + """ + from fastapi import HTTPException + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/atlassian_mcp", + "headers": [ + (b"authorization", b"Bearer atlassian-oauth2-access-token-xyz"), + ], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ): + ( + auth_result, + mcp_auth_header, + mcp_servers, + mcp_server_auth_headers, + oauth2_headers, + raw_headers, + ) = await MCPRequestHandler.process_mcp_request(scope) + + # Should succeed with default UserAPIKeyAuth (OAuth2 fallback) + assert auth_result is not None + assert isinstance(auth_result, UserAPIKeyAuth) + # OAuth2 headers should contain the token for upstream forwarding + assert ( + oauth2_headers.get("Authorization") + == "Bearer atlassian-oauth2-access-token-xyz" + ) + + async def test_explicit_litellm_key_with_oauth2_authorization(self): + """ + When both x-litellm-api-key AND Authorization header are present, + LiteLLM key should be used for auth and Authorization preserved for OAuth2. + """ + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/atlassian_mcp", + "headers": [ + (b"x-litellm-api-key", b"sk-litellm-valid-key"), + (b"authorization", b"Bearer atlassian-oauth2-token"), + ], + } + + async def mock_user_api_key_auth(api_key, request): + return UserAPIKeyAuth(api_key=api_key, user_id="test-user") + + with patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth, + ) as mock_auth: + ( + auth_result, + mcp_auth_header, + mcp_servers, + mcp_server_auth_headers, + oauth2_headers, + raw_headers, + ) = await MCPRequestHandler.process_mcp_request(scope) + + # LiteLLM key should be used for auth + mock_auth.assert_called_once() + call_args = mock_auth.call_args + assert call_args.kwargs["api_key"] == "sk-litellm-valid-key" + + # OAuth2 headers should still contain the Authorization token + assert ( + oauth2_headers.get("Authorization") + == "Bearer atlassian-oauth2-token" + ) + + async def test_litellm_key_in_authorization_backward_compat(self): + """ + Backward compatibility: when only Authorization header is present + with a valid LiteLLM key (not OAuth2), auth should succeed normally. + """ + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/some_server", + "headers": [ + (b"authorization", b"Bearer sk-litellm-valid-key"), + ], + } + + async def mock_user_api_key_auth(api_key, request): + return UserAPIKeyAuth(api_key=api_key, user_id="test-user") + + with patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth, + ) as mock_auth: + ( + auth_result, + _, + _, + _, + _, + _, + ) = await MCPRequestHandler.process_mcp_request(scope) + + # Should succeed with the LiteLLM key from Authorization header + assert auth_result.api_key == "Bearer sk-litellm-valid-key" + mock_auth.assert_called_once() + + async def test_non_auth_http_exception_still_raises(self): + """ + If user_api_key_auth raises a non-401/403 HTTPException (e.g., 500), + it should NOT be caught by the OAuth2 fallback. + """ + from fastapi import HTTPException + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/some_server", + "headers": [ + (b"authorization", b"Bearer some-token"), + ], + } + + async def mock_user_api_key_auth_server_error(api_key, request): + raise HTTPException(status_code=500, detail="Internal server error") + + with patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_server_error, + ): + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 500 + + async def test_proxy_exception_oauth2_fallback(self): + """ + user_api_key_auth raises ProxyException (not HTTPException) in production. + The OAuth2 fallback must catch ProxyException with code 401/403 too. + """ + from litellm.proxy._types import ProxyException + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/atlassian_mcp", + "headers": [ + (b"authorization", b"Bearer atlassian-oauth2-access-token-xyz"), + ], + } + + async def mock_user_api_key_auth_proxy_exception(api_key, request): + raise ProxyException( + message="Authentication Error: Invalid API key", + type="auth_error", + param="api_key", + code=401, + ) + + with patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_proxy_exception, + ): + ( + auth_result, + mcp_auth_header, + mcp_servers, + mcp_server_auth_headers, + oauth2_headers, + raw_headers, + ) = await MCPRequestHandler.process_mcp_request(scope) + + # Should succeed with default UserAPIKeyAuth (OAuth2 fallback) + assert auth_result is not None + assert isinstance(auth_result, UserAPIKeyAuth) + assert ( + oauth2_headers.get("Authorization") + == "Bearer atlassian-oauth2-access-token-xyz" + ) + + async def test_proxy_exception_non_auth_still_raises(self): + """ + ProxyException with non-401/403 code should NOT be caught. + """ + from litellm.proxy._types import ProxyException + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/some_server", + "headers": [ + (b"authorization", b"Bearer some-token"), + ], + } + + async def mock_user_api_key_auth_500(api_key, request): + raise ProxyException( + message="Internal error", + type="server_error", + param=None, + code=500, + ) + + with patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_500, + ): + with pytest.raises(ProxyException): + await MCPRequestHandler.process_mcp_request(scope) + + class TestMCPCustomHeaderName: """Test suite for custom MCP authentication header name functionality"""