Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1908,7 +1908,14 @@ async def pre_call_tool_check(
user_api_key_auth: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
server: MCPServer,
):
) -> Dict[str, Any]:
"""
Run pre-call checks and guardrail hooks for an MCP tool call.

Returns a dict that may contain:
- "arguments": hook-modified tool arguments (only if changed)
- "extra_headers": headers injected by pre_mcp_call guardrail hooks
"""
## check if the tool is allowed or banned for the given server
if not self.check_allowed_or_banned_tools(name, server):
raise HTTPException(
Expand Down Expand Up @@ -1969,6 +1976,7 @@ async def pre_call_tool_check(
mcp_request_obj, pre_hook_kwargs
)

hook_result: Dict[str, Any] = {}
try:
# Use standard pre_call_hook
modified_data = await proxy_logging_obj.pre_call_hook(
Expand All @@ -1984,7 +1992,9 @@ async def pre_call_tool_check(
)
)
if modified_kwargs.get("arguments") != arguments:
arguments = modified_kwargs["arguments"]
hook_result["arguments"] = modified_kwargs["arguments"]
if modified_kwargs.get("extra_headers"):
hook_result["extra_headers"] = modified_kwargs["extra_headers"]

except (
BlockedPiiEntityError,
Expand All @@ -1995,6 +2005,8 @@ async def pre_call_tool_check(
verbose_logger.error(f"Guardrail blocked MCP tool call pre call: {str(e)}")
raise e

return hook_result

def _create_during_hook_task(
self,
name: str,
Expand Down Expand Up @@ -2047,6 +2059,7 @@ async def _call_regular_mcp_tool(
raw_headers: Optional[Dict[str, str]],
proxy_logging_obj: Optional[ProxyLogging],
host_progress_callback: Optional[Callable] = None,
hook_extra_headers: Optional[Dict[str, str]] = None,
) -> CallToolResult:
"""
Call a regular MCP tool using the MCP client.
Expand All @@ -2061,6 +2074,9 @@ async def _call_regular_mcp_tool(
oauth2_headers: Optional OAuth2 headers
raw_headers: Optional raw headers from the request
proxy_logging_obj: Optional ProxyLogging object for hook integration
host_progress_callback: Optional callback for progress updates
hook_extra_headers: Optional headers injected by pre_mcp_call guardrail
hooks. Merged last (highest priority) into outbound request headers.

Returns:
CallToolResult from the MCP server
Expand Down Expand Up @@ -2116,6 +2132,11 @@ async def _call_regular_mcp_tool(
extra_headers = {}
extra_headers.update(mcp_server.static_headers)

if hook_extra_headers:
if extra_headers is None:
extra_headers = {}
extra_headers.update(hook_extra_headers)

stdio_env = self._build_stdio_env(mcp_server, raw_headers)

client = await self._create_mcp_client(
Expand Down Expand Up @@ -2201,15 +2222,33 @@ async def call_tool(
# Allow validation and modification of tool calls before execution
# Using standard pre_call_hook
#########################################################
hook_result: Dict[str, Any] = {}
if proxy_logging_obj:
await self.pre_call_tool_check(
hook_result = await self.pre_call_tool_check(
name=name,
arguments=arguments,
server_name=server_name,
user_api_key_auth=user_api_key_auth,
proxy_logging_obj=proxy_logging_obj,
server=mcp_server,
)
if "arguments" in hook_result:
arguments = hook_result["arguments"]

# OpenAPI-backed servers cannot forward hook-injected headers — reject early
# before scheduling any background tasks to avoid orphaned asyncio.Tasks.
if mcp_server.spec_path and hook_result.get("extra_headers"):
raise HTTPException(
status_code=400,
detail={
"error": (
"pre_mcp_call hook returned extra_headers for an "
"OpenAPI-backed MCP server, which does not support "
"hook header injection. Use a regular MCP server "
"(SSE/HTTP transport) for hook header support."
)
},
)

# Prepare tasks for during hooks
tasks = []
Expand Down Expand Up @@ -2247,6 +2286,7 @@ async def call_tool(
raw_headers=raw_headers,
proxy_logging_obj=proxy_logging_obj,
host_progress_callback=host_progress_callback,
hook_extra_headers=hook_result.get("extra_headers"),
)

# For OpenAPI tools, await outside the client context
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,7 @@ class UserAPIKeyAuth(
Any
] = None # Expanded created_by user when expand=user is used
end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
jwt_claims: Optional[Dict] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
4 changes: 4 additions & 0 deletions litellm/proxy/auth/user_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
)
if valid_token is not None:
api_key = valid_token.token or ""
valid_token.jwt_claims = jwt_claims
do_standard_jwt_auth = False
# Fall through to virtual key checks

Expand Down Expand Up @@ -729,6 +730,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
team_membership: Optional[LiteLLM_TeamMembership] = result.get(
"team_membership", None
)
jwt_claims: Optional[dict] = result.get("jwt_claims", None)

global_proxy_spend = await get_global_proxy_spend(
litellm_proxy_admin_name=litellm_proxy_admin_name,
Expand Down Expand Up @@ -757,6 +759,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
org_id=org_id,
end_user_id=end_user_id,
parent_otel_span=parent_otel_span,
jwt_claims=jwt_claims,
)

valid_token = UserAPIKeyAuth(
Expand Down Expand Up @@ -803,6 +806,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
team_metadata=(
team_object.metadata if team_object is not None else None
),
jwt_claims=jwt_claims,
)

# Check if model has zero cost - if so, skip all budget checks
Expand Down
9 changes: 7 additions & 2 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,17 +824,22 @@ def _convert_mcp_hook_response_to_kwargs(
) -> dict:
"""
Helper function to convert pre_call_hook response back to kwargs for MCP usage.

Supports:
- modified_arguments: Override tool call arguments
- extra_headers: Inject custom headers into the outbound MCP request
"""
if not response_data:
return original_kwargs

# Apply any argument modifications from the hook response
modified_kwargs = original_kwargs.copy()

# If the response contains modified arguments, apply them
if response_data.get("modified_arguments"):
modified_kwargs["arguments"] = response_data["modified_arguments"]

if response_data.get("extra_headers"):
modified_kwargs["extra_headers"] = response_data["extra_headers"]

return modified_kwargs

async def process_pre_call_hook_response(self, response, data, call_type):
Expand Down
Loading
Loading