diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 12243a19184..07d237c4758 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -371,6 +371,28 @@ async def async_pre_call_hook( ]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm pass + async def async_post_call_response_headers_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + request_headers: Optional[Dict[str, str]] = None, + ) -> Optional[Dict[str, str]]: + """ + Called after an LLM API call (success or failure) to allow injecting custom HTTP response headers. + + Args: + - data: dict - The request data. + - user_api_key_dict: UserAPIKeyAuth - The user API key dictionary. + - response: Any - The response object (None for failure cases). + - request_headers: Optional[Dict[str, str]] - The original request headers. + + Returns: + - Optional[Dict[str, str]]: A dictionary of headers to inject into the HTTP response. + Return None to not inject any headers. + """ + return None + async def async_post_call_failure_hook( self, request_data: dict, diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 032c7dffbe8..136ce696511 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -804,6 +804,15 @@ async def base_process_llm_request( **additional_headers, ) + # Call response headers hook for streaming success + callback_headers = await proxy_logging_obj.post_call_response_headers_hook( + data=self.data, + user_api_key_dict=user_api_key_dict, + response=response, + ) + if callback_headers: + custom_headers.update(callback_headers) + # Preserve the original client-requested model (pre-alias mapping) for downstream # streaming generators. Pre-call processing can rewrite `self.data["model"]` for # aliasing/routing, but the OpenAI-compatible response `model` field should reflect @@ -900,6 +909,16 @@ async def base_process_llm_request( **additional_headers, ) ) + + # Call response headers hook for non-streaming success + callback_headers = await proxy_logging_obj.post_call_response_headers_hook( + data=self.data, + user_api_key_dict=user_api_key_dict, + response=response, + ) + if callback_headers: + fastapi_response.headers.update(callback_headers) + await check_response_size_is_safe(response=response) return response @@ -1058,6 +1077,18 @@ async def _handle_llm_api_exception( headers = get_response_headers(dict(_response_headers)) headers.update(custom_headers) + # Call response headers hook for failure + try: + callback_headers = await proxy_logging_obj.post_call_response_headers_hook( + data=self.data, + user_api_key_dict=user_api_key_dict, + response=None, + ) + if callback_headers: + headers.update(callback_headers) + except Exception: + pass + if isinstance(e, HTTPException): raise ProxyException( message=getattr(e, "detail", str(e)), diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8922ed032e2..6bbf0df74de 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1808,6 +1808,46 @@ async def post_call_success_hook( raise e return response + async def post_call_response_headers_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + request_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, str]: + """ + Calls async_post_call_response_headers_hook on all CustomLogger callbacks. + Merges all returned header dicts (later callbacks override earlier ones). + + Returns: + Dict[str, str]: Merged headers from all callbacks. + """ + merged_headers: Dict[str, str] = {} + try: + for callback in litellm.callbacks: + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + cast(_custom_logger_compatible_callbacks_literal, callback) + ) + else: + _callback = callback # type: ignore + + if _callback is not None and isinstance(_callback, CustomLogger): + result = await _callback.async_post_call_response_headers_hook( + data=data, + user_api_key_dict=user_api_key_dict, + response=response, + request_headers=request_headers, + ) + if result is not None: + merged_headers.update(result) + except Exception as e: + verbose_proxy_logger.exception( + "Error in post_call_response_headers_hook: %s", str(e) + ) + return merged_headers + async def async_post_call_streaming_hook( self, data: dict, diff --git a/tests/test_litellm/proxy/hooks/test_post_call_response_headers_hook.py b/tests/test_litellm/proxy/hooks/test_post_call_response_headers_hook.py new file mode 100644 index 00000000000..6a12366fdd3 --- /dev/null +++ b/tests/test_litellm/proxy/hooks/test_post_call_response_headers_hook.py @@ -0,0 +1,197 @@ +""" +Integration tests for async_post_call_response_headers_hook. + +Tests verify that CustomLogger callbacks can inject custom HTTP response headers +into success (streaming and non-streaming) and failure responses. +""" + +import os +import sys +import pytest +from typing import Any, Dict, Optional +from unittest.mock import patch + +sys.path.insert(0, os.path.abspath("../../../..")) + +import litellm +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class HeaderInjectorLogger(CustomLogger): + """Logger that injects custom headers into responses.""" + + def __init__(self, headers: Optional[Dict[str, str]] = None): + self.headers = headers + self.called = False + self.received_response = None + self.received_data = None + + async def async_post_call_response_headers_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Any, + request_headers: Optional[Dict[str, str]] = None, + ) -> Optional[Dict[str, str]]: + self.called = True + self.received_response = response + self.received_data = data + return self.headers + + +@pytest.mark.asyncio +async def test_response_headers_hook_returns_headers(): + """Test that the hook returns headers from a single callback.""" + injector = HeaderInjectorLogger(headers={"x-custom-id": "abc123"}) + + with patch("litellm.callbacks", [injector]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + result = await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response={"id": "resp-1"}, + ) + + assert injector.called is True + assert result == {"x-custom-id": "abc123"} + + +@pytest.mark.asyncio +async def test_response_headers_hook_returns_none(): + """Test that returning None results in empty headers dict.""" + injector = HeaderInjectorLogger(headers=None) + + with patch("litellm.callbacks", [injector]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + result = await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response={"id": "resp-1"}, + ) + + assert injector.called is True + assert result == {} + + +@pytest.mark.asyncio +async def test_response_headers_hook_multiple_callbacks_merge(): + """Test that headers from multiple callbacks are merged.""" + injector1 = HeaderInjectorLogger(headers={"x-header-a": "value-a"}) + injector2 = HeaderInjectorLogger(headers={"x-header-b": "value-b"}) + + with patch("litellm.callbacks", [injector1, injector2]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + result = await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response=None, + ) + + assert injector1.called is True + assert injector2.called is True + assert result == {"x-header-a": "value-a", "x-header-b": "value-b"} + + +@pytest.mark.asyncio +async def test_response_headers_hook_later_callback_overrides(): + """Test that later callbacks override earlier ones for the same header key.""" + injector1 = HeaderInjectorLogger(headers={"x-request-id": "first"}) + injector2 = HeaderInjectorLogger(headers={"x-request-id": "second"}) + + with patch("litellm.callbacks", [injector1, injector2]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + result = await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response=None, + ) + + assert result == {"x-request-id": "second"} + + +@pytest.mark.asyncio +async def test_response_headers_hook_receives_response_on_success(): + """Test that the hook receives the response object on success.""" + injector = HeaderInjectorLogger(headers={"x-ok": "1"}) + mock_response = {"id": "resp-success", "choices": []} + + with patch("litellm.callbacks", [injector]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response=mock_response, + ) + + assert injector.received_response is mock_response + + +@pytest.mark.asyncio +async def test_response_headers_hook_receives_none_response_on_failure(): + """Test that the hook receives None response for failure cases.""" + injector = HeaderInjectorLogger(headers={"x-error-id": "err-1"}) + + with patch("litellm.callbacks", [injector]): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response=None, + ) + + assert injector.received_response is None + + +@pytest.mark.asyncio +async def test_response_headers_hook_no_callbacks(): + """Test that no callbacks results in empty headers.""" + with patch("litellm.callbacks", []): + from litellm.proxy.utils import ProxyLogging + from litellm.caching.caching import DualCache + + proxy_logging = ProxyLogging(user_api_key_cache=DualCache()) + + result = await proxy_logging.post_call_response_headers_hook( + data={"model": "test-model"}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response=None, + ) + + assert result == {} + + +@pytest.mark.asyncio +async def test_default_hook_returns_none(): + """Test that the base CustomLogger hook returns None by default.""" + logger = CustomLogger() + result = await logger.async_post_call_response_headers_hook( + data={}, + user_api_key_dict=UserAPIKeyAuth(api_key="test-key"), + response=None, + ) + assert result is None