diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index ce39ecf52dc..d801461bb99 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -62,6 +62,13 @@ from litellm.types.utils import ModelResponse, ModelResponseStream, Usage +def _get_hidden_params(response: Any) -> dict: + """Extract _hidden_params from a response object or dict.""" + if isinstance(response, dict): + return response.get("_hidden_params", {}) or {} + return getattr(response, "_hidden_params", {}) or {} + + async def _parse_event_data_for_error(event_line: Union[str, bytes]) -> Optional[int]: """Parses an event line and returns an error code if present, else None.""" event_line = ( @@ -273,7 +280,7 @@ def _override_openai_response_model( return # Check if a fallback occurred - if so, preserve the actual model used - hidden_params = getattr(response_obj, "_hidden_params", {}) or {} + hidden_params = _get_hidden_params(response_obj) if isinstance(hidden_params, dict): fallback_headers = hidden_params.get("additional_headers", {}) or {} attempted_fallbacks = fallback_headers.get( @@ -873,7 +880,7 @@ async def base_process_llm_request( response = responses[1] - hidden_params = getattr(response, "_hidden_params", {}) or {} + hidden_params = _get_hidden_params(response) model_id = self._get_model_id_from_response(hidden_params, self.data) cache_key, api_base, response_cost = ( @@ -1000,9 +1007,7 @@ async def base_process_llm_request( log_context=f"litellm_call_id={logging_obj.litellm_call_id}", ) - hidden_params = ( - getattr(response, "_hidden_params", {}) or {} - ) # get any updated response headers + hidden_params = _get_hidden_params(response) # get any updated response headers additional_headers = hidden_params.get("additional_headers", {}) or {} fastapi_response.headers.update( diff --git a/tests/test_litellm/proxy/test_common_request_processing.py b/tests/test_litellm/proxy/test_common_request_processing.py index ba1084eafe0..d120c0bfe1b 100644 --- a/tests/test_litellm/proxy/test_common_request_processing.py +++ b/tests/test_litellm/proxy/test_common_request_processing.py @@ -15,6 +15,7 @@ ProxyConfig, _extract_error_from_sse_chunk, _get_cost_breakdown_from_logging_obj, + _get_hidden_params, _override_openai_response_model, _parse_event_data_for_error, create_response, @@ -1622,3 +1623,42 @@ def test_only_model_tagged_when_no_key_info(self): ) mock_set_tag.assert_called_once_with("litellm.requested_model", "claude-3-5-sonnet") + + +class TestGetHiddenParams: + """Tests for _get_hidden_params helper that extracts _hidden_params from response objects or dicts.""" + + def test_dict_response_with_hidden_params(self): + response = {"_hidden_params": {"api_base": "http://example.com", "additional_headers": {"x-foo": "bar"}}} + result = _get_hidden_params(response) + assert result == {"api_base": "http://example.com", "additional_headers": {"x-foo": "bar"}} + + def test_dict_response_without_hidden_params(self): + response = {"some_key": "value"} + result = _get_hidden_params(response) + assert result == {} + + def test_dict_response_with_none_hidden_params(self): + response = {"_hidden_params": None} + result = _get_hidden_params(response) + assert result == {} + + def test_object_response_with_hidden_params(self): + response = MagicMock() + response._hidden_params = {"cache_key": "abc"} + result = _get_hidden_params(response) + assert result == {"cache_key": "abc"} + + def test_object_response_without_hidden_params(self): + class NoParams: + pass + result = _get_hidden_params(NoParams()) + assert result == {} + + def test_none_response(self): + result = _get_hidden_params(None) + assert result == {} + + def test_empty_dict_response(self): + result = _get_hidden_params({}) + assert result == {}