Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
364 changes: 251 additions & 113 deletions litellm/proxy/_experimental/mcp_server/server.py

Large diffs are not rendered by default.

78 changes: 43 additions & 35 deletions litellm/responses/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ async def aresponses_api_with_mcp(

# Process MCP tools through the complete pipeline (fetch + filter + deduplicate + transform)
# Extract user_api_key_auth from litellm_metadata (where it's added by add_user_api_key_auth_to_request_metadata)
user_api_key_auth = kwargs.get("user_api_key_auth") or kwargs.get("litellm_metadata", {}).get("user_api_key_auth")
user_api_key_auth = kwargs.get("user_api_key_auth") or kwargs.get(
"litellm_metadata", {}
).get("user_api_key_auth")

# Get original MCP tools (for events) and OpenAI tools (for LLM) by reusing existing methods
(
Expand Down Expand Up @@ -280,7 +282,7 @@ async def aresponses_api_with_mcp(
user_api_key_auth = kwargs.get("litellm_metadata", {}).get(
"user_api_key_auth"
)

# Extract MCP auth headers from the request to pass to MCP server
secret_fields: Optional[Dict[str, Any]] = kwargs.get("secret_fields")
(
Expand All @@ -292,7 +294,7 @@ async def aresponses_api_with_mcp(
secret_fields=secret_fields,
tools=tools,
)

tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
tool_server_map=tool_server_map,
tool_calls=tool_calls,
Expand All @@ -301,6 +303,8 @@ async def aresponses_api_with_mcp(
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers_from_request,
litellm_call_id=kwargs.get("litellm_call_id"),
litellm_trace_id=kwargs.get("litellm_trace_id"),
)

if tool_results:
Expand Down Expand Up @@ -349,6 +353,7 @@ async def aresponses_api_with_mcp(
tool_server_map=tool_server_map,
base_iterator=final_response,
mcp_events=tool_execution_events,
user_api_key_auth=user_api_key_auth,
)

# Add custom output elements to the final response (for non-streaming)
Expand Down Expand Up @@ -587,9 +592,12 @@ def responses(
#########################################################
# Update input with provider-specific file IDs if managed files are used
#########################################################
input = cast(Union[str, ResponseInputParam], update_responses_input_with_model_file_ids(input=input))
input = cast(
Union[str, ResponseInputParam],
update_responses_input_with_model_file_ids(input=input),
)
local_vars["input"] = input

#########################################################
# Native MCP Responses API
#########################################################
Expand Down Expand Up @@ -624,11 +632,11 @@ def responses(
)

# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)

local_vars.update(kwargs)
Expand Down Expand Up @@ -823,11 +831,11 @@ def delete_responses(
raise ValueError("custom_llm_provider is required but passed as None")

# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)

if responses_api_provider_config is None:
Expand Down Expand Up @@ -1003,11 +1011,11 @@ def get_responses(
raise ValueError("custom_llm_provider is required but passed as None")

# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)

if responses_api_provider_config is None:
Expand Down Expand Up @@ -1160,11 +1168,11 @@ def list_input_items(
if custom_llm_provider is None:
raise ValueError("custom_llm_provider is required but passed as None")

responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)

if responses_api_provider_config is None:
Expand Down Expand Up @@ -1318,11 +1326,11 @@ def cancel_responses(
raise ValueError("custom_llm_provider is required but passed as None")

# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)

if responses_api_provider_config is None:
Expand Down Expand Up @@ -1500,11 +1508,11 @@ def compact_responses(
raise ValueError("custom_llm_provider is required but passed as None")

# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)

if responses_api_provider_config is None:
Expand Down
2 changes: 2 additions & 0 deletions litellm/responses/mcp/chat_completions_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ async def acompletion_with_mcp(
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers,
litellm_call_id=kwargs.get("litellm_call_id"),
litellm_trace_id=kwargs.get("litellm_trace_id"),
)

if not tool_results:
Expand Down
Loading
Loading