Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
add_server_prefix_to_name,
get_server_prefix,
is_tool_name_prefixed,
merge_mcp_headers,
normalize_server_name,
split_server_prefix_from_name,
validate_mcp_server_name,
Expand Down Expand Up @@ -372,7 +373,7 @@ def _register_openapi_tools(self, spec_path: str, server: MCPServer, base_url: s
server_prefix = get_server_prefix(server)

# Build headers from server configuration
headers = {}
headers: Dict[str, str] = {}

# Add authentication headers if configured
if server.authentication_token:
Expand All @@ -385,10 +386,15 @@ def _register_openapi_tools(self, spec_path: str, server: MCPServer, base_url: s
elif server.auth_type == MCPAuth.basic:
headers["Authorization"] = f"Basic {server.authentication_token}"

# Add any extra headers from server config
# Note: extra_headers is a List[str] of header names to forward, not a dict
# For OpenAPI tools, we'll just use the authentication headers
# If extra_headers were needed, they would be processed separately
# Add any static headers from server config.
#
# Note: `extra_headers` on MCPServer is a List[str] of header names to forward
# from the client request (not available in this OpenAPI tool generation step).
# `static_headers` is a dict of concrete headers to always send.
headers = merge_mcp_headers(
extra_headers=headers,
static_headers=server.static_headers,
) or {}

verbose_logger.debug(
f"Using headers for OpenAPI tools (excluding sensitive values): "
Expand Down
9 changes: 8 additions & 1 deletion litellm/proxy/_experimental/mcp_server/rest_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from litellm.proxy._experimental.mcp_server.ui_session_utils import (
build_effective_auth_contexts,
)
from litellm.proxy._experimental.mcp_server.utils import merge_mcp_headers
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.types.mcp import MCPAuth
Expand Down Expand Up @@ -438,16 +439,22 @@ async def _execute_with_mcp_client(
command=request.command,
args=request.args,
env=request.env,
static_headers=request.static_headers,
)

stdio_env = global_mcp_server_manager._build_stdio_env(
server_model, raw_headers
)

merged_headers = merge_mcp_headers(
extra_headers=oauth2_headers,
static_headers=request.static_headers,
)

client = global_mcp_server_manager._create_mcp_client(
server=server_model,
mcp_auth_header=mcp_auth_header,
extra_headers=oauth2_headers,
extra_headers=merged_headers,
stdio_env=stdio_env,
)

Expand Down
30 changes: 29 additions & 1 deletion litellm/proxy/_experimental/mcp_server/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
MCP Server Utilities
"""
from typing import Tuple, Any
from typing import Any, Dict, Mapping, Optional, Tuple

import os
import importlib
Expand Down Expand Up @@ -137,3 +137,31 @@ def validate_mcp_server_name(
)
else:
raise Exception(error_message)


def merge_mcp_headers(
*,
extra_headers: Optional[Mapping[str, str]] = None,
static_headers: Optional[Mapping[str, str]] = None,
) -> Optional[Dict[str, str]]:
"""Merge outbound HTTP headers for MCP calls.

This is used when calling out to external MCP servers (or OpenAPI-based MCP tools).

Merge rules:
- Start with `extra_headers` (typically OAuth2-derived headers)
- Overlay `static_headers` (user-configured per MCP server)

If both contain the same key, `static_headers` wins. This matches the existing
behavior in `MCPServerManager` where `server.static_headers` is applied after
any caller-provided headers.
"""
merged: Dict[str, str] = {}

if extra_headers:
merged.update({str(k): str(v) for k, v in extra_headers.items()})

if static_headers:
merged.update({str(k): str(v) for k, v in static_headers.items()})

return merged or None
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import importlib
import logging
import os
Expand Down Expand Up @@ -989,6 +990,68 @@ def capture_create_mcp_client(server, mcp_auth_header, extra_headers, stdio_env)
assert result.status == "healthy"
assert result.health_check_error is None

@pytest.mark.asyncio
async def test_register_openapi_tools_includes_static_headers(self, tmp_path):
"""Ensure OpenAPI-to-MCP tool calls include server.static_headers (Issue #19341)."""
manager = MCPServerManager()

spec_path = tmp_path / "openapi.json"
spec_path.write_text(
json.dumps(
{
"openapi": "3.0.0",
"info": {"title": "Demo", "version": "1.0.0"},
"paths": {
"/health": {
"get": {
"operationId": "health_check",
"summary": "health",
}
}
},
}
)
)

server = MCPServer(
server_id="openapi-server",
name="openapi-server",
server_name="openapi-server",
url="https://example.com",
transport=MCPTransport.http,
auth_type=MCPAuth.none,
static_headers={"Authorization": "STATIC token"},
)

captured: dict = {}

def fake_create_tool_function(path, method, operation, base_url, headers=None):
captured["headers"] = headers

async def tool_func(**kwargs):
return "ok"

return tool_func

with patch(
"litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.create_tool_function",
side_effect=fake_create_tool_function,
), patch(
"litellm.proxy._experimental.mcp_server.openapi_to_mcp_generator.build_input_schema",
return_value={"type": "object", "properties": {}, "required": []},
), patch(
"litellm.proxy._experimental.mcp_server.tool_registry.global_mcp_tool_registry.register_tool",
return_value=None,
):
manager._register_openapi_tools(
spec_path=str(spec_path),
server=server,
base_url="https://example.com",
)

assert captured["headers"] is not None
assert captured["headers"]["Authorization"] == "STATIC token"

@pytest.mark.asyncio
async def test_pre_call_tool_check_allowed_tools_list_allows_tool(self):
"""Test pre_call_tool_check allows tool when it's in allowed_tools list"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,59 @@ async def failing_operation(client):
assert result["status"] == "error"
assert "stack_trace" not in result

@pytest.mark.asyncio
async def test_forwards_static_headers(self, monkeypatch):
"""Ensure static_headers are forwarded to the MCP client during test calls.

This is required for `/mcp-rest/test/tools/list` (Issue #19341), where the UI
sends `static_headers` but the backend must forward them during
`session.initialize()` and tool discovery.
"""
captured: dict = {}

def fake_build_stdio_env(server, raw_headers):
return None

def fake_create_client(*args, **kwargs):
captured["extra_headers"] = kwargs.get("extra_headers")
return object()

monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager,
"_build_stdio_env",
fake_build_stdio_env,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager,
"_create_mcp_client",
fake_create_client,
raising=False,
)

async def ok_operation(client):
return {"status": "ok"}

payload = NewMCPServerRequest(
server_name="example",
url="https://example.com",
auth_type=MCPAuth.none,
static_headers={"Authorization": "STATIC token"},
)

result = await rest_endpoints._execute_with_mcp_client(
payload,
ok_operation,
oauth2_headers={"X-OAuth": "1"},
raw_headers={"x-test": "y"},
)

assert result["status"] == "ok"
assert captured["extra_headers"] == {
"X-OAuth": "1",
"Authorization": "STATIC token",
}


class TestTestConnection:
def test_requires_auth_dependency(self):
Expand Down
Loading