-
-
Notifications
You must be signed in to change notification settings - Fork 6.6k
fix(proxy): pass request_headers to response headers hook + fix guardrail gap #21385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6b73c3c
650acd0
6318692
36717bd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |||||
| Any, | ||||||
| AsyncGenerator, | ||||||
| Callable, | ||||||
| Dict, | ||||||
| Literal, | ||||||
| Optional, | ||||||
| Tuple, | ||||||
|
|
@@ -49,6 +50,7 @@ | |||||
| StreamErrorSerializer = Callable[[ProxyException], str] | ||||||
|
|
||||||
| if TYPE_CHECKING: | ||||||
| from litellm.integrations.custom_guardrail import ModifyResponseException | ||||||
| from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig | ||||||
|
|
||||||
| ProxyConfig = _ProxyConfig | ||||||
|
|
@@ -351,8 +353,23 @@ def _get_cost_breakdown_from_logging_obj( | |||||
|
|
||||||
|
|
||||||
| class ProxyBaseLLMRequestProcessing: | ||||||
| # Headers excluded from request_headers passed to callbacks to avoid leaking credentials | ||||||
| _SENSITIVE_HEADERS = frozenset({"authorization", "cookie", "proxy-authorization"}) | ||||||
|
|
||||||
| def __init__(self, data: dict): | ||||||
| self.data = data | ||||||
| self._request_headers: Optional[Dict[str, str]] = None | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing
Suggested change
Alternatively, add |
||||||
|
|
||||||
| @staticmethod | ||||||
| def _filter_sensitive_headers( | ||||||
| headers: "starlette.datastructures.Headers", | ||||||
| ) -> dict: | ||||||
| """Return a copy of request headers with sensitive values removed.""" | ||||||
| return { | ||||||
| k: v | ||||||
| for k, v in headers.items() | ||||||
| if k.lower() not in ProxyBaseLLMRequestProcessing._SENSITIVE_HEADERS | ||||||
| } | ||||||
|
|
||||||
| @staticmethod | ||||||
| def get_custom_headers( | ||||||
|
|
@@ -749,6 +766,8 @@ async def base_process_llm_request( | |||||
| """ | ||||||
| Common request processing logic for both chat completions and responses API endpoints | ||||||
| """ | ||||||
| self._request_headers = self._filter_sensitive_headers(request.headers) | ||||||
|
|
||||||
| requested_model_from_client: Optional[str] = ( | ||||||
| self.data.get("model") if isinstance(self.data.get("model"), str) else None | ||||||
| ) | ||||||
|
|
@@ -859,6 +878,7 @@ async def base_process_llm_request( | |||||
| data=self.data, | ||||||
| user_api_key_dict=user_api_key_dict, | ||||||
| response=response, | ||||||
| request_headers=self._request_headers, | ||||||
| ) | ||||||
| if callback_headers: | ||||||
| custom_headers.update(callback_headers) | ||||||
|
|
@@ -967,6 +987,7 @@ async def base_process_llm_request( | |||||
| data=self.data, | ||||||
| user_api_key_dict=user_api_key_dict, | ||||||
| response=response, | ||||||
| request_headers=self._request_headers, | ||||||
| ) | ||||||
| if callback_headers: | ||||||
| fastapi_response.headers.update(callback_headers) | ||||||
|
|
@@ -1065,6 +1086,33 @@ def _is_streaming_request( | |||||
| return True | ||||||
| return False | ||||||
|
|
||||||
| async def _handle_modify_response_exception( | ||||||
| self, | ||||||
| e: "ModifyResponseException", | ||||||
| user_api_key_dict: UserAPIKeyAuth, | ||||||
| proxy_logging_obj: ProxyLogging, | ||||||
| fastapi_response: Response, | ||||||
| ): | ||||||
| """Centralized handling for ModifyResponseException (guardrail passthrough). | ||||||
|
|
||||||
| Calls the failure hook and injects custom response headers β mirrors | ||||||
| the pattern in ``_handle_llm_api_exception`` for error responses. | ||||||
| """ | ||||||
| await proxy_logging_obj.post_call_failure_hook( | ||||||
| user_api_key_dict=user_api_key_dict, | ||||||
| original_exception=e, | ||||||
| request_data=e.request_data, | ||||||
| ) | ||||||
|
|
||||||
| callback_headers = await proxy_logging_obj.post_call_response_headers_hook( | ||||||
| data=e.request_data, | ||||||
| user_api_key_dict=user_api_key_dict, | ||||||
| response=None, | ||||||
| request_headers=self._request_headers, | ||||||
| ) | ||||||
| if callback_headers: | ||||||
| fastapi_response.headers.update(callback_headers) | ||||||
|
|
||||||
| async def _handle_llm_api_exception( | ||||||
| self, | ||||||
| e: Exception, | ||||||
|
|
@@ -1135,6 +1183,7 @@ async def _handle_llm_api_exception( | |||||
| data=self.data, | ||||||
| user_api_key_dict=user_api_key_dict, | ||||||
| response=None, | ||||||
| request_headers=self._request_headers, | ||||||
| ) | ||||||
| if callback_headers: | ||||||
| headers.update(callback_headers) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| """ | ||
| Demo CustomLogger that injects custom response headers. | ||
|
|
||
| Shows how to: | ||
| 1. Echo an incoming request header (e.g., APIGEE request ID) into the response | ||
| 2. Inject headers on both success and failure paths | ||
| 3. Works for /chat/completions, /embeddings, and /responses | ||
|
|
||
| Usage: | ||
| litellm --config tests/e2e_demo_response_headers_config.yaml | ||
|
|
||
| Test commands: | ||
| # /chat/completions (non-streaming) | ||
| curl -s -D- http://localhost:4000/chat/completions \ | ||
| -H "Authorization: Bearer sk-1234" \ | ||
| -H "x-apigee-request-id: apigee-req-001" \ | ||
| -d '{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}]}' | ||
|
|
||
| # /chat/completions (streaming) | ||
| curl -s -D- http://localhost:4000/chat/completions \ | ||
| -H "Authorization: Bearer sk-1234" \ | ||
| -H "x-apigee-request-id: apigee-req-002" \ | ||
| -d '{"model":"gpt-4o-mini","messages":[{"role":"user","content":"hi"}],"stream":true}' | ||
|
|
||
| # /embeddings | ||
| curl -s -D- http://localhost:4000/embeddings \ | ||
| -H "Authorization: Bearer sk-1234" \ | ||
| -H "x-apigee-request-id: apigee-req-003" \ | ||
| -d '{"model":"text-embedding-3-small","input":"hello"}' | ||
|
|
||
| # /v1/responses (non-streaming) | ||
| curl -s -D- http://localhost:4000/v1/responses \ | ||
| -H "Authorization: Bearer sk-1234" \ | ||
| -H "x-apigee-request-id: apigee-req-004" \ | ||
| -d '{"model":"gpt-4o-mini","input":"hi"}' | ||
|
|
||
| # /v1/responses (streaming) | ||
| curl -s -D- http://localhost:4000/v1/responses \ | ||
| -H "Authorization: Bearer sk-1234" \ | ||
| -H "x-apigee-request-id: apigee-req-005" \ | ||
| -d '{"model":"gpt-4o-mini","input":"hi","stream":true}' | ||
|
|
||
| # Failure path (bad model β headers still injected) | ||
| curl -s -D- http://localhost:4000/chat/completions \ | ||
| -H "Authorization: Bearer sk-1234" \ | ||
| -H "x-apigee-request-id: apigee-req-006" \ | ||
| -d '{"model":"nonexistent-model","messages":[{"role":"user","content":"hi"}]}' | ||
|
|
||
| Expected: All responses contain x-apigee-request-id, x-custom-header, and x-litellm-hook-model. | ||
| """ | ||
|
|
||
| from typing import Any, Dict, Optional | ||
|
|
||
| from litellm.integrations.custom_logger import CustomLogger | ||
| from litellm.proxy._types import UserAPIKeyAuth | ||
|
|
||
|
|
||
| class ResponseHeaderInjector(CustomLogger): | ||
| """ | ||
| Demonstrates injecting custom HTTP response headers via the proxy hook. | ||
|
|
||
| Key features: | ||
| - Echoes the incoming x-apigee-request-id header back in the response | ||
| - Adds a static custom header and the model name | ||
| - Works for success (streaming + non-streaming) and failure responses | ||
| - Works for all endpoints: /chat/completions, /embeddings, /responses | ||
| """ | ||
|
|
||
| async def async_post_call_response_headers_hook( | ||
| self, | ||
| data: dict, | ||
| user_api_key_dict: UserAPIKeyAuth, | ||
| response: Any, | ||
| request_headers: Optional[Dict[str, str]] = None, | ||
| ) -> Optional[Dict[str, str]]: | ||
| headers: Dict[str, str] = { | ||
| "x-custom-header": "hello-from-hook", | ||
| "x-litellm-hook-model": data.get("model", "unknown"), | ||
| } | ||
|
|
||
| # Echo the APIGEE request ID from the incoming request into the response | ||
| if request_headers: | ||
| apigee_id = request_headers.get("x-apigee-request-id") | ||
| if apigee_id: | ||
| headers["x-apigee-request-id"] = apigee_id | ||
|
|
||
| return headers | ||
|
|
||
|
|
||
| response_header_injector = ResponseHeaderInjector() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docs say "original" but headers are filtered
The text says
request_headerscontains "the original HTTP request headers," but_filter_sensitive_headersstripsauthorization,cookie, andproxy-authorizationbefore passing them to callbacks. Consider updating the wording to note this filtering, e.g.:"The
request_headersparameter contains the HTTP request headers (with sensitive headers likeauthorizationandcookieremoved for security)..."This helps callback authors understand they won't have access to auth headers through this parameter.