Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions litellm/integrations/custom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,28 @@ async def async_pre_call_hook(
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
pass

async def async_post_call_response_headers_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Any,
request_headers: Optional[Dict[str, str]] = None,
) -> Optional[Dict[str, str]]:
"""
Called after an LLM API call (success or failure) to allow injecting custom HTTP response headers.

Args:
- data: dict - The request data.
- user_api_key_dict: UserAPIKeyAuth - The user API key dictionary.
- response: Any - The response object (None for failure cases).
- request_headers: Optional[Dict[str, str]] - The original request headers.

Returns:
- Optional[Dict[str, str]]: A dictionary of headers to inject into the HTTP response.
Return None to not inject any headers.
"""
return None

async def async_post_call_failure_hook(
self,
request_data: dict,
Expand Down
31 changes: 31 additions & 0 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,15 @@ async def base_process_llm_request(
**additional_headers,
)

# Call response headers hook for streaming success
callback_headers = await proxy_logging_obj.post_call_response_headers_hook(
data=self.data,
user_api_key_dict=user_api_key_dict,
response=response,
)
if callback_headers:
custom_headers.update(callback_headers)

# Preserve the original client-requested model (pre-alias mapping) for downstream
# streaming generators. Pre-call processing can rewrite `self.data["model"]` for
# aliasing/routing, but the OpenAI-compatible response `model` field should reflect
Expand Down Expand Up @@ -900,6 +909,16 @@ async def base_process_llm_request(
**additional_headers,
)
)

# Call response headers hook for non-streaming success
callback_headers = await proxy_logging_obj.post_call_response_headers_hook(
data=self.data,
user_api_key_dict=user_api_key_dict,
response=response,
)
if callback_headers:
fastapi_response.headers.update(callback_headers)

await check_response_size_is_safe(response=response)

return response
Expand Down Expand Up @@ -1058,6 +1077,18 @@ async def _handle_llm_api_exception(
headers = get_response_headers(dict(_response_headers))
headers.update(custom_headers)

# Call response headers hook for failure
try:
callback_headers = await proxy_logging_obj.post_call_response_headers_hook(
data=self.data,
user_api_key_dict=user_api_key_dict,
response=None,
)
if callback_headers:
headers.update(callback_headers)
except Exception:
pass

if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", str(e)),
Expand Down
40 changes: 40 additions & 0 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,6 +1808,46 @@ async def post_call_success_hook(
raise e
return response

async def post_call_response_headers_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Any,
request_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
"""
Calls async_post_call_response_headers_hook on all CustomLogger callbacks.
Merges all returned header dicts (later callbacks override earlier ones).

Returns:
Dict[str, str]: Merged headers from all callbacks.
"""
merged_headers: Dict[str, str] = {}
try:
for callback in litellm.callbacks:
_callback: Optional[CustomLogger] = None
if isinstance(callback, str):
_callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class(
cast(_custom_logger_compatible_callbacks_literal, callback)
)
else:
_callback = callback # type: ignore

if _callback is not None and isinstance(_callback, CustomLogger):
result = await _callback.async_post_call_response_headers_hook(
data=data,
user_api_key_dict=user_api_key_dict,
response=response,
request_headers=request_headers,
)
if result is not None:
merged_headers.update(result)
except Exception as e:
verbose_proxy_logger.exception(
"Error in post_call_response_headers_hook: %s", str(e)
)
return merged_headers

async def async_post_call_streaming_hook(
self,
data: dict,
Expand Down
197 changes: 197 additions & 0 deletions tests/test_litellm/proxy/hooks/test_post_call_response_headers_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
Integration tests for async_post_call_response_headers_hook.

Tests verify that CustomLogger callbacks can inject custom HTTP response headers
into success (streaming and non-streaming) and failure responses.
"""

import os
import sys
import pytest
from typing import Any, Dict, Optional
from unittest.mock import patch

sys.path.insert(0, os.path.abspath("../../../.."))

import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth


class HeaderInjectorLogger(CustomLogger):
"""Logger that injects custom headers into responses."""

def __init__(self, headers: Optional[Dict[str, str]] = None):
self.headers = headers
self.called = False
self.received_response = None
self.received_data = None

async def async_post_call_response_headers_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response: Any,
request_headers: Optional[Dict[str, str]] = None,
) -> Optional[Dict[str, str]]:
self.called = True
self.received_response = response
self.received_data = data
return self.headers


@pytest.mark.asyncio
async def test_response_headers_hook_returns_headers():
"""Test that the hook returns headers from a single callback."""
injector = HeaderInjectorLogger(headers={"x-custom-id": "abc123"})

with patch("litellm.callbacks", [injector]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

result = await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response={"id": "resp-1"},
)

assert injector.called is True
assert result == {"x-custom-id": "abc123"}


@pytest.mark.asyncio
async def test_response_headers_hook_returns_none():
"""Test that returning None results in empty headers dict."""
injector = HeaderInjectorLogger(headers=None)

with patch("litellm.callbacks", [injector]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

result = await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response={"id": "resp-1"},
)

assert injector.called is True
assert result == {}


@pytest.mark.asyncio
async def test_response_headers_hook_multiple_callbacks_merge():
"""Test that headers from multiple callbacks are merged."""
injector1 = HeaderInjectorLogger(headers={"x-header-a": "value-a"})
injector2 = HeaderInjectorLogger(headers={"x-header-b": "value-b"})

with patch("litellm.callbacks", [injector1, injector2]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

result = await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response=None,
)

assert injector1.called is True
assert injector2.called is True
assert result == {"x-header-a": "value-a", "x-header-b": "value-b"}


@pytest.mark.asyncio
async def test_response_headers_hook_later_callback_overrides():
"""Test that later callbacks override earlier ones for the same header key."""
injector1 = HeaderInjectorLogger(headers={"x-request-id": "first"})
injector2 = HeaderInjectorLogger(headers={"x-request-id": "second"})

with patch("litellm.callbacks", [injector1, injector2]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

result = await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response=None,
)

assert result == {"x-request-id": "second"}


@pytest.mark.asyncio
async def test_response_headers_hook_receives_response_on_success():
"""Test that the hook receives the response object on success."""
injector = HeaderInjectorLogger(headers={"x-ok": "1"})
mock_response = {"id": "resp-success", "choices": []}

with patch("litellm.callbacks", [injector]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response=mock_response,
)

assert injector.received_response is mock_response


@pytest.mark.asyncio
async def test_response_headers_hook_receives_none_response_on_failure():
"""Test that the hook receives None response for failure cases."""
injector = HeaderInjectorLogger(headers={"x-error-id": "err-1"})

with patch("litellm.callbacks", [injector]):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response=None,
)

assert injector.received_response is None


@pytest.mark.asyncio
async def test_response_headers_hook_no_callbacks():
"""Test that no callbacks results in empty headers."""
with patch("litellm.callbacks", []):
from litellm.proxy.utils import ProxyLogging
from litellm.caching.caching import DualCache

proxy_logging = ProxyLogging(user_api_key_cache=DualCache())

result = await proxy_logging.post_call_response_headers_hook(
data={"model": "test-model"},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response=None,
)

assert result == {}


@pytest.mark.asyncio
async def test_default_hook_returns_none():
"""Test that the base CustomLogger hook returns None by default."""
logger = CustomLogger()
result = await logger.async_post_call_response_headers_hook(
data={},
user_api_key_dict=UserAPIKeyAuth(api_key="test-key"),
response=None,
)
assert result is None
Loading