From 6267f1689b96725bc00f28386c85eccbb42d7fb6 Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Wed, 21 Jan 2026 12:00:22 +0900 Subject: [PATCH 1/7] feat: save mcp fail log --- .../proxy/_experimental/mcp_server/server.py | 138 ++++++++++-------- 1 file changed, 74 insertions(+), 64 deletions(-) diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 76a28344856..e32ae85f9ad 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -6,6 +6,7 @@ import asyncio import contextlib from datetime import datetime +import traceback from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union, cast from fastapi import FastAPI, HTTPException @@ -13,6 +14,7 @@ from starlette.types import Receive, Scope, Send from litellm._logging import verbose_logger +from litellm.constants import MAXIMUM_TRACEBACK_LINES_TO_LOG from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( MCPRequestHandler, @@ -25,7 +27,7 @@ from litellm.proxy._types import UserAPIKeyAuth from litellm.types.mcp import MCPAuth from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer -from litellm.types.utils import StandardLoggingMCPToolCall +from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall from litellm.utils import client # Check if MCP is available @@ -1320,33 +1322,6 @@ async def execute_mcp_tool( content=cast(Any, local_content), isError=False ) - ######################################################### - # Post MCP Tool Call Hook - # Allow modifying the MCP tool call response before it is returned to the user - ######################################################### - if litellm_logging_obj: - litellm_logging_obj.post_call(original_response=response) - end_time = datetime.now() - await litellm_logging_obj.async_post_mcp_tool_call_hook( - kwargs=litellm_logging_obj.model_call_details, - response_obj=response, - start_time=start_time, - end_time=end_time, - ) - # Set call_type to call_mcp_tool so cost calculator recognizes it - from litellm.types.utils import CallTypes - - litellm_logging_obj.call_type = CallTypes.call_mcp_tool.value - # Trigger success logging to build standard_logging_object and call callbacks - # async_success_handler will: - # 1. Call _success_handler_helper_fn which recognizes call_mcp_tool - # 2. Call _process_hidden_params_and_response_cost which: - # - Calculates cost via _response_cost_calculator -> MCPCostCalculator - # - Builds standard_logging_object - # 3. Call async_log_success_event on all callbacks - await litellm_logging_obj.async_success_handler( - result=response, start_time=start_time, end_time=end_time - ) return response @client @@ -1365,49 +1340,84 @@ async def call_mcp_tool( Call a specific tool with the provided arguments (handles prefixed tool names). """ start_time = datetime.now() - if arguments is None: - raise HTTPException( - status_code=400, detail="Request arguments are required" - ) + litellm_logging_obj: Optional[LiteLLMLoggingObj] = kwargs.get( + "litellm_logging_obj", None + ) - ## CHECK IF USER IS ALLOWED TO CALL THIS TOOL - allowed_mcp_server_ids = ( - await global_mcp_server_manager.get_allowed_mcp_servers( - user_api_key_auth=user_api_key_auth, + try: + if arguments is None: + raise HTTPException( + status_code=400, detail="Request arguments are required" + ) + + ## CHECK IF USER IS ALLOWED TO CALL THIS TOOL + allowed_mcp_server_ids = ( + await global_mcp_server_manager.get_allowed_mcp_servers( + user_api_key_auth=user_api_key_auth, + ) ) - ) - allowed_mcp_servers: List[MCPServer] = [] - for allowed_mcp_server_id in allowed_mcp_server_ids: - allowed_server = global_mcp_server_manager.get_mcp_server_by_id( - allowed_mcp_server_id + allowed_mcp_servers: List[MCPServer] = [] + for allowed_mcp_server_id in allowed_mcp_server_ids: + allowed_server = global_mcp_server_manager.get_mcp_server_by_id( + allowed_mcp_server_id + ) + if allowed_server is not None: + allowed_mcp_servers.append(allowed_server) + + allowed_mcp_servers = await _get_allowed_mcp_servers_from_mcp_server_names( + mcp_servers=mcp_servers, + allowed_mcp_servers=allowed_mcp_servers, ) - if allowed_server is not None: - allowed_mcp_servers.append(allowed_server) + if not allowed_mcp_servers: + raise HTTPException( + status_code=403, + detail="User not allowed to call this tool.", + ) - allowed_mcp_servers = await _get_allowed_mcp_servers_from_mcp_server_names( - mcp_servers=mcp_servers, - allowed_mcp_servers=allowed_mcp_servers, - ) - if not allowed_mcp_servers: - raise HTTPException( - status_code=403, - detail="User not allowed to call this tool.", + # Delegate to execute_mcp_tool for execution + response = await execute_mcp_tool( + name=name, + arguments=arguments, + allowed_mcp_servers=allowed_mcp_servers, + start_time=start_time, + user_api_key_auth=user_api_key_auth, + mcp_auth_header=mcp_auth_header, + mcp_server_auth_headers=mcp_server_auth_headers, + oauth2_headers=oauth2_headers, + raw_headers=raw_headers, + **kwargs, ) + except Exception as e: + traceback_str = traceback.format_exc( + limit=MAXIMUM_TRACEBACK_LINES_TO_LOG + ) + from litellm.proxy.proxy_server import proxy_logging_obj + + if proxy_logging_obj and user_api_key_auth: + await proxy_logging_obj.post_call_failure_hook( + request_data=kwargs, + original_exception=e, + user_api_key_dict=user_api_key_auth, + route="/mcp/call_tool", + traceback_str=traceback_str, + ) + raise - # Delegate to execute_mcp_tool for execution - return await execute_mcp_tool( - name=name, - arguments=arguments, - allowed_mcp_servers=allowed_mcp_servers, - start_time=start_time, - user_api_key_auth=user_api_key_auth, - mcp_auth_header=mcp_auth_header, - mcp_server_auth_headers=mcp_server_auth_headers, - oauth2_headers=oauth2_headers, - raw_headers=raw_headers, - **kwargs, - ) + if litellm_logging_obj: + litellm_logging_obj.post_call(original_response=response) + end_time = datetime.now() + await litellm_logging_obj.async_post_mcp_tool_call_hook( + kwargs=litellm_logging_obj.model_call_details, + response_obj=response, + start_time=start_time, + end_time=end_time, + ) + litellm_logging_obj.call_type = CallTypes.call_mcp_tool.value + await litellm_logging_obj.async_success_handler( + result=response, start_time=start_time, end_time=end_time + ) + return response async def mcp_get_prompt( name: str, From ae4d92ad509ebadbb3e72dc790f6621ca2e13de3 Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Wed, 21 Jan 2026 13:54:24 +0900 Subject: [PATCH 2/7] feat: save mcp call log via responses --- litellm/responses/main.py | 3 + .../responses/mcp/chat_completions_handler.py | 2 + .../mcp/litellm_proxy_mcp_handler.py | 190 +++++++++++++++++- .../responses/mcp/mcp_streaming_iterator.py | 4 + 4 files changed, 196 insertions(+), 3 deletions(-) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 71d94287e82..78d358a1e39 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -301,6 +301,8 @@ async def aresponses_api_with_mcp( mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers_from_request, + litellm_call_id=kwargs.get("litellm_call_id"), + litellm_trace_id=kwargs.get("litellm_trace_id"), ) if tool_results: @@ -349,6 +351,7 @@ async def aresponses_api_with_mcp( tool_server_map=tool_server_map, base_iterator=final_response, mcp_events=tool_execution_events, + user_api_key_auth=user_api_key_auth, ) # Add custom output elements to the final response (for non-streaming) diff --git a/litellm/responses/mcp/chat_completions_handler.py b/litellm/responses/mcp/chat_completions_handler.py index 6ce59e3e67f..26853b30596 100644 --- a/litellm/responses/mcp/chat_completions_handler.py +++ b/litellm/responses/mcp/chat_completions_handler.py @@ -142,6 +142,8 @@ async def acompletion_with_mcp( mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers, + litellm_call_id=kwargs.get("litellm_call_id"), + litellm_trace_id=kwargs.get("litellm_trace_id"), ) if not tool_results: diff --git a/litellm/responses/mcp/litellm_proxy_mcp_handler.py b/litellm/responses/mcp/litellm_proxy_mcp_handler.py index 9cdcd3894e0..f5757e4d52e 100644 --- a/litellm/responses/mcp/litellm_proxy_mcp_handler.py +++ b/litellm/responses/mcp/litellm_proxy_mcp_handler.py @@ -1,3 +1,5 @@ +import traceback +from datetime import datetime from typing import ( TYPE_CHECKING, Any, @@ -11,14 +13,18 @@ ) from litellm._logging import verbose_logger +from litellm.constants import MAXIMUM_TRACEBACK_LINES_TO_LOG +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.proxy._experimental.mcp_server.utils import split_server_prefix_from_name from litellm.responses.main import aresponses from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator from litellm.types.llms.openai import ResponsesAPIResponse, ToolParam -from litellm.types.utils import Choices, ModelResponse +from litellm.types.utils import CallTypes, Choices, ModelResponse, StandardLoggingMCPToolCall +from litellm.utils import Rules, function_setup if TYPE_CHECKING: from mcp.types import Tool as MCPTool + from litellm.proxy.utils import ProxyLogging else: MCPTool = Any @@ -470,6 +476,8 @@ async def _execute_tool_calls( mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None, oauth2_headers: Optional[Dict[str, str]] = None, raw_headers: Optional[Dict[str, str]] = None, + litellm_call_id: Optional[str] = None, + litellm_trace_id: Optional[str] = None, ) -> List[Dict[str, Any]]: """Execute tool calls and return results.""" from fastapi import HTTPException @@ -478,10 +486,16 @@ async def _execute_tool_calls( from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( global_mcp_server_manager, ) + from litellm.proxy.proxy_server import proxy_logging_obj + + from litellm._uuid import uuid tool_results = [] tool_call_id: Optional[str] = None + rules_obj = Rules() for tool_call in tool_calls: + logging_request_data: Dict[str, Any] = {} + tool_name: str = "" try: ( tool_name, @@ -514,6 +528,101 @@ async def _execute_tool_calls( ): sanitized_tool_name = unprefixed_name + start_time = datetime.now() + logging_input = [ + { + "role": "tool", + "content": { + "tool_name": sanitized_tool_name, + "arguments": parsed_arguments, + }, + } + ] + tool_logging_call_id = litellm_call_id or str(uuid.uuid4()) + logging_request_data: Dict[str, Any] = { + "model": f"MCP: {tool_name}", + "metadata": { + "tool_call_id": tool_call_id, + "tool_name": sanitized_tool_name, + "server_name": server_name, + }, + "input": logging_input, + "call_type": CallTypes.call_mcp_tool.value, + "litellm_call_id": tool_logging_call_id, + } + if litellm_trace_id: + logging_request_data["litellm_trace_id"] = litellm_trace_id + user_identifier = None + if user_api_key_auth is not None: + user_api_key = getattr(user_api_key_auth, "api_key", None) + if user_api_key: + logging_request_data["metadata"]["user_api_key"] = user_api_key + + user_identifier = getattr(user_api_key_auth, "end_user_id", None) or getattr( + user_api_key_auth, "user_id", None + ) + if user_identifier: + logging_request_data["user"] = user_identifier + + litellm_logging_obj: Optional[LiteLLMLoggingObj] = None + try: + litellm_logging_obj, _ = function_setup( + original_function="call_mcp_tool", + rules_obj=rules_obj, + start_time=start_time, + **logging_request_data, + ) + except Exception as logging_error: + verbose_logger.debug( + "Failed to initialize logging for MCP tool call %s: %s", + tool_name, + logging_error, + ) + litellm_logging_obj = None + + logging_request_data["litellm_logging_obj"] = litellm_logging_obj + logging_request_data["arguments"] = parsed_arguments + + if litellm_logging_obj: + try: + litellm_logging_obj.pre_call( + input=logging_input, + api_key="", + ) + except Exception: + verbose_logger.exception( + "Failed to run pre_call for MCP tool logging" + ) + + standard_logging_mcp_tool_call: StandardLoggingMCPToolCall = { + "name": sanitized_tool_name, + "arguments": parsed_arguments, + "namespaced_tool_name": tool_name, + } + mcp_server = global_mcp_server_manager._get_mcp_server_from_tool_name( + tool_name + ) + if mcp_server: + mcp_info = mcp_server.mcp_info or {} + standard_logging_mcp_tool_call["mcp_server_name"] = ( + mcp_info.get("server_name") + or getattr(mcp_server, "server_name", None) + or server_name + ) + logo_url = mcp_info.get("logo_url") + if logo_url: + standard_logging_mcp_tool_call["mcp_server_logo_url"] = logo_url + cost_info = mcp_info.get("mcp_server_cost_info") + if cost_info: + standard_logging_mcp_tool_call["mcp_server_cost_info"] = cost_info + + if litellm_logging_obj: + litellm_logging_obj.model_call_details[ + "mcp_tool_call_metadata" + ] = standard_logging_mcp_tool_call + litellm_logging_obj.model = f"MCP: {tool_name}" + litellm_logging_obj.call_type = CallTypes.call_mcp_tool.value + result = await global_mcp_server_manager.call_tool( server_name=server_name, name=sanitized_tool_name, @@ -526,6 +635,26 @@ async def _execute_tool_calls( proxy_logging_obj=proxy_logging_obj, ) + if litellm_logging_obj: + try: + litellm_logging_obj.post_call(original_response=result) + end_time = datetime.now() + await litellm_logging_obj.async_post_mcp_tool_call_hook( + kwargs=litellm_logging_obj.model_call_details, + response_obj=result, + start_time=start_time, + end_time=end_time, + ) + await litellm_logging_obj.async_success_handler( + result=result, + start_time=start_time, + end_time=end_time, + ) + except Exception: + verbose_logger.exception( + "Failed to log MCP tool call success for %s", tool_name + ) + # Format result for inclusion in response result_text = LiteLLM_Proxy_MCP_Handler._parse_mcp_result(result) tool_results.append( @@ -537,6 +666,12 @@ async def _execute_tool_calls( ) except BlockedPiiEntityError as e: + await LiteLLM_Proxy_MCP_Handler._log_mcp_tool_failure( + proxy_logging_obj=proxy_logging_obj, + user_api_key_auth=user_api_key_auth, + request_data=logging_request_data, + error=e, + ) verbose_logger.error( f"BlockedPiiEntityError in MCP tool call: {str(e)}" ) @@ -549,6 +684,12 @@ async def _execute_tool_calls( } ) except GuardrailRaisedException as e: + await LiteLLM_Proxy_MCP_Handler._log_mcp_tool_failure( + proxy_logging_obj=proxy_logging_obj, + user_api_key_auth=user_api_key_auth, + request_data=logging_request_data, + error=e, + ) verbose_logger.error( f"GuardrailRaisedException in MCP tool call: {str(e)}" ) @@ -561,12 +702,28 @@ async def _execute_tool_calls( } ) except HTTPException as e: + await LiteLLM_Proxy_MCP_Handler._log_mcp_tool_failure( + proxy_logging_obj=proxy_logging_obj, + user_api_key_auth=user_api_key_auth, + request_data=logging_request_data, + error=e, + ) verbose_logger.error(f"HTTPException in MCP tool call: {str(e)}") error_message = f"Tool call failed: {str(e.detail) if hasattr(e, 'detail') else str(e)}" tool_results.append( - {"tool_call_id": tool_call_id, "result": error_message} + { + "tool_call_id": tool_call_id, + "result": error_message, + "name": tool_name, + } ) except Exception as e: + await LiteLLM_Proxy_MCP_Handler._log_mcp_tool_failure( + proxy_logging_obj=proxy_logging_obj, + user_api_key_auth=user_api_key_auth, + request_data=logging_request_data, + error=e, + ) verbose_logger.exception(f"Error executing MCP tool call: {e}") tool_results.append( { @@ -718,6 +875,33 @@ async def _make_follow_up_call( **call_params, ) + @staticmethod + async def _log_mcp_tool_failure( + *, + proxy_logging_obj: Optional["ProxyLogging"], + user_api_key_auth: Any, + request_data: Dict[str, Any], + error: Exception, + ) -> None: + """Log MCP tool failures via proxy logging hooks.""" + + if proxy_logging_obj is None or user_api_key_auth is None: + return + + try: + traceback_str = traceback.format_exc( + limit=MAXIMUM_TRACEBACK_LINES_TO_LOG + ) + await proxy_logging_obj.post_call_failure_hook( + request_data=request_data, + original_exception=error, + user_api_key_dict=user_api_key_auth, + route="/responses/mcp/call_tool", + traceback_str=traceback_str, + ) + except Exception: + verbose_logger.exception("Failed to log MCP tool call failure") + @staticmethod def _create_mcp_streaming_response( input: Union[str, Any], @@ -758,7 +942,7 @@ def _create_mcp_streaming_response( mcp_events=mcp_discovery_events, # Pre-generated MCP discovery events tool_server_map=tool_server_map, mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy, - user_api_key_auth=kwargs.get("user_api_key_auth"), + user_api_key_auth=kwargs.get("user_api_key_auth") or kwargs.get("litellm_metadata", {}).get("user_api_key_auth"), original_request_params=request_params, ) diff --git a/litellm/responses/mcp/mcp_streaming_iterator.py b/litellm/responses/mcp/mcp_streaming_iterator.py index ac040d3d6ec..53f39164e91 100644 --- a/litellm/responses/mcp/mcp_streaming_iterator.py +++ b/litellm/responses/mcp/mcp_streaming_iterator.py @@ -298,6 +298,8 @@ def __init__( self.custom_llm_provider = self.original_request_params.get( "custom_llm_provider", None ) + self.litellm_call_id = self.original_request_params.get("litellm_call_id") + self.litellm_trace_id = self.original_request_params.get("litellm_trace_id") self._extract_mcp_headers_from_params() @@ -568,6 +570,8 @@ async def _generate_tool_execution_events(self) -> None: mcp_server_auth_headers=self.mcp_server_auth_headers, oauth2_headers=self.oauth2_headers, raw_headers=self.raw_headers, + litellm_call_id=self.litellm_call_id, + litellm_trace_id=self.litellm_trace_id, ) # Create completion events and output_item.done events for tool execution From 872e5b98977eac7eea95be6fd6069e751709960a Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Wed, 21 Jan 2026 14:32:08 +0900 Subject: [PATCH 3/7] feat: log mcp list_tools calls to SpendLogs --- .../proxy/_experimental/mcp_server/server.py | 224 ++++++++++++++---- .../mcp/litellm_proxy_mcp_handler.py | 12 +- litellm/types/utils.py | 2 + 3 files changed, 187 insertions(+), 51 deletions(-) diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index e32ae85f9ad..52117d86066 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -7,6 +7,7 @@ import contextlib from datetime import datetime import traceback +import uuid from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, Union, cast from fastapi import FastAPI, HTTPException @@ -28,7 +29,7 @@ from litellm.types.mcp import MCPAuth from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall -from litellm.utils import client +from litellm.utils import Rules, client, function_setup # Check if MCP is available # "mcp" requires python 3.10 or higher, but several litellm users use python 3.8 @@ -228,6 +229,8 @@ async def list_tools() -> List[MCPTool]: mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers, + log_list_tools_to_spendlogs=True, + list_tools_log_source="mcp_protocol", ) verbose_logger.info( f"MCP list_tools - Successfully returned {len(tools)} tools" @@ -742,6 +745,8 @@ async def _get_tools_from_mcp_servers( mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None, oauth2_headers: Optional[Dict[str, str]] = None, raw_headers: Optional[Dict[str, str]] = None, + log_list_tools_to_spendlogs: bool = False, + list_tools_log_source: Optional[str] = None, ) -> List[MCPTool]: """ Helper method to fetch tools from MCP servers based on server filtering criteria. @@ -759,67 +764,186 @@ async def _get_tools_from_mcp_servers( if not MCP_AVAILABLE: return [] - allowed_mcp_servers = await _get_allowed_mcp_servers( - user_api_key_auth=user_api_key_auth, - mcp_servers=mcp_servers, - ) + list_tools_start_time = datetime.now() + litellm_logging_obj: Optional[LiteLLMLoggingObj] = None + list_tools_request_data: Dict[str, Any] = {} - # Decide whether to add prefix based on number of allowed servers - add_prefix = not (len(allowed_mcp_servers) == 1) + if log_list_tools_to_spendlogs: + # This is intentionally minimal: only async_success_handler / post_call_failure_hook + rules_obj = Rules() + list_tools_call_id = str(uuid.uuid4()) + spend_logs_metadata: Dict[str, Any] = { + "mcp_operation": "list_tools", + } + if isinstance(list_tools_log_source, str): + spend_logs_metadata["source"] = list_tools_log_source + if isinstance(mcp_servers, list): + spend_logs_metadata["requested_mcp_servers"] = mcp_servers + + list_tools_request_data = { + "model": "MCP: list_tools", + "call_type": CallTypes.list_mcp_tools.value, + "litellm_call_id": list_tools_call_id, + "metadata": { + "spend_logs_metadata": spend_logs_metadata, + }, + # Provide a small input payload for standard logging + "input": [ + { + "role": "system", + "content": { + "mcp_operation": "list_tools", + "requested_mcp_servers": mcp_servers, + }, + } + ], + } - async def _fetch_and_filter_server_tools(server: MCPServer) -> List[MCPTool]: - """Fetch and filter tools from a single server with error handling.""" - if server is None: - return [] + # Attach user identifiers when available (matches call_mcp_tool style) + if user_api_key_auth is not None: + user_api_key = getattr(user_api_key_auth, "api_key", None) + if user_api_key: + cast(dict, list_tools_request_data["metadata"])[ + "user_api_key" + ] = user_api_key + + user_identifier = getattr(user_api_key_auth, "end_user_id", None) or getattr( + user_api_key_auth, "user_id", None + ) + if user_identifier: + list_tools_request_data["user"] = user_identifier - server_auth_header, extra_headers = _prepare_mcp_server_headers( - server=server, - mcp_server_auth_headers=mcp_server_auth_headers, - mcp_auth_header=mcp_auth_header, - oauth2_headers=oauth2_headers, - raw_headers=raw_headers, + try: + litellm_logging_obj, _ = function_setup( + original_function="list_mcp_tools", + rules_obj=rules_obj, + start_time=list_tools_start_time, + **list_tools_request_data, + ) + if litellm_logging_obj: + litellm_logging_obj.call_type = CallTypes.list_mcp_tools.value + litellm_logging_obj.model = "MCP: list_tools" + except Exception as logging_error: + verbose_logger.debug( + "Failed to initialize logging for MCP list_tools: %s", logging_error + ) + litellm_logging_obj = None + + try: + allowed_mcp_servers = await _get_allowed_mcp_servers( + user_api_key_auth=user_api_key_auth, + mcp_servers=mcp_servers, ) - try: - tools = await global_mcp_server_manager._get_tools_from_server( + # Decide whether to add prefix based on number of allowed servers + add_prefix = not (len(allowed_mcp_servers) == 1) + + async def _fetch_and_filter_server_tools(server: MCPServer) -> List[MCPTool]: + """Fetch and filter tools from a single server with error handling.""" + if server is None: + return [] + + server_auth_header, extra_headers = _prepare_mcp_server_headers( server=server, - mcp_auth_header=server_auth_header, - extra_headers=extra_headers, - add_prefix=add_prefix, + mcp_server_auth_headers=mcp_server_auth_headers, + mcp_auth_header=mcp_auth_header, + oauth2_headers=oauth2_headers, raw_headers=raw_headers, ) - filtered_tools = filter_tools_by_allowed_tools(tools, server) - filtered_tools = await filter_tools_by_key_team_permissions( - tools=filtered_tools, - server_id=server.server_id, - user_api_key_auth=user_api_key_auth, - ) + try: + tools = await global_mcp_server_manager._get_tools_from_server( + server=server, + mcp_auth_header=server_auth_header, + extra_headers=extra_headers, + add_prefix=add_prefix, + raw_headers=raw_headers, + ) + filtered_tools = filter_tools_by_allowed_tools(tools, server) - verbose_logger.debug( - f"Successfully fetched {len(tools)} tools from server {server.name}, {len(filtered_tools)} after filtering" - ) - return filtered_tools - except Exception as e: - verbose_logger.exception( - f"Error getting tools from server {server.name}: {str(e)}" - ) - return [] + filtered_tools = await filter_tools_by_key_team_permissions( + tools=filtered_tools, + server_id=server.server_id, + user_api_key_auth=user_api_key_auth, + ) - # Fetch tools from all servers in parallel - tasks = [ - _fetch_and_filter_server_tools(server) for server in allowed_mcp_servers - ] - results = await asyncio.gather(*tasks) + verbose_logger.debug( + f"Successfully fetched {len(tools)} tools from server {server.name}, {len(filtered_tools)} after filtering" + ) + return filtered_tools + except Exception as e: + verbose_logger.exception( + f"Error getting tools from server {server.name}: {str(e)}" + ) + return [] - # Flatten results into single list - all_tools: List[MCPTool] = [tool for tools in results for tool in tools] + # Fetch tools from all servers in parallel + tasks = [ + _fetch_and_filter_server_tools(server) for server in allowed_mcp_servers + ] + results = await asyncio.gather(*tasks) + + # Flatten results into single list + all_tools: List[MCPTool] = [tool for tools in results for tool in tools] + + # If logging is enabled, enrich spend_logs_metadata with counts + if litellm_logging_obj: + per_server_tool_counts: Dict[str, int] = {} + for server, server_tools in zip(allowed_mcp_servers, results): + if server is None: + continue + server_key = ( + getattr(server, "server_name", None) + or getattr(server, "alias", None) + or getattr(server, "name", None) + or "unknown" + ) + per_server_tool_counts[str(server_key)] = len(server_tools) + + metadata_dict = litellm_logging_obj.model_call_details.get("metadata") + if isinstance(metadata_dict, dict): + spend_meta = metadata_dict.get("spend_logs_metadata") + if not isinstance(spend_meta, dict): + spend_meta = {} + metadata_dict["spend_logs_metadata"] = spend_meta + spend_meta["allowed_server_count"] = len(allowed_mcp_servers) + spend_meta["tool_count_total"] = len(all_tools) + spend_meta["per_server_tool_counts"] = per_server_tool_counts + + end_time = datetime.now() + await litellm_logging_obj.async_success_handler( + result=all_tools, + start_time=list_tools_start_time, + end_time=end_time, + ) - verbose_logger.info( - f"Successfully fetched {len(all_tools)} tools total from all MCP servers" - ) + verbose_logger.info( + f"Successfully fetched {len(all_tools)} tools total from all MCP servers" + ) - return all_tools + return all_tools + except Exception as e: + # Only fire failure hook if logging was requested for this list-tools execution + if log_list_tools_to_spendlogs and user_api_key_auth is not None: + try: + from litellm.proxy.proxy_server import proxy_logging_obj + + if proxy_logging_obj: + traceback_str = traceback.format_exc( + limit=MAXIMUM_TRACEBACK_LINES_TO_LOG + ) + await proxy_logging_obj.post_call_failure_hook( + request_data=list_tools_request_data or {}, + original_exception=e, + user_api_key_dict=user_api_key_auth, + route="/mcp/list_tools", + traceback_str=traceback_str, + ) + except Exception: + verbose_logger.debug( + "Failed to log MCP list_tools failure via post_call_failure_hook" + ) + raise async def _get_prompts_from_mcp_servers( user_api_key_auth: Optional[UserAPIKeyAuth], @@ -1052,6 +1176,8 @@ async def _list_mcp_tools( mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None, oauth2_headers: Optional[Dict[str, str]] = None, raw_headers: Optional[Dict[str, str]] = None, + log_list_tools_to_spendlogs: bool = False, + list_tools_log_source: Optional[str] = None, ) -> List[MCPTool]: """ List all available MCP tools. @@ -1077,6 +1203,8 @@ async def _list_mcp_tools( mcp_server_auth_headers=mcp_server_auth_headers, oauth2_headers=oauth2_headers, raw_headers=raw_headers, + log_list_tools_to_spendlogs=log_list_tools_to_spendlogs, + list_tools_log_source=list_tools_log_source, ) verbose_logger.debug( f"Successfully fetched {len(managed_tools)} tools from managed MCP servers" diff --git a/litellm/responses/mcp/litellm_proxy_mcp_handler.py b/litellm/responses/mcp/litellm_proxy_mcp_handler.py index f5757e4d52e..e22898ae0ff 100644 --- a/litellm/responses/mcp/litellm_proxy_mcp_handler.py +++ b/litellm/responses/mcp/litellm_proxy_mcp_handler.py @@ -18,7 +18,7 @@ from litellm.proxy._experimental.mcp_server.utils import split_server_prefix_from_name from litellm.responses.main import aresponses from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator -from litellm.types.llms.openai import ResponsesAPIResponse, ToolParam +from litellm.types.llms.openai import ResponsesAPIResponse from litellm.types.utils import CallTypes, Choices, ModelResponse, StandardLoggingMCPToolCall from litellm.utils import Rules, function_setup @@ -28,6 +28,10 @@ else: MCPTool = Any +# NOTE: We intentionally keep ToolParam as a broad type here to avoid tight coupling +# to optional OpenAI SDK typing symbols in environments that may not have them available. +ToolParam = Dict[str, Any] + LITELLM_PROXY_MCP_SERVER_URL = "litellm_proxy" LITELLM_PROXY_MCP_SERVER_URL_PREFIX = f"{LITELLM_PROXY_MCP_SERVER_URL}/mcp/" @@ -123,6 +127,8 @@ async def _get_mcp_tools_from_manager( mcp_auth_header=None, mcp_servers=mcp_servers, mcp_server_auth_headers=None, + log_list_tools_to_spendlogs=True, + list_tools_log_source="responses", ) allowed_mcp_server_ids = ( await global_mcp_server_manager.get_allowed_mcp_servers(user_api_key_auth) @@ -495,7 +501,7 @@ async def _execute_tool_calls( rules_obj = Rules() for tool_call in tool_calls: logging_request_data: Dict[str, Any] = {} - tool_name: str = "" + tool_name: Optional[str] = None try: ( tool_name, @@ -539,7 +545,7 @@ async def _execute_tool_calls( } ] tool_logging_call_id = litellm_call_id or str(uuid.uuid4()) - logging_request_data: Dict[str, Any] = { + logging_request_data = { "model": f"MCP: {tool_name}", "metadata": { "tool_call_id": tool_call_id, diff --git a/litellm/types/utils.py b/litellm/types/utils.py index f5e217d8b46..324380db2b4 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -384,6 +384,7 @@ class CallTypes(str, Enum): # MCP Call Types ######################################################### call_mcp_tool = "call_mcp_tool" + list_mcp_tools = "list_mcp_tools" ######################################################### # A2A Call Types @@ -448,6 +449,7 @@ class CallTypes(str, Enum): "vector_store_file_delete", "avector_store_file_delete", "call_mcp_tool", + "list_mcp_tools", "asend_message", "send_message", "aresponses", From caf5f7f8aeb248b64fa95d246b7376a8428656ce Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Wed, 21 Jan 2026 14:51:56 +0900 Subject: [PATCH 4/7] test: add test --- .../mcp_server/test_mcp_server.py | 135 ++++++++++++++++ .../mcp/test_litellm_proxy_mcp_handler.py | 151 +++++++++++++++++- 2 files changed, 283 insertions(+), 3 deletions(-) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 7aa176aaac5..8c29cf6a590 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -1817,3 +1817,138 @@ async def test_rebuilds_server_when_updated_at_changes(self): mock_get_all.assert_awaited_once() mock_build.assert_awaited_once_with(db_row) assert manager.registry["server-1"] is rebuilt_server + + +@pytest.mark.asyncio +async def test_call_mcp_tool_logs_failure_via_post_call_failure_hook(): + """ + Regression test for 6267f168...: + Ensure proxy-side `call_mcp_tool` logs failures via `proxy_logging_obj.post_call_failure_hook`. + """ + try: + from litellm.proxy._experimental.mcp_server.server import ( + call_mcp_tool, + global_mcp_server_manager, + ) + from litellm.proxy._types import MCPTransport, UserAPIKeyAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + except ImportError: + pytest.skip("MCP server not available") + + mock_server = MCPServer( + server_id="server-123", + name="test_server", + alias="test_server", + server_name="test_server", + url="https://test-server.com/mcp", + transport=MCPTransport.http, + mcp_info={"server_name": "test_server"}, + ) + + proxy_logging_mock = MagicMock() + proxy_logging_mock.post_call_failure_hook = AsyncMock() + + user_auth = UserAPIKeyAuth(api_key="test-key", user_id="test-user") + + with patch.object( + global_mcp_server_manager, + "get_allowed_mcp_servers", + new_callable=AsyncMock, + return_value=[mock_server.server_id], + ), patch.object( + global_mcp_server_manager, + "get_mcp_server_by_id", + return_value=mock_server, + ), patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers_from_mcp_server_names", + new_callable=AsyncMock, + return_value=[mock_server], + ), patch( + "litellm.proxy._experimental.mcp_server.server.execute_mcp_tool", + new_callable=AsyncMock, + side_effect=Exception("boom"), + ), patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + proxy_logging_mock, + ): + with pytest.raises(Exception): + await call_mcp_tool( + name="test_server-any_tool", + arguments={"x": 1}, + user_api_key_auth=user_auth, + litellm_call_id="cid", + ) + + proxy_logging_mock.post_call_failure_hook.assert_awaited_once() + assert ( + proxy_logging_mock.post_call_failure_hook.await_args.kwargs.get("route") + == "/mcp/call_tool" + ) + + +@pytest.mark.asyncio +async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enabled(): + """ + Regression test for 872e5b98...: + Ensure list-tools logging path calls `async_success_handler` when enabled. + """ + try: + from litellm.proxy._experimental.mcp_server.server import _get_tools_from_mcp_servers + from litellm.proxy._types import UserAPIKeyAuth + except ImportError: + pytest.skip("MCP server not available") + + user_auth = UserAPIKeyAuth(api_key="test-key", user_id="test-user") + + server_a = MagicMock(name="server_a_obj") + server_a.name = "server_a" + server_a.alias = "server_a" + server_a.server_name = "server_a" + server_a.server_id = "a" + server_a.auth_type = None + server_a.extra_headers = None + + tool_1 = MagicMock() + tool_1.name = "server_a-tool_1" + + dummy_logging_obj = MagicMock() + dummy_logging_obj.model_call_details = {"metadata": {"spend_logs_metadata": {}}} + dummy_logging_obj.async_success_handler = AsyncMock() + + with patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + new=AsyncMock(return_value=[server_a]), + ), patch( + "litellm.proxy._experimental.mcp_server.server._prepare_mcp_server_headers", + return_value=(None, None), + ), patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, patch( + "litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools", + side_effect=lambda tools, _server: tools, + ), patch( + "litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions", + new=AsyncMock(side_effect=lambda tools, **_: tools), + ), patch( + "litellm.proxy._experimental.mcp_server.server.function_setup", + return_value=(dummy_logging_obj, None), + ): + mock_manager._get_tools_from_server = AsyncMock(return_value=[tool_1]) + + tools = await _get_tools_from_mcp_servers( + user_api_key_auth=user_auth, + mcp_auth_header=None, + mcp_servers=["server_a"], + mcp_server_auth_headers=None, + log_list_tools_to_spendlogs=True, + list_tools_log_source="mcp_protocol", + ) + + assert tools == [tool_1] + dummy_logging_obj.async_success_handler.assert_awaited_once() + assert dummy_logging_obj.async_success_handler.await_args.kwargs["result"] == [tool_1] + + spend_meta = dummy_logging_obj.model_call_details["metadata"]["spend_logs_metadata"] + assert spend_meta["tool_count_total"] == 1 + assert spend_meta["allowed_server_count"] == 1 + assert spend_meta["per_server_tool_counts"]["server_a"] == 1 diff --git a/tests/test_litellm/responses/mcp/test_litellm_proxy_mcp_handler.py b/tests/test_litellm/responses/mcp/test_litellm_proxy_mcp_handler.py index b632e72f567..15fdc7bd0c4 100644 --- a/tests/test_litellm/responses/mcp/test_litellm_proxy_mcp_handler.py +++ b/tests/test_litellm/responses/mcp/test_litellm_proxy_mcp_handler.py @@ -1,12 +1,15 @@ import sys import types -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest +from fastapi import HTTPException +import importlib from litellm.responses.mcp.litellm_proxy_mcp_handler import ( LiteLLM_Proxy_MCP_Handler, ) +from typing import Any, cast from litellm.types.utils import ModelResponse from litellm.types.responses.main import OutputFunctionToolCall @@ -22,7 +25,9 @@ def _setup_mcp_call_environment(monkeypatch: pytest.MonkeyPatch) -> AsyncMock: monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_module) fake_manager = types.SimpleNamespace( - call_tool=AsyncMock(return_value=_DummyMCPResult()) + call_tool=AsyncMock(return_value=_DummyMCPResult()), + # Newer logging path calls this to enrich spend logs metadata + _get_mcp_server_from_tool_name=MagicMock(return_value=None), ) monkeypatch.setattr( "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager", @@ -31,6 +36,15 @@ def _setup_mcp_call_environment(monkeypatch: pytest.MonkeyPatch) -> AsyncMock: return fake_manager.call_tool +def _setup_proxy_logging(monkeypatch: pytest.MonkeyPatch) -> AsyncMock: + """Patch proxy_logging_obj so failure hook can be asserted.""" + proxy_logging_obj = MagicMock() + proxy_logging_obj.post_call_failure_hook = AsyncMock() + proxy_module = types.SimpleNamespace(proxy_logging_obj=proxy_logging_obj) + monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_module) + return proxy_logging_obj.post_call_failure_hook + + def test_deduplicate_mcp_tools_single_allowed_server(): tools = [{"name": "search"}, {"name": "search"}] # duplicate on purpose @@ -184,7 +198,7 @@ def test_create_follow_up_input_handles_response_function_tool_call(): ) follow_up = LiteLLM_Proxy_MCP_Handler._create_follow_up_input( - response=response, + response=cast(Any, response), tool_results=[], original_input=None, ) @@ -216,6 +230,8 @@ async def test_execute_tool_calls_strips_server_prefix(monkeypatch): user_api_key_auth=None, ) + assert call_tool_mock.await_count == 1 + assert call_tool_mock.await_args is not None assert call_tool_mock.await_args.kwargs["name"] == "read_wiki_structure" @@ -236,6 +252,8 @@ async def test_execute_tool_calls_keeps_tool_name_without_prefix(monkeypatch): user_api_key_auth=None, ) + assert call_tool_mock.await_count == 1 + assert call_tool_mock.await_args is not None assert call_tool_mock.await_args.kwargs["name"] == tool_name @@ -256,4 +274,131 @@ async def test_execute_tool_calls_keeps_tool_name_when_equal_to_server(monkeypat user_api_key_auth=None, ) + assert call_tool_mock.await_count == 1 + assert call_tool_mock.await_args is not None assert call_tool_mock.await_args.kwargs["name"] == tool_name + + +@pytest.mark.asyncio +async def test_execute_tool_calls_logs_failure_via_post_call_failure_hook(monkeypatch): + """ + Regression test for ae4d92ad...: + Ensure responses-side MCP tool execution logs failures via proxy_logging_obj.post_call_failure_hook. + """ + post_call_failure_hook = _setup_proxy_logging(monkeypatch) + + fake_manager = types.SimpleNamespace( + call_tool=AsyncMock( + side_effect=HTTPException(status_code=500, detail="boom") + ) + ) + monkeypatch.setattr( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager", + fake_manager, + ) + + tool_name = "deepwiki-read_wiki_structure" + tool_calls = [ + {"id": "call-err", "function": {"name": tool_name, "arguments": "{}"}} + ] + + user_auth = types.SimpleNamespace(api_key="test_key", user_id="test_user") + + results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls( + tool_server_map={tool_name: "deepwiki"}, + tool_calls=tool_calls, + user_api_key_auth=user_auth, + litellm_call_id="cid", + litellm_trace_id="tid", + ) + + assert len(results) == 1 + assert results[0]["tool_call_id"] == "call-err" + assert results[0]["name"] == tool_name + + post_call_failure_hook.assert_awaited_once() + assert post_call_failure_hook.await_args is not None + assert ( + post_call_failure_hook.await_args.kwargs.get("route") + == "/responses/mcp/call_tool" + ) + + +@pytest.mark.asyncio +async def test_execute_tool_calls_passes_litellm_call_id_and_trace_id_to_function_setup( + monkeypatch, +): + """ + Regression test for ae4d92ad...: + Ensure litellm_call_id / litellm_trace_id are forwarded into function_setup kwargs. + """ + _setup_proxy_logging(monkeypatch) + call_tool_mock = _setup_mcp_call_environment(monkeypatch) + + captured = {} + + def fake_function_setup(*_args, **kwargs): + captured.update(kwargs) + return None, None + + # NOTE: Don't patch via dotted string path here because `litellm.responses` + # is a function attribute on the `litellm` package (shadowing the submodule), + # which breaks monkeypatch's importpath resolution. + handler_module = importlib.import_module( + "litellm.responses.mcp.litellm_proxy_mcp_handler" + ) + monkeypatch.setattr(handler_module, "function_setup", fake_function_setup) + + tool_name = "deepwiki-read_wiki_structure" + tool_calls = [ + {"id": "call-1", "function": {"name": tool_name, "arguments": "{}"}} + ] + + await LiteLLM_Proxy_MCP_Handler._execute_tool_calls( + tool_server_map={tool_name: "deepwiki"}, + tool_calls=tool_calls, + user_api_key_auth=None, + litellm_call_id="cid", + litellm_trace_id="tid", + ) + + # Ensure the tool call was attempted (sanity) + assert call_tool_mock.await_count == 1 + + assert captured.get("litellm_call_id") == "cid" + assert captured.get("litellm_trace_id") == "tid" + + +@pytest.mark.asyncio +async def test_get_mcp_tools_from_manager_enables_list_tools_logging(monkeypatch): + """ + Regression test for 872e5b98...: + Ensure responses-side tool discovery enables list-tools SpendLogs logging flags. + """ + mock_get_tools = AsyncMock(return_value=[]) + monkeypatch.setattr( + "litellm.proxy._experimental.mcp_server.server._get_tools_from_mcp_servers", + mock_get_tools, + ) + + # Patch manager methods used by _get_mcp_tools_from_manager to avoid needing full UserAPIKeyAuth fields. + fake_manager = types.SimpleNamespace( + get_allowed_mcp_servers=AsyncMock(return_value=[]), + get_mcp_servers_from_ids=MagicMock(return_value=[]), + ) + monkeypatch.setattr( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager", + fake_manager, + ) + + user_auth = types.SimpleNamespace(api_key="test_key", user_id="test_user") + tools, _server_names = await LiteLLM_Proxy_MCP_Handler._get_mcp_tools_from_manager( + user_api_key_auth=user_auth, + mcp_tools_with_litellm_proxy=[{"type": "mcp", "server_url": "litellm_proxy/mcp/deepwiki"}], + ) + + assert tools == [] + assert mock_get_tools.await_count == 1 + assert mock_get_tools.await_args is not None + assert mock_get_tools.await_args.kwargs["log_list_tools_to_spendlogs"] is True + assert mock_get_tools.await_args.kwargs["list_tools_log_source"] == "responses" From 72cbc295e597621b08bcfb7c1a89cd721d5e7685 Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Wed, 21 Jan 2026 14:54:50 +0900 Subject: [PATCH 5/7] chore: format --- .../proxy/_experimental/mcp_server/server.py | 16 ++-- litellm/responses/main.py | 75 ++++++++++--------- .../mcp/litellm_proxy_mcp_handler.py | 24 +++--- .../responses/mcp/mcp_streaming_iterator.py | 67 ++++++++++------- litellm/types/utils.py | 10 +-- 5 files changed, 110 insertions(+), 82 deletions(-) diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 52117d86066..4e5c73be806 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -807,9 +807,9 @@ async def _get_tools_from_mcp_servers( "user_api_key" ] = user_api_key - user_identifier = getattr(user_api_key_auth, "end_user_id", None) or getattr( - user_api_key_auth, "user_id", None - ) + user_identifier = getattr( + user_api_key_auth, "end_user_id", None + ) or getattr(user_api_key_auth, "user_id", None) if user_identifier: list_tools_request_data["user"] = user_identifier @@ -838,7 +838,9 @@ async def _get_tools_from_mcp_servers( # Decide whether to add prefix based on number of allowed servers add_prefix = not (len(allowed_mcp_servers) == 1) - async def _fetch_and_filter_server_tools(server: MCPServer) -> List[MCPTool]: + async def _fetch_and_filter_server_tools( + server: MCPServer, + ) -> List[MCPTool]: """Fetch and filter tools from a single server with error handling.""" if server is None: return [] @@ -1517,11 +1519,9 @@ async def call_mcp_tool( **kwargs, ) except Exception as e: - traceback_str = traceback.format_exc( - limit=MAXIMUM_TRACEBACK_LINES_TO_LOG - ) + traceback_str = traceback.format_exc(limit=MAXIMUM_TRACEBACK_LINES_TO_LOG) from litellm.proxy.proxy_server import proxy_logging_obj - + if proxy_logging_obj and user_api_key_auth: await proxy_logging_obj.post_call_failure_hook( request_data=kwargs, diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 78d358a1e39..83c23a58500 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -169,7 +169,9 @@ async def aresponses_api_with_mcp( # Process MCP tools through the complete pipeline (fetch + filter + deduplicate + transform) # Extract user_api_key_auth from litellm_metadata (where it's added by add_user_api_key_auth_to_request_metadata) - user_api_key_auth = kwargs.get("user_api_key_auth") or kwargs.get("litellm_metadata", {}).get("user_api_key_auth") + user_api_key_auth = kwargs.get("user_api_key_auth") or kwargs.get( + "litellm_metadata", {} + ).get("user_api_key_auth") # Get original MCP tools (for events) and OpenAI tools (for LLM) by reusing existing methods ( @@ -280,7 +282,7 @@ async def aresponses_api_with_mcp( user_api_key_auth = kwargs.get("litellm_metadata", {}).get( "user_api_key_auth" ) - + # Extract MCP auth headers from the request to pass to MCP server secret_fields: Optional[Dict[str, Any]] = kwargs.get("secret_fields") ( @@ -292,7 +294,7 @@ async def aresponses_api_with_mcp( secret_fields=secret_fields, tools=tools, ) - + tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls( tool_server_map=tool_server_map, tool_calls=tool_calls, @@ -590,9 +592,12 @@ def responses( ######################################################### # Update input with provider-specific file IDs if managed files are used ######################################################### - input = cast(Union[str, ResponseInputParam], update_responses_input_with_model_file_ids(input=input)) + input = cast( + Union[str, ResponseInputParam], + update_responses_input_with_model_file_ids(input=input), + ) local_vars["input"] = input - + ######################################################### # Native MCP Responses API ######################################################### @@ -627,11 +632,11 @@ def responses( ) # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=model, - provider=litellm.LlmProviders(custom_llm_provider), - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=model, + provider=litellm.LlmProviders(custom_llm_provider), ) local_vars.update(kwargs) @@ -826,11 +831,11 @@ def delete_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), ) if responses_api_provider_config is None: @@ -1006,11 +1011,11 @@ def get_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), ) if responses_api_provider_config is None: @@ -1163,11 +1168,11 @@ def list_input_items( if custom_llm_provider is None: raise ValueError("custom_llm_provider is required but passed as None") - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), ) if responses_api_provider_config is None: @@ -1321,11 +1326,11 @@ def cancel_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=None, - provider=litellm.LlmProviders(custom_llm_provider), - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=litellm.LlmProviders(custom_llm_provider), ) if responses_api_provider_config is None: @@ -1503,11 +1508,11 @@ def compact_responses( raise ValueError("custom_llm_provider is required but passed as None") # get provider config - responses_api_provider_config: Optional[BaseResponsesAPIConfig] = ( - ProviderConfigManager.get_provider_responses_api_config( - model=model, - provider=litellm.LlmProviders(custom_llm_provider), - ) + responses_api_provider_config: Optional[ + BaseResponsesAPIConfig + ] = ProviderConfigManager.get_provider_responses_api_config( + model=model, + provider=litellm.LlmProviders(custom_llm_provider), ) if responses_api_provider_config is None: diff --git a/litellm/responses/mcp/litellm_proxy_mcp_handler.py b/litellm/responses/mcp/litellm_proxy_mcp_handler.py index e22898ae0ff..7cc821f5e02 100644 --- a/litellm/responses/mcp/litellm_proxy_mcp_handler.py +++ b/litellm/responses/mcp/litellm_proxy_mcp_handler.py @@ -19,7 +19,12 @@ from litellm.responses.main import aresponses from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator from litellm.types.llms.openai import ResponsesAPIResponse -from litellm.types.utils import CallTypes, Choices, ModelResponse, StandardLoggingMCPToolCall +from litellm.types.utils import ( + CallTypes, + Choices, + ModelResponse, + StandardLoggingMCPToolCall, +) from litellm.utils import Rules, function_setup if TYPE_CHECKING: @@ -564,9 +569,9 @@ async def _execute_tool_calls( if user_api_key: logging_request_data["metadata"]["user_api_key"] = user_api_key - user_identifier = getattr(user_api_key_auth, "end_user_id", None) or getattr( - user_api_key_auth, "user_id", None - ) + user_identifier = getattr( + user_api_key_auth, "end_user_id", None + ) or getattr(user_api_key_auth, "user_id", None) if user_identifier: logging_request_data["user"] = user_identifier @@ -620,7 +625,9 @@ async def _execute_tool_calls( standard_logging_mcp_tool_call["mcp_server_logo_url"] = logo_url cost_info = mcp_info.get("mcp_server_cost_info") if cost_info: - standard_logging_mcp_tool_call["mcp_server_cost_info"] = cost_info + standard_logging_mcp_tool_call[ + "mcp_server_cost_info" + ] = cost_info if litellm_logging_obj: litellm_logging_obj.model_call_details[ @@ -895,9 +902,7 @@ async def _log_mcp_tool_failure( return try: - traceback_str = traceback.format_exc( - limit=MAXIMUM_TRACEBACK_LINES_TO_LOG - ) + traceback_str = traceback.format_exc(limit=MAXIMUM_TRACEBACK_LINES_TO_LOG) await proxy_logging_obj.post_call_failure_hook( request_data=request_data, original_exception=error, @@ -948,7 +953,8 @@ def _create_mcp_streaming_response( mcp_events=mcp_discovery_events, # Pre-generated MCP discovery events tool_server_map=tool_server_map, mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy, - user_api_key_auth=kwargs.get("user_api_key_auth") or kwargs.get("litellm_metadata", {}).get("user_api_key_auth"), + user_api_key_auth=kwargs.get("user_api_key_auth") + or kwargs.get("litellm_metadata", {}).get("user_api_key_auth"), original_request_params=request_params, ) diff --git a/litellm/responses/mcp/mcp_streaming_iterator.py b/litellm/responses/mcp/mcp_streaming_iterator.py index 53f39164e91..731aa5c692b 100644 --- a/litellm/responses/mcp/mcp_streaming_iterator.py +++ b/litellm/responses/mcp/mcp_streaming_iterator.py @@ -273,9 +273,9 @@ def __init__( self.finished = False # Event queues and generation flags - self.mcp_discovery_events: List[ResponsesAPIStreamingResponse] = ( - mcp_events # Pre-generated MCP discovery events - ) + self.mcp_discovery_events: List[ + ResponsesAPIStreamingResponse + ] = mcp_events # Pre-generated MCP discovery events self.tool_execution_events: List[ResponsesAPIStreamingResponse] = [] self.mcp_discovery_generated = True # Events are already generated self.mcp_events = ( @@ -284,9 +284,9 @@ def __init__( self.tool_server_map = tool_server_map # Iterator references - self.base_iterator: Optional[Union[Any, ResponsesAPIResponse]] = ( - base_iterator # Will be created when needed - ) + self.base_iterator: Optional[ + Union[Any, ResponsesAPIResponse] + ] = base_iterator # Will be created when needed self.follow_up_iterator: Optional[Any] = None # Response collection for tool execution @@ -305,7 +305,7 @@ def __init__( # Mark as async iterator self.is_async = True - + def _extract_mcp_headers_from_params(self) -> None: """Extract MCP headers from original request params to pass to tool calls""" from typing import Dict, Optional @@ -313,25 +313,31 @@ def _extract_mcp_headers_from_params(self) -> None: from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( MCPRequestHandler, ) - + # Extract headers from secret_fields in original_request_params raw_headers_from_request: Optional[Dict[str, str]] = None secret_fields = self.original_request_params.get("secret_fields") if secret_fields and isinstance(secret_fields, dict): raw_headers_from_request = secret_fields.get("raw_headers") - + # Extract MCP-specific headers self.mcp_auth_header: Optional[str] = None self.mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None self.oauth2_headers: Optional[Dict[str, str]] = None self.raw_headers: Optional[Dict[str, str]] = raw_headers_from_request - + if raw_headers_from_request: headers_obj = Headers(raw_headers_from_request) - self.mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(headers_obj) - self.mcp_server_auth_headers = MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj) - self.oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(headers_obj) - + self.mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers( + headers_obj + ) + self.mcp_server_auth_headers = ( + MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj) + ) + self.oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers( + headers_obj + ) + # Also check if headers are provided in tools array (from request body) tools = self.original_request_params.get("tools") if tools: @@ -341,17 +347,26 @@ def _extract_mcp_headers_from_params(self) -> None: if tool_headers and isinstance(tool_headers, dict): # Merge tool headers into mcp_server_auth_headers headers_obj_from_tool = Headers(tool_headers) - tool_mcp_server_auth_headers = MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj_from_tool) - + tool_mcp_server_auth_headers = ( + MCPRequestHandler._get_mcp_server_auth_headers_from_headers( + headers_obj_from_tool + ) + ) + if tool_mcp_server_auth_headers: if self.mcp_server_auth_headers is None: self.mcp_server_auth_headers = {} # Merge the headers from tool into existing headers - for server_alias, headers_dict in tool_mcp_server_auth_headers.items(): + for ( + server_alias, + headers_dict, + ) in tool_mcp_server_auth_headers.items(): if server_alias not in self.mcp_server_auth_headers: self.mcp_server_auth_headers[server_alias] = {} - self.mcp_server_auth_headers[server_alias].update(headers_dict) - + self.mcp_server_auth_headers[server_alias].update( + headers_dict + ) + # Also merge raw headers if self.raw_headers is None: self.raw_headers = {} @@ -489,9 +504,9 @@ async def _create_initial_response_iterator(self) -> None: # Use the pre-fetched all_tools from original_request_params (no re-processing needed) params_for_llm = {} for key, value in params.items(): - params_for_llm[key] = ( - value # Copy all params as-is since tools are already processed - ) + params_for_llm[ + key + ] = value # Copy all params as-is since tools are already processed tools_count = ( len(params_for_llm.get("tools", [])) @@ -545,9 +560,11 @@ async def _generate_tool_execution_events(self) -> None: return for tool_call in tool_calls: - tool_name, tool_arguments, tool_call_id = ( - LiteLLM_Proxy_MCP_Handler._extract_tool_call_details(tool_call) - ) + ( + tool_name, + tool_arguments, + tool_call_id, + ) = LiteLLM_Proxy_MCP_Handler._extract_tool_call_details(tool_call) if tool_name and tool_call_id: # Create MCP call events for this tool execution call_events = create_mcp_call_events( diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 324380db2b4..cac2fe85541 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -63,11 +63,11 @@ def _generate_id(): # private helper function return "chatcmpl-" + str(uuid.uuid4()) - class SafeAttributeModel: """ A base model that provides safe attribute access. """ + def __delattr__(self, name): try: super().__delattr__(name) @@ -125,13 +125,14 @@ class SearchContextCostPerQuery(TypedDict, total=False): class AgenticLoopParams(TypedDict, total=False): """ Parameters passed to agentic loop hooks (e.g., WebSearch interception). - + Stored in logging_obj.model_call_details["agentic_loop_params"] to provide agentic hooks with the original request context needed for follow-up calls. """ + model: str """The model string with provider prefix (e.g., 'bedrock/invoke/...')""" - + custom_llm_provider: str """The LLM provider name (e.g., 'bedrock', 'anthropic')""" @@ -1345,8 +1346,7 @@ class CacheCreationTokenDetails(BaseModel): class PromptTokensDetailsWrapper( - SafeAttributeModel, - PromptTokensDetails + SafeAttributeModel, PromptTokensDetails ): # extends with image generation fields (text_tokens, image_tokens) text_tokens: Optional[int] = None """Text tokens sent to the model.""" From 165081059a5405d849f1daf01b976237f43a21bf Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Wed, 21 Jan 2026 14:57:14 +0900 Subject: [PATCH 6/7] chore: ruff fix --- litellm/proxy/_experimental/mcp_server/server.py | 2 +- litellm/responses/mcp/litellm_proxy_mcp_handler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 4e5c73be806..03652ae155e 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -738,7 +738,7 @@ def _prepare_mcp_server_headers( return server_auth_header, extra_headers - async def _get_tools_from_mcp_servers( + async def _get_tools_from_mcp_servers( # noqa: PLR0915 user_api_key_auth: Optional[UserAPIKeyAuth], mcp_auth_header: Optional[str], mcp_servers: Optional[List[str]], diff --git a/litellm/responses/mcp/litellm_proxy_mcp_handler.py b/litellm/responses/mcp/litellm_proxy_mcp_handler.py index 7cc821f5e02..2419126fe2f 100644 --- a/litellm/responses/mcp/litellm_proxy_mcp_handler.py +++ b/litellm/responses/mcp/litellm_proxy_mcp_handler.py @@ -479,7 +479,7 @@ def _parse_mcp_result(result: Any) -> str: return result_text or "Tool executed successfully" @staticmethod - async def _execute_tool_calls( + async def _execute_tool_calls( # noqa: PLR0915 tool_server_map: dict[str, str], tool_calls: List[Any], user_api_key_auth: Any, From a312ef31e39af511688a4d49101edacb519e14d6 Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Wed, 21 Jan 2026 15:04:24 +0900 Subject: [PATCH 7/7] chore: fix mypy --- litellm/responses/mcp/litellm_proxy_mcp_handler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/responses/mcp/litellm_proxy_mcp_handler.py b/litellm/responses/mcp/litellm_proxy_mcp_handler.py index 2419126fe2f..4376e076a95 100644 --- a/litellm/responses/mcp/litellm_proxy_mcp_handler.py +++ b/litellm/responses/mcp/litellm_proxy_mcp_handler.py @@ -35,7 +35,9 @@ # NOTE: We intentionally keep ToolParam as a broad type here to avoid tight coupling # to optional OpenAI SDK typing symbols in environments that may not have them available. -ToolParam = Dict[str, Any] +# `Any` is used to keep mypy compatible with the broader OpenAI tool union types +# passed around in Responses API while still allowing dict-style access at runtime. +ToolParam = Any LITELLM_PROXY_MCP_SERVER_URL = "litellm_proxy" LITELLM_PROXY_MCP_SERVER_URL_PREFIX = f"{LITELLM_PROXY_MCP_SERVER_URL}/mcp/"