Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions docs/my-website/docs/proxy/call_hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,6 @@ from litellm.proxy.proxy_server import UserAPIKeyAuth
from typing import Any, Dict, Optional

class CustomHeaderLogger(CustomLogger):
def __init__(self):
super().__init__()

async def async_post_call_response_headers_hook(
self,
data: dict,
Expand All @@ -425,8 +422,25 @@ class CustomHeaderLogger(CustomLogger):
) -> Optional[Dict[str, str]]:
"""
Inject custom headers into all responses (success and failure).
Works for /chat/completions, /embeddings, and /responses.

Use request_headers to echo incoming headers (e.g., API gateway request IDs).
"""
return {"x-custom-header": "custom-value"}
headers = {"x-custom-header": "custom-value"}

# Echo an incoming gateway request ID into the response
if request_headers:
gateway_id = request_headers.get("x-gateway-request-id")
if gateway_id:
headers["x-gateway-request-id"] = gateway_id

return headers

proxy_handler_instance = CustomHeaderLogger()
```

:::tip
This hook works for **all proxy endpoints**: `/chat/completions`, `/embeddings`, `/responses` (streaming and non-streaming), and failure responses.

The `request_headers` parameter contains the original HTTP request headers, allowing you to echo incoming headers (e.g., API gateway request IDs) into the response.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs say "original" but headers are filtered

The text says request_headers contains "the original HTTP request headers," but _filter_sensitive_headers strips authorization, cookie, and proxy-authorization before passing them to callbacks. Consider updating the wording to note this filtering, e.g.:

"The request_headers parameter contains the HTTP request headers (with sensitive headers like authorization and cookie removed for security)..."

This helps callback authors understand they won't have access to auth headers through this parameter.

:::
9 changes: 5 additions & 4 deletions litellm/proxy/anthropic_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ async def anthropic_response( # noqa: PLR0915
except ModifyResponseException as e:
# Guardrail flagged content in passthrough mode - return 200 with violation message
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
await base_llm_response_processor._handle_modify_response_exception(
e=e,
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
proxy_logging_obj=proxy_logging_obj,
fastapi_response=fastapi_response,
)

# Create Anthropic-formatted response with violation message
Expand Down Expand Up @@ -110,7 +111,7 @@ async def _passthrough_stream_generator():
return await create_response(
generator=selected_data_generator,
media_type="text/event-stream",
headers={},
headers=dict(fastapi_response.headers),
)

return _anthropic_response
Expand Down
49 changes: 49 additions & 0 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
AsyncGenerator,
Callable,
Dict,
Literal,
Optional,
Tuple,
Expand Down Expand Up @@ -49,6 +50,7 @@
StreamErrorSerializer = Callable[[ProxyException], str]

if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import ModifyResponseException
from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig

ProxyConfig = _ProxyConfig
Expand Down Expand Up @@ -351,8 +353,23 @@ def _get_cost_breakdown_from_logging_obj(


class ProxyBaseLLMRequestProcessing:
# Headers excluded from request_headers passed to callbacks to avoid leaking credentials
_SENSITIVE_HEADERS = frozenset({"authorization", "cookie", "proxy-authorization"})

def __init__(self, data: dict):
self.data = data
self._request_headers: Optional[Dict[str, str]] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing Dict import causes NameError

Dict is used in the type annotation here but is not imported from typing (lines 6-15 import TYPE_CHECKING, Any, AsyncGenerator, Callable, Literal, Optional, Tuple, Union β€” but not Dict). In Python without from __future__ import annotations, variable annotations in function bodies are evaluated at runtime. This will raise a NameError: name 'Dict' is not defined when ProxyBaseLLMRequestProcessing.__init__ is called.

Suggested change
self._request_headers: Optional[Dict[str, str]] = None
self._request_headers: Optional[dict] = None

Alternatively, add Dict to the typing imports at line 6.


@staticmethod
def _filter_sensitive_headers(
headers: "starlette.datastructures.Headers",
) -> dict:
"""Return a copy of request headers with sensitive values removed."""
return {
k: v
for k, v in headers.items()
if k.lower() not in ProxyBaseLLMRequestProcessing._SENSITIVE_HEADERS
}

@staticmethod
def get_custom_headers(
Expand Down Expand Up @@ -749,6 +766,8 @@ async def base_process_llm_request(
"""
Common request processing logic for both chat completions and responses API endpoints
"""
self._request_headers = self._filter_sensitive_headers(request.headers)

requested_model_from_client: Optional[str] = (
self.data.get("model") if isinstance(self.data.get("model"), str) else None
)
Expand Down Expand Up @@ -859,6 +878,7 @@ async def base_process_llm_request(
data=self.data,
user_api_key_dict=user_api_key_dict,
response=response,
request_headers=self._request_headers,
)
if callback_headers:
custom_headers.update(callback_headers)
Expand Down Expand Up @@ -967,6 +987,7 @@ async def base_process_llm_request(
data=self.data,
user_api_key_dict=user_api_key_dict,
response=response,
request_headers=self._request_headers,
)
if callback_headers:
fastapi_response.headers.update(callback_headers)
Expand Down Expand Up @@ -1065,6 +1086,33 @@ def _is_streaming_request(
return True
return False

async def _handle_modify_response_exception(
self,
e: "ModifyResponseException",
user_api_key_dict: UserAPIKeyAuth,
proxy_logging_obj: ProxyLogging,
fastapi_response: Response,
):
"""Centralized handling for ModifyResponseException (guardrail passthrough).

Calls the failure hook and injects custom response headers β€” mirrors
the pattern in ``_handle_llm_api_exception`` for error responses.
"""
await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=e.request_data,
)

callback_headers = await proxy_logging_obj.post_call_response_headers_hook(
data=e.request_data,
user_api_key_dict=user_api_key_dict,
response=None,
request_headers=self._request_headers,
)
if callback_headers:
fastapi_response.headers.update(callback_headers)

async def _handle_llm_api_exception(
self,
e: Exception,
Expand Down Expand Up @@ -1135,6 +1183,7 @@ async def _handle_llm_api_exception(
data=self.data,
user_api_key_dict=user_api_key_dict,
response=None,
request_headers=self._request_headers,
)
if callback_headers:
headers.update(callback_headers)
Expand Down
17 changes: 11 additions & 6 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6197,11 +6197,13 @@ async def chat_completion( # noqa: PLR0915
except ModifyResponseException as e:
# Guardrail flagged content in passthrough mode - return 200 with violation message
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
await base_llm_response_processor._handle_modify_response_exception(
e=e,
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
proxy_logging_obj=proxy_logging_obj,
fastapi_response=fastapi_response,
)

_chat_response = litellm.ModelResponse()
_chat_response.model = e.model # type: ignore
_chat_response.choices[0].message.content = e.message # type: ignore
Expand All @@ -6227,6 +6229,7 @@ async def chat_completion( # noqa: PLR0915
selected_data_generator,
media_type="text/event-stream",
status_code=200, # Return 200 for passthrough mode
headers=dict(fastapi_response.headers),
)
_usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
_chat_response.usage = _usage # type: ignore
Expand Down Expand Up @@ -6361,10 +6364,11 @@ async def completion( # noqa: PLR0915
except ModifyResponseException as e:
# Guardrail flagged content in passthrough mode - return 200 with violation message
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
await base_llm_response_processor._handle_modify_response_exception(
e=e,
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
proxy_logging_obj=proxy_logging_obj,
fastapi_response=fastapi_response,
)

if _data.get("stream", None) is not None and _data["stream"] is True:
Expand Down Expand Up @@ -6397,6 +6401,7 @@ async def completion( # noqa: PLR0915
selected_data_generator,
media_type="text/event-stream",
status_code=200, # Return 200 for passthrough mode
headers=dict(fastapi_response.headers),
)
else:
_response = litellm.TextCompletionResponse()
Expand Down
8 changes: 4 additions & 4 deletions litellm/proxy/response_api_endpoints/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ async def responses_api(
return response
except ModifyResponseException as e:
# Guardrail passthrough: return violation message in Responses API format (200)
_data = e.request_data
await proxy_logging_obj.post_call_failure_hook(
await processor._handle_modify_response_exception(
e=e,
user_api_key_dict=user_api_key_dict,
original_exception=e,
request_data=_data,
proxy_logging_obj=proxy_logging_obj,
fastapi_response=fastapi_response,
)

violation_text = e.message
Expand Down
90 changes: 90 additions & 0 deletions tests/e2e_demo_response_headers_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""
Demo CustomLogger that injects custom response headers.

Shows how to:
1. Echo an incoming request header (e.g., APIGEE request ID) into the response
2. Inject headers on both success and failure paths
3. Works for /chat/completions, /embeddings, and /responses

Usage:
litellm --config tests/e2e_demo_response_headers_config.yaml

Test commands:
# /chat/completions (non-streaming)
curl -s -D- http://localhost:4000/chat/completions \
-H "Authorization: Bearer sk-1234" \
-H "x-apigee-request-id: apigee-req-001" \
-d '{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}]}'

# /chat/completions (streaming)
curl -s -D- http://localhost:4000/chat/completions \
-H "Authorization: Bearer sk-1234" \
-H "x-apigee-request-id: apigee-req-002" \
-d '{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}],"stream":true}'

# /embeddings
curl -s -D- http://localhost:4000/embeddings \
-H "Authorization: Bearer sk-1234" \
-H "x-apigee-request-id: apigee-req-003" \
-d '{"model":"text-embedding-3-small","input":"hello"}'

# /v1/responses (non-streaming)
curl -s -D- http://localhost:4000/v1/responses \
-H "Authorization: Bearer sk-1234" \
-H "x-apigee-request-id: apigee-req-004" \
-d '{"model":"gpt-4o-mini","input":"hi"}'

# /v1/responses (streaming)
curl -s -D- http://localhost:4000/v1/responses \
-H "Authorization: Bearer sk-1234" \
-H "x-apigee-request-id: apigee-req-005" \
-d '{"model":"gpt-4o-mini","input":"hi","stream":true}'

# Failure path (bad model β†’ headers still injected)
curl -s -D- http://localhost:4000/chat/completions \
-H "Authorization: Bearer sk-1234" \
-H "x-apigee-request-id: apigee-req-006" \
-d '{"model":"nonexistent-model","messages":[{"role":"user","content":"hi"}]}'

Expected: All responses contain x-apigee-request-id, x-custom-header, and x-litellm-hook-model.
"""

from typing import Any, Dict, Optional

from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth


class ResponseHeaderInjector(CustomLogger):
"""
Demonstrates injecting custom HTTP response headers via the proxy hook.

Key features:
- Echoes the incoming x-apigee-request-id header back in the response
- Adds a static custom header and the model name
- Works for success (streaming + non-streaming) and failure responses
- Works for all endpoints: /chat/completions, /embeddings, /responses
"""

async def async_post_call_response_headers_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Any,
request_headers: Optional[Dict[str, str]] = None,
) -> Optional[Dict[str, str]]:
headers: Dict[str, str] = {
"x-custom-header": "hello-from-hook",
"x-litellm-hook-model": data.get("model", "unknown"),
}

# Echo the APIGEE request ID from the incoming request into the response
if request_headers:
apigee_id = request_headers.get("x-apigee-request-id")
if apigee_id:
headers["x-apigee-request-id"] = apigee_id

return headers


response_header_injector = ResponseHeaderInjector()
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,66 @@ async def test_default_hook_returns_none():
response=None,
)
assert result is None


@pytest.mark.asyncio
async def test_response_headers_hook_receives_request_headers():
"""Test that the hook receives request_headers when provided."""
injector = HeaderInjectorLogger(headers={"x-echoed": "yes"})
mock_request_headers = {"x-apigee-request-id": "req-abc-123", "authorization": "Bearer sk-xxx"}

with patch("litellm.callbacks", [injector]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

result = await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response={"id": "resp-1"},
request_headers=mock_request_headers,
)

assert injector.called is True
assert result == {"x-echoed": "yes"}


@pytest.mark.asyncio
async def test_response_headers_hook_request_headers_passed_to_callback():
"""Test that request_headers are forwarded to the callback and can be used to echo incoming headers."""

class RequestHeaderAwareLogger(CustomLogger):
def __init__(self):
self.received_request_headers = None

async def async_post_call_response_headers_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Any,
request_headers: Optional[Dict[str, str]] = None,
) -> Optional[Dict[str, str]]:
self.received_request_headers = request_headers
if request_headers and "x-apigee-request-id" in request_headers:
return {"x-apigee-request-id": request_headers["x-apigee-request-id"]}
return None

logger = RequestHeaderAwareLogger()
mock_request_headers = {"x-apigee-request-id": "apigee-123"}

with patch("litellm.callbacks", [logger]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

result = await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response=None,
request_headers=mock_request_headers,
)

assert logger.received_request_headers == mock_request_headers
assert result == {"x-apigee-request-id": "apigee-123"}
Loading
Loading