diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 53dc6e512c5..e0217cd9e00 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -42,6 +42,7 @@ add_server_prefix_to_name, get_server_prefix, is_tool_name_prefixed, + merge_mcp_headers, normalize_server_name, split_server_prefix_from_name, validate_mcp_server_name, @@ -372,7 +373,7 @@ def _register_openapi_tools(self, spec_path: str, server: MCPServer, base_url: s server_prefix = get_server_prefix(server) # Build headers from server configuration - headers = {} + headers: Dict[str, str] = {} # Add authentication headers if configured if server.authentication_token: @@ -385,10 +386,15 @@ def _register_openapi_tools(self, spec_path: str, server: MCPServer, base_url: s elif server.auth_type == MCPAuth.basic: headers["Authorization"] = f"Basic {server.authentication_token}" - # Add any extra headers from server config - # Note: extra_headers is a List[str] of header names to forward, not a dict - # For OpenAPI tools, we'll just use the authentication headers - # If extra_headers were needed, they would be processed separately + # Add any static headers from server config. + # + # Note: `extra_headers` on MCPServer is a List[str] of header names to forward + # from the client request (not available in this OpenAPI tool generation step). + # `static_headers` is a dict of concrete headers to always send. + headers = merge_mcp_headers( + extra_headers=headers, + static_headers=server.static_headers, + ) or {} verbose_logger.debug( f"Using headers for OpenAPI tools (excluding sensitive values): " diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index 48f7a8b0b7b..d93f852f22d 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -8,6 +8,7 @@ from litellm.proxy._experimental.mcp_server.ui_session_utils import ( build_effective_auth_contexts, ) +from litellm.proxy._experimental.mcp_server.utils import merge_mcp_headers from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.types.mcp import MCPAuth @@ -438,16 +439,22 @@ async def _execute_with_mcp_client( command=request.command, args=request.args, env=request.env, + static_headers=request.static_headers, ) stdio_env = global_mcp_server_manager._build_stdio_env( server_model, raw_headers ) + merged_headers = merge_mcp_headers( + extra_headers=oauth2_headers, + static_headers=request.static_headers, + ) + client = global_mcp_server_manager._create_mcp_client( server=server_model, mcp_auth_header=mcp_auth_header, - extra_headers=oauth2_headers, + extra_headers=merged_headers, stdio_env=stdio_env, ) diff --git a/litellm/proxy/_experimental/mcp_server/utils.py b/litellm/proxy/_experimental/mcp_server/utils.py index d801b312aac..8189f212bcb 100644 --- a/litellm/proxy/_experimental/mcp_server/utils.py +++ b/litellm/proxy/_experimental/mcp_server/utils.py @@ -1,7 +1,7 @@ """ MCP Server Utilities """ -from typing import Tuple, Any +from typing import Any, Dict, Mapping, Optional, Tuple import os import importlib @@ -137,3 +137,31 @@ def validate_mcp_server_name( ) else: raise Exception(error_message) + + +def merge_mcp_headers( + *, + extra_headers: Optional[Mapping[str, str]] = None, + static_headers: Optional[Mapping[str, str]] = None, +) -> Optional[Dict[str, str]]: + """Merge outbound HTTP headers for MCP calls. + + This is used when calling out to external MCP servers (or OpenAPI-based MCP tools). + + Merge rules: + - Start with `extra_headers` (typically OAuth2-derived headers) + - Overlay `static_headers` (user-configured per MCP server) + + If both contain the same key, `static_headers` wins. This matches the existing + behavior in `MCPServerManager` where `server.static_headers` is applied after + any caller-provided headers. + """ + merged: Dict[str, str] = {} + + if extra_headers: + merged.update({str(k): str(v) for k, v in extra_headers.items()}) + + if static_headers: + merged.update({str(k): str(v) for k, v in static_headers.items()}) + + return merged or None diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index 86abec31012..ecdc75ede52 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -1,3 +1,4 @@ +import json import importlib import logging import os @@ -989,6 +990,68 @@ def capture_create_mcp_client(server, mcp_auth_header, extra_headers, stdio_env) assert result.status == "healthy" assert result.health_check_error is None + @pytest.mark.asyncio + async def test_register_openapi_tools_includes_static_headers(self, tmp_path): + """Ensure OpenAPI-to-MCP tool calls include server.static_headers (Issue #19341).""" + manager = MCPServerManager() + + spec_path = tmp_path / "openapi.json" + spec_path.write_text( + json.dumps( + { + "openapi": "3.0.0", + "info": {"title": "Demo", "version": "1.0.0"}, + "paths": { + "/health": { + "get": { + "operationId": "health_check", + "summary": "health", + } + } + }, + } + ) + ) + + server = MCPServer( + server_id="openapi-server", + name="openapi-server", + server_name="openapi-server", + url="https://example.com", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + static_headers={"Authorization": "STATIC token"}, + ) + + captured: dict = {} + + def fake_create_tool_function(path, method, operation, base_url, headers=None): + captured["headers"] = headers + + async def tool_func(**kwargs): + return "ok" + + return tool_func + + with patch( + "litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.create_tool_function", + side_effect=fake_create_tool_function, + ), patch( + "litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.build_input_schema", + return_value={"type": "object", "properties": {}, "required": []}, + ), patch( + "litellm.proxy._experimental.mcp_server.tool_registry.global_mcp_tool_registry.register_tool", + return_value=None, + ): + manager._register_openapi_tools( + spec_path=str(spec_path), + server=server, + base_url="https://example.com", + ) + + assert captured["headers"] is not None + assert captured["headers"]["Authorization"] == "STATIC token" + @pytest.mark.asyncio async def test_pre_call_tool_check_allowed_tools_list_allows_tool(self): """Test pre_call_tool_check allows tool when it's in allowed_tools list""" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py index c026ea232b7..05f18b7054a 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py @@ -101,6 +101,59 @@ async def failing_operation(client): assert result["status"] == "error" assert "stack_trace" not in result + @pytest.mark.asyncio + async def test_forwards_static_headers(self, monkeypatch): + """Ensure static_headers are forwarded to the MCP client during test calls. + + This is required for `/mcp-rest/test/tools/list` (Issue #19341), where the UI + sends `static_headers` but the backend must forward them during + `session.initialize()` and tool discovery. + """ + captured: dict = {} + + def fake_build_stdio_env(server, raw_headers): + return None + + def fake_create_client(*args, **kwargs): + captured["extra_headers"] = kwargs.get("extra_headers") + return object() + + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, + "_build_stdio_env", + fake_build_stdio_env, + raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, + "_create_mcp_client", + fake_create_client, + raising=False, + ) + + async def ok_operation(client): + return {"status": "ok"} + + payload = NewMCPServerRequest( + server_name="example", + url="https://example.com", + auth_type=MCPAuth.none, + static_headers={"Authorization": "STATIC token"}, + ) + + result = await rest_endpoints._execute_with_mcp_client( + payload, + ok_operation, + oauth2_headers={"X-OAuth": "1"}, + raw_headers={"x-test": "y"}, + ) + + assert result["status"] == "ok" + assert captured["extra_headers"] == { + "X-OAuth": "1", + "Authorization": "STATIC token", + } + class TestTestConnection: def test_requires_auth_dependency(self):