diff --git a/docs/my-website/docs/proxy/call_hooks.md b/docs/my-website/docs/proxy/call_hooks.md index 17354725fd5..d0a6aa779dc 100644 --- a/docs/my-website/docs/proxy/call_hooks.md +++ b/docs/my-website/docs/proxy/call_hooks.md @@ -413,9 +413,6 @@ from litellm.proxy.proxy_server import UserAPIKeyAuth from typing import Any, Dict, Optional class CustomHeaderLogger(CustomLogger): - def __init__(self): - super().__init__() - async def async_post_call_response_headers_hook( self, data: dict, @@ -425,8 +422,25 @@ class CustomHeaderLogger(CustomLogger): ) -> Optional[Dict[str, str]]: """ Inject custom headers into all responses (success and failure). + Works for /chat/completions, /embeddings, and /responses. + + Use request_headers to echo incoming headers (e.g., API gateway request IDs). """ - return {"x-custom-header": "custom-value"} + headers = {"x-custom-header": "custom-value"} + + # Echo an incoming gateway request ID into the response + if request_headers: + gateway_id = request_headers.get("x-gateway-request-id") + if gateway_id: + headers["x-gateway-request-id"] = gateway_id + + return headers proxy_handler_instance = CustomHeaderLogger() ``` + +:::tip +This hook works for **all proxy endpoints**: `/chat/completions`, `/embeddings`, `/responses` (streaming and non-streaming), and failure responses. + +The `request_headers` parameter contains the original HTTP request headers, allowing you to echo incoming headers (e.g., API gateway request IDs) into the response. +::: diff --git a/litellm/proxy/anthropic_endpoints/endpoints.py b/litellm/proxy/anthropic_endpoints/endpoints.py index 77bb1f53e62..3d6c58a33ee 100644 --- a/litellm/proxy/anthropic_endpoints/endpoints.py +++ b/litellm/proxy/anthropic_endpoints/endpoints.py @@ -72,10 +72,11 @@ async def anthropic_response( # noqa: PLR0915 except ModifyResponseException as e: # Guardrail flagged content in passthrough mode - return 200 with violation message _data = e.request_data - await proxy_logging_obj.post_call_failure_hook( + await base_llm_response_processor._handle_modify_response_exception( + e=e, user_api_key_dict=user_api_key_dict, - original_exception=e, - request_data=_data, + proxy_logging_obj=proxy_logging_obj, + fastapi_response=fastapi_response, ) # Create Anthropic-formatted response with violation message @@ -110,7 +111,7 @@ async def _passthrough_stream_generator(): return await create_response( generator=selected_data_generator, media_type="text/event-stream", - headers={}, + headers=dict(fastapi_response.headers), ) return _anthropic_response diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 7dfa3bb239f..d1b36c07c6f 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -8,6 +8,7 @@ Any, AsyncGenerator, Callable, + Dict, Literal, Optional, Tuple, @@ -49,6 +50,7 @@ StreamErrorSerializer = Callable[[ProxyException], str] if TYPE_CHECKING: + from litellm.integrations.custom_guardrail import ModifyResponseException from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig ProxyConfig = _ProxyConfig @@ -351,8 +353,23 @@ def _get_cost_breakdown_from_logging_obj( class ProxyBaseLLMRequestProcessing: + # Headers excluded from request_headers passed to callbacks to avoid leaking credentials + _SENSITIVE_HEADERS = frozenset({"authorization", "cookie", "proxy-authorization"}) + def __init__(self, data: dict): self.data = data + self._request_headers: Optional[Dict[str, str]] = None + + @staticmethod + def _filter_sensitive_headers( + headers: "starlette.datastructures.Headers", + ) -> dict: + """Return a copy of request headers with sensitive values removed.""" + return { + k: v + for k, v in headers.items() + if k.lower() not in ProxyBaseLLMRequestProcessing._SENSITIVE_HEADERS + } @staticmethod def get_custom_headers( @@ -749,6 +766,8 @@ async def base_process_llm_request( """ Common request processing logic for both chat completions and responses API endpoints """ + self._request_headers = self._filter_sensitive_headers(request.headers) + requested_model_from_client: Optional[str] = ( self.data.get("model") if isinstance(self.data.get("model"), str) else None ) @@ -859,6 +878,7 @@ async def base_process_llm_request( data=self.data, user_api_key_dict=user_api_key_dict, response=response, + request_headers=self._request_headers, ) if callback_headers: custom_headers.update(callback_headers) @@ -967,6 +987,7 @@ async def base_process_llm_request( data=self.data, user_api_key_dict=user_api_key_dict, response=response, + request_headers=self._request_headers, ) if callback_headers: fastapi_response.headers.update(callback_headers) @@ -1065,6 +1086,33 @@ def _is_streaming_request( return True return False + async def _handle_modify_response_exception( + self, + e: "ModifyResponseException", + user_api_key_dict: UserAPIKeyAuth, + proxy_logging_obj: ProxyLogging, + fastapi_response: Response, + ): + """Centralized handling for ModifyResponseException (guardrail passthrough). + + Calls the failure hook and injects custom response headers — mirrors + the pattern in ``_handle_llm_api_exception`` for error responses. + """ + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=e.request_data, + ) + + callback_headers = await proxy_logging_obj.post_call_response_headers_hook( + data=e.request_data, + user_api_key_dict=user_api_key_dict, + response=None, + request_headers=self._request_headers, + ) + if callback_headers: + fastapi_response.headers.update(callback_headers) + async def _handle_llm_api_exception( self, e: Exception, @@ -1135,6 +1183,7 @@ async def _handle_llm_api_exception( data=self.data, user_api_key_dict=user_api_key_dict, response=None, + request_headers=self._request_headers, ) if callback_headers: headers.update(callback_headers) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7d1067f7b05..9f7156273a6 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -6197,11 +6197,13 @@ async def chat_completion( # noqa: PLR0915 except ModifyResponseException as e: # Guardrail flagged content in passthrough mode - return 200 with violation message _data = e.request_data - await proxy_logging_obj.post_call_failure_hook( + await base_llm_response_processor._handle_modify_response_exception( + e=e, user_api_key_dict=user_api_key_dict, - original_exception=e, - request_data=_data, + proxy_logging_obj=proxy_logging_obj, + fastapi_response=fastapi_response, ) + _chat_response = litellm.ModelResponse() _chat_response.model = e.model # type: ignore _chat_response.choices[0].message.content = e.message # type: ignore @@ -6227,6 +6229,7 @@ async def chat_completion( # noqa: PLR0915 selected_data_generator, media_type="text/event-stream", status_code=200, # Return 200 for passthrough mode + headers=dict(fastapi_response.headers), ) _usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0) _chat_response.usage = _usage # type: ignore @@ -6361,10 +6364,11 @@ async def completion( # noqa: PLR0915 except ModifyResponseException as e: # Guardrail flagged content in passthrough mode - return 200 with violation message _data = e.request_data - await proxy_logging_obj.post_call_failure_hook( + await base_llm_response_processor._handle_modify_response_exception( + e=e, user_api_key_dict=user_api_key_dict, - original_exception=e, - request_data=_data, + proxy_logging_obj=proxy_logging_obj, + fastapi_response=fastapi_response, ) if _data.get("stream", None) is not None and _data["stream"] is True: @@ -6397,6 +6401,7 @@ async def completion( # noqa: PLR0915 selected_data_generator, media_type="text/event-stream", status_code=200, # Return 200 for passthrough mode + headers=dict(fastapi_response.headers), ) else: _response = litellm.TextCompletionResponse() diff --git a/litellm/proxy/response_api_endpoints/endpoints.py b/litellm/proxy/response_api_endpoints/endpoints.py index 44e8c42b2c1..05f187cdf7e 100644 --- a/litellm/proxy/response_api_endpoints/endpoints.py +++ b/litellm/proxy/response_api_endpoints/endpoints.py @@ -219,11 +219,11 @@ async def responses_api( return response except ModifyResponseException as e: # Guardrail passthrough: return violation message in Responses API format (200) - _data = e.request_data - await proxy_logging_obj.post_call_failure_hook( + await processor._handle_modify_response_exception( + e=e, user_api_key_dict=user_api_key_dict, - original_exception=e, - request_data=_data, + proxy_logging_obj=proxy_logging_obj, + fastapi_response=fastapi_response, ) violation_text = e.message diff --git a/tests/e2e_demo_response_headers_callback.py b/tests/e2e_demo_response_headers_callback.py new file mode 100644 index 00000000000..aabe0008a99 --- /dev/null +++ b/tests/e2e_demo_response_headers_callback.py @@ -0,0 +1,90 @@ +""" +Demo CustomLogger that injects custom response headers. + +Shows how to: +1. Echo an incoming request header (e.g., APIGEE request ID) into the response +2. Inject headers on both success and failure paths +3. Works for /chat/completions, /embeddings, and /responses + +Usage: + litellm --config tests/e2e_demo_response_headers_config.yaml + +Test commands: + # /chat/completions (non-streaming) + curl -s -D- http://localhost:4000/chat/completions \ + -H "Authorization: Bearer sk-1234" \ + -H "x-apigee-request-id: apigee-req-001" \ + -d '{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}]}' + + # /chat/completions (streaming) + curl -s -D- http://localhost:4000/chat/completions \ + -H "Authorization: Bearer sk-1234" \ + -H "x-apigee-request-id: apigee-req-002" \ + -d '{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}],"stream":true}' + + # /embeddings + curl -s -D- http://localhost:4000/embeddings \ + -H "Authorization: Bearer sk-1234" \ + -H "x-apigee-request-id: apigee-req-003" \ + -d '{"model":"text-embedding-3-small","input":"hello"}' + + # /v1/responses (non-streaming) + curl -s -D- http://localhost:4000/v1/responses \ + -H "Authorization: Bearer sk-1234" \ + -H "x-apigee-request-id: apigee-req-004" \ + -d '{"model":"gpt-4o-mini","input":"hi"}' + + # /v1/responses (streaming) + curl -s -D- http://localhost:4000/v1/responses \ + -H "Authorization: Bearer sk-1234" \ + -H "x-apigee-request-id: apigee-req-005" \ + -d '{"model":"gpt-4o-mini","input":"hi","stream":true}' + + # Failure path (bad model → headers still injected) + curl -s -D- http://localhost:4000/chat/completions \ + -H "Authorization: Bearer sk-1234" \ + -H "x-apigee-request-id: apigee-req-006" \ + -d '{"model":"nonexistent-model","messages":[{"role":"user","content":"hi"}]}' + +Expected: All responses contain x-apigee-request-id, x-custom-header, and x-litellm-hook-model. +""" + +from typing import Any, Dict, Optional + +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class ResponseHeaderInjector(CustomLogger): + """ + Demonstrates injecting custom HTTP response headers via the proxy hook. + + Key features: + - Echoes the incoming x-apigee-request-id header back in the response + - Adds a static custom header and the model name + - Works for success (streaming + non-streaming) and failure responses + - Works for all endpoints: /chat/completions, /embeddings, /responses + """ + + async def async_post_call_response_headers_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + request_headers: Optional[Dict[str, str]] = None, + ) -> Optional[Dict[str, str]]: + headers: Dict[str, str] = { + "x-custom-header": "hello-from-hook", + "x-litellm-hook-model": data.get("model", "unknown"), + } + + # Echo the APIGEE request ID from the incoming request into the response + if request_headers: + apigee_id = request_headers.get("x-apigee-request-id") + if apigee_id: + headers["x-apigee-request-id"] = apigee_id + + return headers + + +response_header_injector = ResponseHeaderInjector() diff --git a/tests/test_litellm/proxy/hooks/test_post_call_response_headers_hook.py b/tests/test_litellm/proxy/hooks/test_post_call_response_headers_hook.py index 6a12366fdd3..948fa6fcbe8 100644 --- a/tests/test_litellm/proxy/hooks/test_post_call_response_headers_hook.py +++ b/tests/test_litellm/proxy/hooks/test_post_call_response_headers_hook.py @@ -195,3 +195,66 @@ async def test_default_hook_returns_none(): response=None, ) assert result is None + + +@pytest.mark.asyncio +async def test_response_headers_hook_receives_request_headers(): + """Test that the hook receives request_headers when provided.""" + injector = HeaderInjectorLogger(headers={"x-echoed": "yes"}) + mock_request_headers = {"x-apigee-request-id": "req-abc-123", "authorization": "Bearer sk-xxx"} + + with patch("litellm.callbacks", [injector]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + result = await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response={"id": "resp-1"}, + request_headers=mock_request_headers, + ) + + assert injector.called is True + assert result == {"x-echoed": "yes"} + + +@pytest.mark.asyncio +async def test_response_headers_hook_request_headers_passed_to_callback(): + """Test that request_headers are forwarded to the callback and can be used to echo incoming headers.""" + + class RequestHeaderAwareLogger(CustomLogger): + def __init__(self): + self.received_request_headers = None + + async def async_post_call_response_headers_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + request_headers: Optional[Dict[str, str]] = None, + ) -> Optional[Dict[str, str]]: + self.received_request_headers = request_headers + if request_headers and "x-apigee-request-id" in request_headers: + return {"x-apigee-request-id": request_headers["x-apigee-request-id"]} + return None + + logger = RequestHeaderAwareLogger() + mock_request_headers = {"x-apigee-request-id": "apigee-123"} + + with patch("litellm.callbacks", [logger]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + result = await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response=None, + request_headers=mock_request_headers, + ) + + assert logger.received_request_headers == mock_request_headers + assert result == {"x-apigee-request-id": "apigee-123"} diff --git a/tests/test_litellm/proxy/response_api_endpoints/test_response_headers_on_guardrail_exception.py b/tests/test_litellm/proxy/response_api_endpoints/test_response_headers_on_guardrail_exception.py new file mode 100644 index 00000000000..9f8fe4142a1 --- /dev/null +++ b/tests/test_litellm/proxy/response_api_endpoints/test_response_headers_on_guardrail_exception.py @@ -0,0 +1,74 @@ +""" +Test that _handle_modify_response_exception (centralized in ProxyBaseLLMRequestProcessing) +is called on ModifyResponseException in the /responses endpoint, so custom headers appear +even on guardrail failures. +""" + +import os +import sys +import pytest +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock, patch + +sys.path.insert(0, os.path.abspath("../../../..")) + +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class GuardrailHeaderLogger(CustomLogger): + """Logger that injects headers — used to verify hook fires on guardrail path.""" + + async def async_post_call_response_headers_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + request_headers: Optional[Dict[str, str]] = None, + ) -> Optional[Dict[str, str]]: + return {"x-guardrail-header": "injected"} + + +@pytest.mark.asyncio +async def test_modify_response_exception_calls_response_headers_hook(): + """ + When a guardrail raises ModifyResponseException on /responses, + the response should still include custom headers from the hook + via the centralized _handle_modify_response_exception method. + """ + from litellm.integrations.custom_guardrail import ModifyResponseException + from litellm.proxy.proxy_server import app + from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing + from fastapi.testclient import TestClient + + guardrail_logger = GuardrailHeaderLogger() + + with patch("litellm.callbacks", [guardrail_logger]): + with patch("litellm.proxy.proxy_server.user_api_key_auth") as mock_auth: + mock_auth.return_value = MagicMock( + token="test_token", + user_id="test_user", + team_id=None, + ) + + # Only mock base_process_llm_request so the real + # _handle_modify_response_exception runs and calls the hook. + with patch.object( + ProxyBaseLLMRequestProcessing, + "base_process_llm_request", + new_callable=AsyncMock, + side_effect=ModifyResponseException( + message="Content blocked by guardrail", + model="gpt-4o", + request_data={"model": "gpt-4o"}, + ), + ): + client = TestClient(app) + response = client.post( + "/v1/responses", + json={"model": "gpt-4o", "input": "blocked content"}, + headers={"Authorization": "Bearer sk-1234"}, + ) + + assert response.status_code == 200 + assert response.headers.get("x-guardrail-header") == "injected"