diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 0d3e61b75c7..7bccc5745bc 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -237,6 +237,70 @@ async def combined_generator() -> AsyncGenerator[str, None]: ) +def _override_openai_response_model( + *, + response_obj: Any, + requested_model: str, + log_context: str, +) -> None: + """ + Force the OpenAI-compatible `model` field in the response to match what the client requested. + + LiteLLM internally prefixes some provider/deployment model identifiers (e.g. `hosted_vllm/...`). + That internal identifier should not be returned to clients in the OpenAI `model` field. + + Note: This is intentionally verbose. A model mismatch is a useful signal that an internal + model identifier is being stamped/preserved somewhere in the request/response pipeline. + We log mismatches as warnings (and then restamp to the client-requested value) so these + paths stay observable for maintainers/operators without breaking client compatibility. + + Errors are reserved for cases where the proxy cannot read/override the response model field. + """ + if not requested_model: + return + + if isinstance(response_obj, dict): + downstream_model = response_obj.get("model") + if downstream_model != requested_model: + verbose_proxy_logger.warning( + "%s: response model mismatch - requested=%r downstream=%r. Overriding response['model'] to requested model.", + log_context, + requested_model, + downstream_model, + ) + response_obj["model"] = requested_model + return + + if not hasattr(response_obj, "model"): + verbose_proxy_logger.error( + "%s: cannot override response model; missing `model` attribute. response_type=%s", + log_context, + type(response_obj), + ) + return + + downstream_model = getattr(response_obj, "model", None) + if downstream_model != requested_model: + verbose_proxy_logger.warning( + "%s: response model mismatch - requested=%r downstream=%r. Overriding response.model to requested model.", + log_context, + requested_model, + downstream_model, + ) + + try: + setattr(response_obj, "model", requested_model) + except Exception as e: + verbose_proxy_logger.error( + "%s: failed to override response.model=%r on response_type=%s. error=%s", + log_context, + requested_model, + type(response_obj), + str(e), + exc_info=True, + ) + + def _get_cost_breakdown_from_logging_obj( litellm_logging_obj: Optional[LiteLLMLoggingObj], ) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float]]: @@ -625,6 +689,9 @@ async def base_process_llm_request( """ Common request processing logic for both chat completions and responses API endpoints """ + requested_model_from_client: Optional[str] = ( + self.data.get("model") if isinstance(self.data.get("model"), str) else None + ) if verbose_proxy_logger.isEnabledFor(logging.DEBUG): verbose_proxy_logger.debug( "Request received by LiteLLM:\n{}".format( @@ -690,13 +757,15 @@ async def base_process_llm_request( model_info = litellm_metadata.get("model_info", {}) or {} model_id = model_info.get("id", "") or "" - cache_key = hidden_params.get("cache_key", None) or "" - api_base = hidden_params.get("api_base", None) or "" - response_cost = hidden_params.get("response_cost", None) or "" - fastest_response_batch_completion = hidden_params.get( - "fastest_response_batch_completion", None + cache_key, api_base, response_cost = ( + hidden_params.get("cache_key", None) or "", + hidden_params.get("api_base", None) or "", + hidden_params.get("response_cost", None) or "", + ) + fastest_response_batch_completion, additional_headers = ( + hidden_params.get("fastest_response_batch_completion", None), + hidden_params.get("additional_headers", {}) or {}, ) - additional_headers: dict = hidden_params.get("additional_headers", {}) or {} # Post Call Processing if llm_router is not None: @@ -726,6 +795,13 @@ async def base_process_llm_request( litellm_logging_obj=logging_obj, **additional_headers, ) + + # Preserve the original client-requested model (pre-alias mapping) for downstream + # streaming generators. Pre-call processing can rewrite `self.data["model"]` for + # aliasing/routing, but the OpenAI-compatible response `model` field should reflect + # what the client sent. + if requested_model_from_client: + self.data["_litellm_client_requested_model"] = requested_model_from_client if route_type == "allm_passthrough_route": # Check if response is an async generator if self._is_streaming_response(response): @@ -785,6 +861,15 @@ async def base_process_llm_request( data=self.data, user_api_key_dict=user_api_key_dict, response=response ) + # Always return the client-requested model name (not provider-prefixed internal identifiers) + # for OpenAI-compatible responses. + if requested_model_from_client: + _override_openai_response_model( + response_obj=response, + requested_model=requested_model_from_client, + log_context=f"litellm_call_id={logging_obj.litellm_call_id}", + ) + hidden_params = ( getattr(response, "_hidden_params", {}) or {} ) # get any updated response headers diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4fa58e9d244..180f1c4c83d 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4571,6 +4571,68 @@ async def async_assistants_data_generator( yield f"data: {error_returned}\n\n" +def _get_client_requested_model_for_streaming(request_data: dict) -> str: + """ + Prefer the original client-requested model (pre-alias mapping) when available. + + Pre-call processing can rewrite `request_data["model"]` for aliasing/routing purposes. + The OpenAI-compatible public `model` field should reflect what the client sent. + """ + requested_model = request_data.get("_litellm_client_requested_model") + if isinstance(requested_model, str): + return requested_model + + requested_model = request_data.get("model") + return requested_model if isinstance(requested_model, str) else "" + + +def _restamp_streaming_chunk_model( + *, + chunk: Any, + requested_model_from_client: str, + request_data: dict, + model_mismatch_logged: bool, +) -> Tuple[Any, bool]: + # Always return the client-requested model name (not provider-prefixed internal identifiers) + # on streaming chunks. + # + # Note: This warning is intentionally verbose. A mismatch is a useful signal that an + # internal provider/deployment identifier is leaking into the public API, and helps + # maintainers/operators catch regressions while preserving OpenAI-compatible output. + if not requested_model_from_client or not isinstance(chunk, (BaseModel, dict)): + return chunk, model_mismatch_logged + + downstream_model = ( + chunk.get("model") if isinstance(chunk, dict) else getattr(chunk, "model", None) + ) + if not model_mismatch_logged and downstream_model != requested_model_from_client: + verbose_proxy_logger.warning( + "litellm_call_id=%s: streaming chunk model mismatch - requested=%r downstream=%r. Overriding model to requested.", + request_data.get("litellm_call_id"), + requested_model_from_client, + downstream_model, + ) + model_mismatch_logged = True + + if isinstance(chunk, dict): + chunk["model"] = requested_model_from_client + return chunk, model_mismatch_logged + + try: + setattr(chunk, "model", requested_model_from_client) + except Exception as e: + verbose_proxy_logger.error( + "litellm_call_id=%s: failed to override chunk.model=%r on chunk_type=%s. error=%s", + request_data.get("litellm_call_id"), + requested_model_from_client, + type(chunk), + str(e), + exc_info=True, + ) + + return chunk, model_mismatch_logged + + async def async_data_generator( response, user_api_key_dict: UserAPIKeyAuth, request_data: dict ): @@ -4579,6 +4641,10 @@ async def async_data_generator( # Use a list to accumulate response segments to avoid O(n^2) string concatenation str_so_far_parts: list[str] = [] error_message: Optional[str] = None + requested_model_from_client = _get_client_requested_model_for_streaming( + request_data=request_data + ) + model_mismatch_logged = False async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook( user_api_key_dict=user_api_key_dict, response=response, @@ -4600,6 +4666,13 @@ async def async_data_generator( response_str = litellm.get_response_string(response_obj=chunk) str_so_far_parts.append(response_str) + chunk, model_mismatch_logged = _restamp_streaming_chunk_model( + chunk=chunk, + requested_model_from_client=requested_model_from_client, + request_data=request_data, + model_mismatch_logged=model_mismatch_logged, + ) + if isinstance(chunk, BaseModel): chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True) elif isinstance(chunk, str) and chunk.startswith("data: "): diff --git a/tests/test_litellm/proxy/test_response_model_sanitization.py b/tests/test_litellm/proxy/test_response_model_sanitization.py new file mode 100644 index 00000000000..b1bb8d0ed39 --- /dev/null +++ b/tests/test_litellm/proxy/test_response_model_sanitization.py @@ -0,0 +1,217 @@ +import asyncio +import json +import os +import sys +from typing import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock + +import pytest +import yaml +from fastapi.testclient import TestClient + +sys.path.insert(0, os.path.abspath("../../..")) + +import litellm + +pytestmark = pytest.mark.flaky(condition=False) + + +def _initialize_proxy_with_config(config: dict, tmp_path) -> TestClient: + """ + Initialize the proxy server with a temporary config file and return a TestClient. + + IMPORTANT: proxy_server.initialize() mutates module-level globals. We must call + cleanup_router_config_variables() before initializing to prevent cross-test bleed. + """ + from litellm.proxy.proxy_server import app, cleanup_router_config_variables, initialize + + cleanup_router_config_variables() + + config_fp = tmp_path / "proxy_config.yaml" + config_fp.write_text(yaml.safe_dump(config)) + + asyncio.run(initialize(config=str(config_fp), debug=True)) + return TestClient(app) + + +def _make_minimal_chat_completion_response(model: str) -> litellm.ModelResponse: + response = litellm.ModelResponse() + response.model = model + response.choices[0].message.content = "hello" # type: ignore[union-attr] + response.choices[0].finish_reason = "stop" # type: ignore[union-attr] + return response + + +def _make_model_response_stream_chunk(model: str) -> litellm.ModelResponseStream: + """ + Create a minimal OpenAI-compatible chat.completion.chunk object. + """ + chunk_dict = { + "id": "chatcmpl-test", + "object": "chat.completion.chunk", + "created": 0, + "model": model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": "hi"}, + "finish_reason": None, + } + ], + } + return litellm.ModelResponseStream(**chunk_dict) + + +def test_proxy_chat_completion_does_not_return_provider_prefixed_model(tmp_path, monkeypatch): + """ + Regression test: + + - Client asks for `model="vllm-model"` (no provider prefix) + - Internal provider path uses `hosted_vllm/...` + - Proxy should not leak `hosted_vllm/` in the client-facing `model` field. + """ + client_model = "vllm-model" + internal_model = f"hosted_vllm/{client_model}" + + client = _initialize_proxy_with_config( + config={ + "general_settings": {"master_key": "sk-1234"}, + "model_list": [ + { + "model_name": client_model, + "litellm_params": {"model": internal_model}, + } + ], + }, + tmp_path=tmp_path, + ) + + # Patch router call to avoid making any real network request. + from litellm.proxy import proxy_server + + monkeypatch.setattr( + proxy_server.llm_router, # type: ignore[arg-type] + "acompletion", + AsyncMock(return_value=_make_minimal_chat_completion_response(model=internal_model)), + ) + + # Also no-op proxy logging hooks to keep this test focused and deterministic. + monkeypatch.setattr(proxy_server.proxy_logging_obj, "during_call_hook", AsyncMock(return_value=None)) + monkeypatch.setattr(proxy_server.proxy_logging_obj, "update_request_status", AsyncMock(return_value=None)) + monkeypatch.setattr(proxy_server.proxy_logging_obj, "post_call_success_hook", AsyncMock(side_effect=lambda **kwargs: kwargs["response"])) + + resp = client.post( + "/v1/chat/completions", + headers={"Authorization": "Bearer sk-1234"}, + json={"model": client_model, "messages": [{"role": "user", "content": "hi"}]}, + ) + + assert resp.status_code == 200, resp.text + body = resp.json() + assert body["model"] == client_model + assert not body["model"].startswith("hosted_vllm/") + + +@pytest.mark.asyncio +async def test_proxy_streaming_chunks_do_not_return_provider_prefixed_model(monkeypatch): + """ + Regression test for streaming: + + Even if a streaming chunk contains `model="hosted_vllm/<...>"`, the proxy SSE layer + should not leak the provider prefix to the client. + """ + client_model = "vllm-model" + internal_model = f"hosted_vllm/{client_model}" + + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy import proxy_server + + # Patch proxy_logging_obj hooks so async_data_generator yields exactly our chunk. + async def _iterator_hook( + user_api_key_dict: UserAPIKeyAuth, + response: AsyncGenerator, + request_data: dict, + ): + yield _make_model_response_stream_chunk(model=internal_model) + + monkeypatch.setattr(proxy_server.proxy_logging_obj, "async_post_call_streaming_iterator_hook", _iterator_hook) + monkeypatch.setattr( + proxy_server.proxy_logging_obj, + "async_post_call_streaming_hook", + AsyncMock(side_effect=lambda **kwargs: kwargs["response"]), + ) + + user_api_key_dict = UserAPIKeyAuth(api_key="sk-1234") + + gen = proxy_server.async_data_generator( + response=MagicMock(), + user_api_key_dict=user_api_key_dict, + request_data={"model": client_model}, + ) + + chunks = [] + async for item in gen: + chunks.append(item) + + # First chunk is expected to be JSON, last chunk is [DONE] + assert len(chunks) >= 2 + first = chunks[0] + assert first.startswith("data: ") + + payload = json.loads(first[len("data: ") :].strip()) + assert payload["model"] == client_model + assert not payload["model"].startswith("hosted_vllm/") + + +@pytest.mark.asyncio +async def test_proxy_streaming_chunks_use_client_requested_model_before_alias_mapping(monkeypatch): + """ + Regression test for alias mapping on streaming: + + - `common_processing_pre_call_logic` can rewrite `request_data["model"]` via model_alias_map / key-specific aliases. + - Non-streaming responses are restamped using the original client-requested model (captured before the rewrite). + - Streaming chunks must do the same to avoid mismatched `model` values between streaming and non-streaming. + """ + client_model_alias = "alias-model" + canonical_model = "vllm-model" + internal_model = f"hosted_vllm/{canonical_model}" + + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy import proxy_server + + async def _iterator_hook( + user_api_key_dict: UserAPIKeyAuth, + response: AsyncGenerator, + request_data: dict, + ): + yield _make_model_response_stream_chunk(model=internal_model) + + monkeypatch.setattr(proxy_server.proxy_logging_obj, "async_post_call_streaming_iterator_hook", _iterator_hook) + monkeypatch.setattr( + proxy_server.proxy_logging_obj, + "async_post_call_streaming_hook", + AsyncMock(side_effect=lambda **kwargs: kwargs["response"]), + ) + + user_api_key_dict = UserAPIKeyAuth(api_key="sk-1234") + + gen = proxy_server.async_data_generator( + response=MagicMock(), + user_api_key_dict=user_api_key_dict, + request_data={ + "model": canonical_model, + "_litellm_client_requested_model": client_model_alias, + }, + ) + + chunks = [] + async for item in gen: + chunks.append(item) + + assert len(chunks) >= 2 + first = chunks[0] + assert first.startswith("data: ") + + payload = json.loads(first[len("data: ") :].strip()) + assert payload["model"] == client_model_alias + assert not payload["model"].startswith("hosted_vllm/")