Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@
from litellm.types.utils import ModelResponse, ModelResponseStream, Usage


def _get_hidden_params(response: Any) -> dict:
"""Extract _hidden_params from a response object or dict."""
if isinstance(response, dict):
return response.get("_hidden_params", {}) or {}
return getattr(response, "_hidden_params", {}) or {}


async def _parse_event_data_for_error(event_line: Union[str, bytes]) -> Optional[int]:
"""Parses an event line and returns an error code if present, else None."""
event_line = (
Expand Down Expand Up @@ -273,7 +280,7 @@ def _override_openai_response_model(
return

# Check if a fallback occurred - if so, preserve the actual model used
hidden_params = getattr(response_obj, "_hidden_params", {}) or {}
hidden_params = _get_hidden_params(response_obj)
if isinstance(hidden_params, dict):
fallback_headers = hidden_params.get("additional_headers", {}) or {}
attempted_fallbacks = fallback_headers.get(
Expand Down Expand Up @@ -873,7 +880,7 @@ async def base_process_llm_request(

response = responses[1]

hidden_params = getattr(response, "_hidden_params", {}) or {}
hidden_params = _get_hidden_params(response)
model_id = self._get_model_id_from_response(hidden_params, self.data)

cache_key, api_base, response_cost = (
Expand Down Expand Up @@ -1000,9 +1007,7 @@ async def base_process_llm_request(
log_context=f"litellm_call_id={logging_obj.litellm_call_id}",
)

hidden_params = (
getattr(response, "_hidden_params", {}) or {}
) # get any updated response headers
hidden_params = _get_hidden_params(response) # get any updated response headers
additional_headers = hidden_params.get("additional_headers", {}) or {}

fastapi_response.headers.update(
Expand Down
40 changes: 40 additions & 0 deletions tests/test_litellm/proxy/test_common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ProxyConfig,
_extract_error_from_sse_chunk,
_get_cost_breakdown_from_logging_obj,
_get_hidden_params,
_override_openai_response_model,
_parse_event_data_for_error,
create_response,
Expand Down Expand Up @@ -1622,3 +1623,42 @@ def test_only_model_tagged_when_no_key_info(self):
)

mock_set_tag.assert_called_once_with("litellm.requested_model", "claude-3-5-sonnet")


class TestGetHiddenParams:
"""Tests for _get_hidden_params helper that extracts _hidden_params from response objects or dicts."""

def test_dict_response_with_hidden_params(self):
response = {"_hidden_params": {"api_base": "http://example.com", "additional_headers": {"x-foo": "bar"}}}
result = _get_hidden_params(response)
assert result == {"api_base": "http://example.com", "additional_headers": {"x-foo": "bar"}}

def test_dict_response_without_hidden_params(self):
response = {"some_key": "value"}
result = _get_hidden_params(response)
assert result == {}

def test_dict_response_with_none_hidden_params(self):
response = {"_hidden_params": None}
result = _get_hidden_params(response)
assert result == {}

def test_object_response_with_hidden_params(self):
response = MagicMock()
response._hidden_params = {"cache_key": "abc"}
result = _get_hidden_params(response)
assert result == {"cache_key": "abc"}

def test_object_response_without_hidden_params(self):
class NoParams:
pass
result = _get_hidden_params(NoParams())
assert result == {}

def test_none_response(self):
result = _get_hidden_params(None)
assert result == {}

def test_empty_dict_response(self):
result = _get_hidden_params({})
assert result == {}
Loading