diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 43fe54fdfb7..75f0642b950 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -1908,7 +1908,14 @@ async def pre_call_tool_check( user_api_key_auth: Optional[UserAPIKeyAuth], proxy_logging_obj: ProxyLogging, server: MCPServer, - ): + ) -> Dict[str, Any]: + """ + Run pre-call checks and guardrail hooks for an MCP tool call. + + Returns a dict that may contain: + - "arguments": hook-modified tool arguments (only if changed) + - "extra_headers": headers injected by pre_mcp_call guardrail hooks + """ ## check if the tool is allowed or banned for the given server if not self.check_allowed_or_banned_tools(name, server): raise HTTPException( @@ -1969,6 +1976,7 @@ async def pre_call_tool_check( mcp_request_obj, pre_hook_kwargs ) + hook_result: Dict[str, Any] = {} try: # Use standard pre_call_hook modified_data = await proxy_logging_obj.pre_call_hook( @@ -1984,7 +1992,9 @@ async def pre_call_tool_check( ) ) if modified_kwargs.get("arguments") != arguments: - arguments = modified_kwargs["arguments"] + hook_result["arguments"] = modified_kwargs["arguments"] + if modified_kwargs.get("extra_headers"): + hook_result["extra_headers"] = modified_kwargs["extra_headers"] except ( BlockedPiiEntityError, @@ -1995,6 +2005,8 @@ async def pre_call_tool_check( verbose_logger.error(f"Guardrail blocked MCP tool call pre call: {str(e)}") raise e + return hook_result + def _create_during_hook_task( self, name: str, @@ -2047,6 +2059,7 @@ async def _call_regular_mcp_tool( raw_headers: Optional[Dict[str, str]], proxy_logging_obj: Optional[ProxyLogging], host_progress_callback: Optional[Callable] = None, + hook_extra_headers: Optional[Dict[str, str]] = None, ) -> CallToolResult: """ Call a regular MCP tool using the MCP client. @@ -2061,6 +2074,9 @@ async def _call_regular_mcp_tool( oauth2_headers: Optional OAuth2 headers raw_headers: Optional raw headers from the request proxy_logging_obj: Optional ProxyLogging object for hook integration + host_progress_callback: Optional callback for progress updates + hook_extra_headers: Optional headers injected by pre_mcp_call guardrail + hooks. Merged last (highest priority) into outbound request headers. Returns: CallToolResult from the MCP server @@ -2116,6 +2132,11 @@ async def _call_regular_mcp_tool( extra_headers = {} extra_headers.update(mcp_server.static_headers) + if hook_extra_headers: + if extra_headers is None: + extra_headers = {} + extra_headers.update(hook_extra_headers) + stdio_env = self._build_stdio_env(mcp_server, raw_headers) client = await self._create_mcp_client( @@ -2201,8 +2222,9 @@ async def call_tool( # Allow validation and modification of tool calls before execution # Using standard pre_call_hook ######################################################### + hook_result: Dict[str, Any] = {} if proxy_logging_obj: - await self.pre_call_tool_check( + hook_result = await self.pre_call_tool_check( name=name, arguments=arguments, server_name=server_name, @@ -2210,6 +2232,23 @@ async def call_tool( proxy_logging_obj=proxy_logging_obj, server=mcp_server, ) + if "arguments" in hook_result: + arguments = hook_result["arguments"] + + # OpenAPI-backed servers cannot forward hook-injected headers — reject early + # before scheduling any background tasks to avoid orphaned asyncio.Tasks. + if mcp_server.spec_path and hook_result.get("extra_headers"): + raise HTTPException( + status_code=400, + detail={ + "error": ( + "pre_mcp_call hook returned extra_headers for an " + "OpenAPI-backed MCP server, which does not support " + "hook header injection. Use a regular MCP server " + "(SSE/HTTP transport) for hook header support." + ) + }, + ) # Prepare tasks for during hooks tasks = [] @@ -2247,6 +2286,7 @@ async def call_tool( raw_headers=raw_headers, proxy_logging_obj=proxy_logging_obj, host_progress_callback=host_progress_callback, + hook_extra_headers=hook_result.get("extra_headers"), ) # For OpenAPI tools, await outside the client context diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index ecbd7314cd7..55cb1c9c43c 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2471,6 +2471,7 @@ class UserAPIKeyAuth( Any ] = None # Expanded created_by user when expand=user is used end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None + jwt_claims: Optional[Dict] = None model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 376048e7a13..451ed56339d 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -700,6 +700,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ) if valid_token is not None: api_key = valid_token.token or "" + valid_token.jwt_claims = jwt_claims do_standard_jwt_auth = False # Fall through to virtual key checks @@ -729,6 +730,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 team_membership: Optional[LiteLLM_TeamMembership] = result.get( "team_membership", None ) + jwt_claims: Optional[dict] = result.get("jwt_claims", None) global_proxy_spend = await get_global_proxy_spend( litellm_proxy_admin_name=litellm_proxy_admin_name, @@ -757,6 +759,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 org_id=org_id, end_user_id=end_user_id, parent_otel_span=parent_otel_span, + jwt_claims=jwt_claims, ) valid_token = UserAPIKeyAuth( @@ -803,6 +806,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 team_metadata=( team_object.metadata if team_object is not None else None ), + jwt_claims=jwt_claims, ) # Check if model has zero cost - if so, skip all budget checks diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 01a0f55aac7..8536704766b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -824,17 +824,22 @@ def _convert_mcp_hook_response_to_kwargs( ) -> dict: """ Helper function to convert pre_call_hook response back to kwargs for MCP usage. + + Supports: + - modified_arguments: Override tool call arguments + - extra_headers: Inject custom headers into the outbound MCP request """ if not response_data: return original_kwargs - # Apply any argument modifications from the hook response modified_kwargs = original_kwargs.copy() - # If the response contains modified arguments, apply them if response_data.get("modified_arguments"): modified_kwargs["arguments"] = response_data["modified_arguments"] + if response_data.get("extra_headers"): + modified_kwargs["extra_headers"] = response_data["extra_headers"] + return modified_kwargs async def process_pre_call_hook_response(self, response, data, call_type): diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py new file mode 100644 index 00000000000..8362193bb0b --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_hook_extra_headers.py @@ -0,0 +1,701 @@ +""" +Tests for pre_mcp_call guardrail hook header mutation support. + +Validates that: +1. _convert_mcp_hook_response_to_kwargs extracts extra_headers from hook response +2. pre_call_tool_check returns hook-provided extra_headers AND modified arguments +3. call_tool flows hook headers and modified arguments downstream +4. Hook-provided headers take highest priority (merge after static_headers) +5. OpenAPI-backed servers raise HTTPException when hook headers are present +6. JWT claims are propagated in both standard and virtual-key fast paths +7. Backward compatibility: hooks without extra_headers continue to work +""" + +import asyncio +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServerManager +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.utils import ProxyLogging +from litellm.types.mcp import MCPAuth, MCPTransport +from litellm.types.mcp_server.mcp_server_manager import MCPServer + + +class TestConvertMcpHookResponseToKwargs: + """Tests for ProxyLogging._convert_mcp_hook_response_to_kwargs""" + + def setup_method(self): + self.proxy_logging = ProxyLogging(user_api_key_cache=MagicMock()) + + def test_returns_original_kwargs_when_response_is_none(self): + original = {"arguments": {"key": "val"}, "name": "tool"} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + None, original + ) + assert result == original + + def test_returns_original_kwargs_when_response_is_empty_dict(self): + original = {"arguments": {"key": "val"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs({}, original) + assert result == original + + def test_extracts_modified_arguments(self): + original = {"arguments": {"old": "value"}} + response = {"modified_arguments": {"new": "value"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["arguments"] == {"new": "value"} + + def test_extracts_extra_headers(self): + original = {"arguments": {"key": "val"}} + response = {"extra_headers": {"Authorization": "Bearer signed-jwt"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["extra_headers"] == {"Authorization": "Bearer signed-jwt"} + + def test_extracts_both_arguments_and_headers(self): + original = {"arguments": {"old": "value"}} + response = { + "modified_arguments": {"new": "value"}, + "extra_headers": {"X-Custom": "header-val"}, + } + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert result["arguments"] == {"new": "value"} + assert result["extra_headers"] == {"X-Custom": "header-val"} + + def test_no_extra_headers_key_preserves_original(self): + """Backward compat: hooks that only return modified_arguments still work.""" + original = {"arguments": {"key": "val"}} + response = {"modified_arguments": {"key": "new_val"}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert "extra_headers" not in result + assert result["arguments"] == {"key": "new_val"} + + def test_empty_extra_headers_not_set(self): + """Empty dict for extra_headers is falsy and should not be set.""" + original = {"arguments": {"key": "val"}} + response = {"extra_headers": {}} + result = self.proxy_logging._convert_mcp_hook_response_to_kwargs( + response, original + ) + assert "extra_headers" not in result + + +class TestPreCallToolCheckReturnsHeaders: + """Tests that pre_call_tool_check returns hook-provided headers.""" + + def _make_server(self, name="test_server"): + return MCPServer( + server_id="test-id", + name=name, + server_name=name, + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + ) + + @pytest.mark.asyncio + async def test_returns_empty_dict_when_hook_has_no_headers(self): + manager = MCPServerManager() + server = self._make_server() + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"modified_arguments": {"key": "val"}} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": {"key": "val"}} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result == {} + + @pytest.mark.asyncio + async def test_returns_extra_headers_from_hook(self): + manager = MCPServerManager() + server = self._make_server() + + hook_headers = {"Authorization": "Bearer signed-jwt", "X-Trace-Id": "abc123"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"extra_headers": hook_headers} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": {"key": "val"}, "extra_headers": hook_headers} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["extra_headers"] == hook_headers + + @pytest.mark.asyncio + async def test_returns_empty_dict_when_hook_returns_none(self): + manager = MCPServerManager() + server = self._make_server() + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock(return_value=None) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "val"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result == {} + + @pytest.mark.asyncio + async def test_returns_modified_arguments_from_hook(self): + """Modified arguments from the hook must be returned so the caller can use them.""" + manager = MCPServerManager() + server = self._make_server() + + original_args = {"key": "original"} + modified_args = {"key": "modified", "extra": "added"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock( + return_value={"modified_arguments": modified_args} + ) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": modified_args} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments=original_args, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["arguments"] == modified_args + + @pytest.mark.asyncio + async def test_returns_both_modified_arguments_and_headers(self): + """Hook can modify both arguments and inject headers simultaneously.""" + manager = MCPServerManager() + server = self._make_server() + + modified_args = {"key": "modified"} + hook_headers = {"Authorization": "Bearer jwt"} + + proxy_logging = MagicMock(spec=ProxyLogging) + proxy_logging._create_mcp_request_object_from_kwargs = MagicMock( + return_value=MagicMock() + ) + proxy_logging._convert_mcp_to_llm_format = MagicMock( + return_value={"model": "fake"} + ) + proxy_logging.pre_call_hook = AsyncMock(return_value={"dummy": True}) + proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock( + return_value={"arguments": modified_args, "extra_headers": hook_headers} + ) + + with patch.object(manager, "check_allowed_or_banned_tools", return_value=True): + with patch.object( + manager, + "check_tool_permission_for_key_team", + new_callable=AsyncMock, + ): + with patch.object(manager, "validate_allowed_params"): + result = await manager.pre_call_tool_check( + name="test_tool", + arguments={"key": "original"}, + server_name="test_server", + user_api_key_auth=None, + proxy_logging_obj=proxy_logging, + server=server, + ) + + assert result["arguments"] == modified_args + assert result["extra_headers"] == hook_headers + + +class TestCallToolFlowsHookHeaders: + """Tests that call_tool passes hook_extra_headers to _call_regular_mcp_tool.""" + + def _make_server(self, name="test_server"): + return MCPServer( + server_id="test-id", + name=name, + server_name=name, + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + ) + + @pytest.mark.asyncio + async def test_hook_headers_passed_to_call_regular_mcp_tool(self): + """Verify that hook_extra_headers kwarg is forwarded.""" + manager = MCPServerManager() + server = self._make_server() + + hook_headers = {"Authorization": "Bearer signed-jwt"} + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"extra_headers": hook_headers}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "val"}, + proxy_logging_obj=proxy_logging, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("hook_extra_headers") == hook_headers + + @pytest.mark.asyncio + async def test_no_hook_headers_when_no_proxy_logging(self): + """Without proxy_logging_obj, no pre_call_tool_check runs.""" + manager = MCPServerManager() + server = self._make_server() + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "val"}, + proxy_logging_obj=None, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("hook_extra_headers") is None + + @pytest.mark.asyncio + async def test_modified_arguments_passed_to_downstream(self): + """Hook-modified arguments must be used for the actual tool call.""" + manager = MCPServerManager() + server = self._make_server() + + modified_args = {"key": "modified_by_hook"} + + with patch.object( + manager, + "_get_mcp_server_from_tool_name", + return_value=server, + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"arguments": modified_args}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_regular_mcp_tool", + new_callable=AsyncMock, + return_value=MagicMock(), + ) as mock_call: + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="test_server", + name="test_tool", + arguments={"key": "original"}, + proxy_logging_obj=proxy_logging, + ) + + mock_call.assert_called_once() + call_kwargs = mock_call.call_args + assert call_kwargs.kwargs.get("arguments") == modified_args + + @pytest.mark.asyncio + async def test_openapi_server_raises_on_hook_headers(self): + """OpenAPI-backed servers should raise HTTPException when hook injects headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="openapi_server", + server_name="openapi_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + spec_path="/path/to/spec.yaml", + ) + + with patch.object( + manager, "_get_mcp_server_from_tool_name", return_value=server + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={"extra_headers": {"Authorization": "Bearer jwt"}}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + proxy_logging = MagicMock(spec=ProxyLogging) + + with pytest.raises(HTTPException) as exc_info: + await manager.call_tool( + server_name="openapi_server", + name="test_tool", + arguments={}, + proxy_logging_obj=proxy_logging, + ) + + assert exc_info.value.status_code == 400 + assert "does not support hook header injection" in str( + exc_info.value.detail + ) + + @pytest.mark.asyncio + async def test_openapi_server_no_error_without_hook_headers(self): + """No exception when OpenAPI server has no hook-injected headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="openapi_server", + server_name="openapi_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + spec_path="/path/to/spec.yaml", + ) + + with patch.object( + manager, "_get_mcp_server_from_tool_name", return_value=server + ): + with patch.object( + manager, + "pre_call_tool_check", + new_callable=AsyncMock, + return_value={}, + ): + with patch.object( + manager, + "_create_during_hook_task", + return_value=asyncio.create_task(asyncio.sleep(0)), + ): + with patch.object( + manager, + "_call_openapi_tool_handler", + new_callable=AsyncMock, + return_value=MagicMock(), + ): + proxy_logging = MagicMock(spec=ProxyLogging) + + await manager.call_tool( + server_name="openapi_server", + name="test_tool", + arguments={}, + proxy_logging_obj=proxy_logging, + ) + + +class TestHookHeaderMergePriority: + """Tests that hook-provided headers have highest priority in _call_regular_mcp_tool.""" + + def _make_server( + self, + static_headers: Optional[Dict[str, str]] = None, + extra_headers_config: Optional[list] = None, + ): + return MCPServer( + server_id="test-id", + name="Test Server", + server_name="test_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + static_headers=static_headers, + extra_headers=extra_headers_config, + ) + + @pytest.mark.asyncio + async def test_hook_headers_override_static_headers(self): + """Hook headers should take precedence over static_headers.""" + manager = MCPServerManager() + server = self._make_server( + static_headers={"Authorization": "Bearer static-token", "X-Static": "yes"} + ) + + hook_headers = {"Authorization": "Bearer hook-signed-jwt"} + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers=None, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers=hook_headers, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers["Authorization"] == "Bearer hook-signed-jwt" + assert headers["X-Static"] == "yes" + + @pytest.mark.asyncio + async def test_no_hook_headers_preserves_existing_behavior(self): + """When hook_extra_headers is None, existing header logic is unchanged.""" + manager = MCPServerManager() + server = self._make_server( + static_headers={"X-Static": "static-value"} + ) + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers=None, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers=None, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers == {"X-Static": "static-value"} + + @pytest.mark.asyncio + async def test_hook_headers_merge_with_oauth2(self): + """Hook headers merge on top of OAuth2 headers.""" + manager = MCPServerManager() + server = MCPServer( + server_id="test-id", + name="Test Server", + server_name="test_server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + ) + + captured_extra_headers: Dict[str, Any] = {} + + async def fake_create_mcp_client( + server, mcp_auth_header=None, extra_headers=None, stdio_env=None + ): + captured_extra_headers["value"] = extra_headers + mock_client = MagicMock() + mock_client.call_tool = AsyncMock(return_value=MagicMock()) + return mock_client + + with patch.object( + manager, "_create_mcp_client", side_effect=fake_create_mcp_client + ): + with patch.object(manager, "_build_stdio_env", return_value=None): + try: + await manager._call_regular_mcp_tool( + mcp_server=server, + original_tool_name="test_tool", + arguments={"key": "val"}, + tasks=[], + mcp_auth_header=None, + mcp_server_auth_headers=None, + oauth2_headers={ + "Authorization": "Bearer oauth2-token", + "X-OAuth": "yes", + }, + raw_headers=None, + proxy_logging_obj=None, + hook_extra_headers={ + "Authorization": "Bearer hook-jwt", + "X-Trace-Id": "trace-123", + }, + ) + except Exception: + pass + + headers = captured_extra_headers.get("value", {}) + assert headers["Authorization"] == "Bearer hook-jwt" + assert headers["X-OAuth"] == "yes" + assert headers["X-Trace-Id"] == "trace-123" + + +class TestUserAPIKeyAuthJwtClaims: + """Tests that UserAPIKeyAuth correctly carries jwt_claims.""" + + def test_jwt_claims_field_defaults_to_none(self): + auth = UserAPIKeyAuth(api_key="test-key") + assert auth.jwt_claims is None + + def test_jwt_claims_field_accepts_dict(self): + claims = {"sub": "user-123", "iss": "litellm", "exp": 9999999999} + auth = UserAPIKeyAuth(api_key="test-key", jwt_claims=claims) + assert auth.jwt_claims == claims + assert auth.jwt_claims["sub"] == "user-123" + + def test_jwt_claims_backward_compatible_without_field(self): + """Existing code that doesn't pass jwt_claims should still work.""" + auth = UserAPIKeyAuth( + api_key="test-key", + user_id="user-1", + team_id="team-1", + ) + assert auth.jwt_claims is None + assert auth.user_id == "user-1" + + def test_jwt_claims_set_after_construction(self): + """Virtual-key fast path sets jwt_claims after the object is created.""" + auth = UserAPIKeyAuth(api_key="test-key") + assert auth.jwt_claims is None + + claims = {"sub": "user-456", "iss": "okta", "groups": ["admin"]} + auth.jwt_claims = claims + assert auth.jwt_claims == claims + assert auth.jwt_claims["groups"] == ["admin"]