Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litellm/experimental_mcp_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def _get_auth_headers(self) -> dict:
headers["X-API-Key"] = self._mcp_auth_value
elif self.auth_type == MCPAuth.authorization:
headers["Authorization"] = self._mcp_auth_value
elif self.auth_type == MCPAuth.oauth2:
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
elif isinstance(self._mcp_auth_value, dict):
headers.update(self._mcp_auth_value)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Dict, List, Optional, Set, Tuple

from fastapi import HTTPException
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.types import Scope

from litellm._logging import verbose_logger
from litellm.proxy._types import LiteLLM_TeamTable, SpecialHeaders, UserAPIKeyAuth
from litellm.proxy._types import LiteLLM_TeamTable, ProxyException, SpecialHeaders, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth


Expand Down Expand Up @@ -63,6 +64,13 @@ async def process_mcp_request(
HTTPException: If headers are invalid or missing required headers
"""
headers = MCPRequestHandler._safe_get_headers_from_scope(scope)

# Check if there is an explicit LiteLLM API key (primary header)
has_explicit_litellm_key = (
headers.get(MCPRequestHandler.LITELLM_API_KEY_HEADER_NAME_PRIMARY)
is not None
)

litellm_api_key = (
MCPRequestHandler.get_litellm_api_key_from_headers(headers) or ""
)
Expand Down Expand Up @@ -106,16 +114,38 @@ async def mock_body():
request.body = mock_body # type: ignore
if ".well-known" in str(request.url): # public routes
validated_user_api_key_auth = UserAPIKeyAuth()
# elif litellm_api_key == "":
# from fastapi import HTTPException

# raise HTTPException(
# status_code=401,
# detail="LiteLLM API key is missing. Please add it or use OAuth authentication.",
# headers={
# "WWW-Authenticate": f'Bearer resource_metadata=f"{request.base_url}/.well-known/oauth-protected-resource"',
# },
# )
elif has_explicit_litellm_key:
# Explicit x-litellm-api-key provided - always validate normally
validated_user_api_key_auth = await user_api_key_auth(
api_key=litellm_api_key, request=request
)
elif oauth2_headers:
# No x-litellm-api-key, but Authorization header present.
# Could be a LiteLLM key (backward compat) OR an OAuth2 token
# from an upstream MCP provider (e.g. Atlassian).
# Try LiteLLM auth first; on auth failure, treat as OAuth2 passthrough.
try:
validated_user_api_key_auth = await user_api_key_auth(
api_key=litellm_api_key, request=request
)
except HTTPException as e:
if e.status_code in (401, 403):
verbose_logger.debug(
"MCP OAuth2: Authorization header is not a valid LiteLLM key, "
"treating as OAuth2 token passthrough"
)
validated_user_api_key_auth = UserAPIKeyAuth()
else:
raise
except ProxyException as e:
if str(e.code) in ("401", "403"):
verbose_logger.debug(
"MCP OAuth2: Authorization header is not a valid LiteLLM key, "
"treating as OAuth2 token passthrough"
)
validated_user_api_key_auth = UserAPIKeyAuth()
else:
raise
else:
validated_user_api_key_auth = await user_api_key_auth(
api_key=litellm_api_key, request=request
Expand Down
21 changes: 21 additions & 0 deletions tests/mcp_tests/test_mcp_client_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,27 @@ def test_get_auth_headers(self):
"Authorization": "Token custom_token",
}

# OAuth2
client = MCPClient(
"http://example.com",
auth_type=MCPAuth.oauth2,
auth_value="oauth2-access-token-xyz",
)
headers = client._get_auth_headers()
assert headers == {
"Authorization": "Bearer oauth2-access-token-xyz",
}

# OAuth2 with extra_headers (per-user flow overrides auth_value)
client = MCPClient(
"http://example.com",
auth_type=MCPAuth.oauth2,
auth_value="static-server-token",
extra_headers={"Authorization": "Bearer per-user-token"},
)
headers = client._get_auth_headers()
assert headers["Authorization"] == "Bearer per-user-token"

# No auth
client = MCPClient("http://example.com")
headers = client._get_auth_headers()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,235 @@ async def test_header_extraction(self, headers, expected_result):
assert mcp_server_auth_headers == {}


@pytest.mark.asyncio
class TestMCPOAuth2AuthFlow:
"""Test suite for OAuth2 authentication flow in MCP requests.

Tests the fix for the 'Capabilities: none' bug where OAuth2 tokens
from upstream MCP providers (e.g., Atlassian) were mistakenly validated
as LiteLLM API keys, causing auth failures and empty tool listings.
"""

async def test_oauth2_token_in_authorization_header_fallback(self):
"""
When only Authorization header is present with a non-LiteLLM OAuth2 token,
auth should fall back to permissive mode (OAuth2 passthrough).
"""
from fastapi import HTTPException

scope = {
"type": "http",
"method": "POST",
"path": "/mcp/atlassian_mcp",
"headers": [
(b"authorization", b"Bearer atlassian-oauth2-access-token-xyz"),
],
}

async def mock_user_api_key_auth_fails(api_key, request):
raise HTTPException(status_code=401, detail="Invalid API key")

with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
):
(
auth_result,
mcp_auth_header,
mcp_servers,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
) = await MCPRequestHandler.process_mcp_request(scope)

# Should succeed with default UserAPIKeyAuth (OAuth2 fallback)
assert auth_result is not None
assert isinstance(auth_result, UserAPIKeyAuth)
# OAuth2 headers should contain the token for upstream forwarding
assert (
oauth2_headers.get("Authorization")
== "Bearer atlassian-oauth2-access-token-xyz"
)

async def test_explicit_litellm_key_with_oauth2_authorization(self):
"""
When both x-litellm-api-key AND Authorization header are present,
LiteLLM key should be used for auth and Authorization preserved for OAuth2.
"""
scope = {
"type": "http",
"method": "POST",
"path": "/mcp/atlassian_mcp",
"headers": [
(b"x-litellm-api-key", b"sk-litellm-valid-key"),
(b"authorization", b"Bearer atlassian-oauth2-token"),
],
}

async def mock_user_api_key_auth(api_key, request):
return UserAPIKeyAuth(api_key=api_key, user_id="test-user")

with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth,
) as mock_auth:
(
auth_result,
mcp_auth_header,
mcp_servers,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
) = await MCPRequestHandler.process_mcp_request(scope)

# LiteLLM key should be used for auth
mock_auth.assert_called_once()
call_args = mock_auth.call_args
assert call_args.kwargs["api_key"] == "sk-litellm-valid-key"

# OAuth2 headers should still contain the Authorization token
assert (
oauth2_headers.get("Authorization")
== "Bearer atlassian-oauth2-token"
)

async def test_litellm_key_in_authorization_backward_compat(self):
"""
Backward compatibility: when only Authorization header is present
with a valid LiteLLM key (not OAuth2), auth should succeed normally.
"""
scope = {
"type": "http",
"method": "POST",
"path": "/mcp/some_server",
"headers": [
(b"authorization", b"Bearer sk-litellm-valid-key"),
],
}

async def mock_user_api_key_auth(api_key, request):
return UserAPIKeyAuth(api_key=api_key, user_id="test-user")

with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth,
) as mock_auth:
(
auth_result,
_,
_,
_,
_,
_,
) = await MCPRequestHandler.process_mcp_request(scope)

# Should succeed with the LiteLLM key from Authorization header
assert auth_result.api_key == "Bearer sk-litellm-valid-key"
mock_auth.assert_called_once()

async def test_non_auth_http_exception_still_raises(self):
"""
If user_api_key_auth raises a non-401/403 HTTPException (e.g., 500),
it should NOT be caught by the OAuth2 fallback.
"""
from fastapi import HTTPException

scope = {
"type": "http",
"method": "POST",
"path": "/mcp/some_server",
"headers": [
(b"authorization", b"Bearer some-token"),
],
}

async def mock_user_api_key_auth_server_error(api_key, request):
raise HTTPException(status_code=500, detail="Internal server error")

with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_server_error,
):
with pytest.raises(HTTPException) as exc_info:
await MCPRequestHandler.process_mcp_request(scope)
assert exc_info.value.status_code == 500

async def test_proxy_exception_oauth2_fallback(self):
"""
user_api_key_auth raises ProxyException (not HTTPException) in production.
The OAuth2 fallback must catch ProxyException with code 401/403 too.
"""
from litellm.proxy._types import ProxyException

scope = {
"type": "http",
"method": "POST",
"path": "/mcp/atlassian_mcp",
"headers": [
(b"authorization", b"Bearer atlassian-oauth2-access-token-xyz"),
],
}

async def mock_user_api_key_auth_proxy_exception(api_key, request):
raise ProxyException(
message="Authentication Error: Invalid API key",
type="auth_error",
param="api_key",
code=401,
)

with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_proxy_exception,
):
(
auth_result,
mcp_auth_header,
mcp_servers,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
) = await MCPRequestHandler.process_mcp_request(scope)

# Should succeed with default UserAPIKeyAuth (OAuth2 fallback)
assert auth_result is not None
assert isinstance(auth_result, UserAPIKeyAuth)
assert (
oauth2_headers.get("Authorization")
== "Bearer atlassian-oauth2-access-token-xyz"
)

async def test_proxy_exception_non_auth_still_raises(self):
"""
ProxyException with non-401/403 code should NOT be caught.
"""
from litellm.proxy._types import ProxyException

scope = {
"type": "http",
"method": "POST",
"path": "/mcp/some_server",
"headers": [
(b"authorization", b"Bearer some-token"),
],
}

async def mock_user_api_key_auth_500(api_key, request):
raise ProxyException(
message="Internal error",
type="server_error",
param=None,
code=500,
)

with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_500,
):
with pytest.raises(ProxyException):
await MCPRequestHandler.process_mcp_request(scope)


class TestMCPCustomHeaderName:
"""Test suite for custom MCP authentication header name functionality"""

Expand Down
Loading