Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 91 additions & 6 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,70 @@ async def combined_generator() -> AsyncGenerator[str, None]:
)


def _override_openai_response_model(
*,
response_obj: Any,
requested_model: str,
log_context: str,
) -> None:
"""
Force the OpenAI-compatible `model` field in the response to match what the client requested.

LiteLLM internally prefixes some provider/deployment model identifiers (e.g. `hosted_vllm/...`).
That internal identifier should not be returned to clients in the OpenAI `model` field.

Note: This is intentionally verbose. A model mismatch is a useful signal that an internal
model identifier is being stamped/preserved somewhere in the request/response pipeline.
We log mismatches as warnings (and then restamp to the client-requested value) so these
paths stay observable for maintainers/operators without breaking client compatibility.

Errors are reserved for cases where the proxy cannot read/override the response model field.
"""
if not requested_model:
return

if isinstance(response_obj, dict):
downstream_model = response_obj.get("model")
if downstream_model != requested_model:
verbose_proxy_logger.warning(
"%s: response model mismatch - requested=%r downstream=%r. Overriding response['model'] to requested model.",
log_context,
requested_model,
downstream_model,
)
response_obj["model"] = requested_model
return

if not hasattr(response_obj, "model"):
verbose_proxy_logger.error(
"%s: cannot override response model; missing `model` attribute. response_type=%s",
log_context,
type(response_obj),
)
return

downstream_model = getattr(response_obj, "model", None)
if downstream_model != requested_model:
verbose_proxy_logger.warning(
"%s: response model mismatch - requested=%r downstream=%r. Overriding response.model to requested model.",
log_context,
requested_model,
downstream_model,
)

try:
setattr(response_obj, "model", requested_model)
except Exception as e:
verbose_proxy_logger.error(
"%s: failed to override response.model=%r on response_type=%s. error=%s",
log_context,
requested_model,
type(response_obj),
str(e),
exc_info=True,
)


def _get_cost_breakdown_from_logging_obj(
litellm_logging_obj: Optional[LiteLLMLoggingObj],
) -> Tuple[Optional[float], Optional[float], Optional[float], Optional[float]]:
Expand Down Expand Up @@ -625,6 +689,9 @@ async def base_process_llm_request(
"""
Common request processing logic for both chat completions and responses API endpoints
"""
requested_model_from_client: Optional[str] = (
self.data.get("model") if isinstance(self.data.get("model"), str) else None
)
if verbose_proxy_logger.isEnabledFor(logging.DEBUG):
verbose_proxy_logger.debug(
"Request received by LiteLLM:\n{}".format(
Expand Down Expand Up @@ -690,13 +757,15 @@ async def base_process_llm_request(
model_info = litellm_metadata.get("model_info", {}) or {}
model_id = model_info.get("id", "") or ""

cache_key = hidden_params.get("cache_key", None) or ""
api_base = hidden_params.get("api_base", None) or ""
response_cost = hidden_params.get("response_cost", None) or ""
fastest_response_batch_completion = hidden_params.get(
"fastest_response_batch_completion", None
cache_key, api_base, response_cost = (
hidden_params.get("cache_key", None) or "",
hidden_params.get("api_base", None) or "",
hidden_params.get("response_cost", None) or "",
)
fastest_response_batch_completion, additional_headers = (
hidden_params.get("fastest_response_batch_completion", None),
hidden_params.get("additional_headers", {}) or {},
)
additional_headers: dict = hidden_params.get("additional_headers", {}) or {}

# Post Call Processing
if llm_router is not None:
Expand Down Expand Up @@ -726,6 +795,13 @@ async def base_process_llm_request(
litellm_logging_obj=logging_obj,
**additional_headers,
)

# Preserve the original client-requested model (pre-alias mapping) for downstream
# streaming generators. Pre-call processing can rewrite `self.data["model"]` for
# aliasing/routing, but the OpenAI-compatible response `model` field should reflect
# what the client sent.
if requested_model_from_client:
self.data["_litellm_client_requested_model"] = requested_model_from_client
if route_type == "allm_passthrough_route":
# Check if response is an async generator
if self._is_streaming_response(response):
Expand Down Expand Up @@ -785,6 +861,15 @@ async def base_process_llm_request(
data=self.data, user_api_key_dict=user_api_key_dict, response=response
)

# Always return the client-requested model name (not provider-prefixed internal identifiers)
# for OpenAI-compatible responses.
if requested_model_from_client:
_override_openai_response_model(
response_obj=response,
requested_model=requested_model_from_client,
log_context=f"litellm_call_id={logging_obj.litellm_call_id}",
)

hidden_params = (
getattr(response, "_hidden_params", {}) or {}
) # get any updated response headers
Expand Down
73 changes: 73 additions & 0 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4571,6 +4571,68 @@ async def async_assistants_data_generator(
yield f"data: {error_returned}\n\n"


def _get_client_requested_model_for_streaming(request_data: dict) -> str:
"""
Prefer the original client-requested model (pre-alias mapping) when available.

Pre-call processing can rewrite `request_data["model"]` for aliasing/routing purposes.
The OpenAI-compatible public `model` field should reflect what the client sent.
"""
requested_model = request_data.get("_litellm_client_requested_model")
if isinstance(requested_model, str):
return requested_model

requested_model = request_data.get("model")
return requested_model if isinstance(requested_model, str) else ""


def _restamp_streaming_chunk_model(
*,
chunk: Any,
requested_model_from_client: str,
request_data: dict,
model_mismatch_logged: bool,
) -> Tuple[Any, bool]:
# Always return the client-requested model name (not provider-prefixed internal identifiers)
# on streaming chunks.
#
# Note: This warning is intentionally verbose. A mismatch is a useful signal that an
# internal provider/deployment identifier is leaking into the public API, and helps
# maintainers/operators catch regressions while preserving OpenAI-compatible output.
if not requested_model_from_client or not isinstance(chunk, (BaseModel, dict)):
return chunk, model_mismatch_logged

downstream_model = (
chunk.get("model") if isinstance(chunk, dict) else getattr(chunk, "model", None)
)
if not model_mismatch_logged and downstream_model != requested_model_from_client:
verbose_proxy_logger.warning(
"litellm_call_id=%s: streaming chunk model mismatch - requested=%r downstream=%r. Overriding model to requested.",
request_data.get("litellm_call_id"),
requested_model_from_client,
downstream_model,
)
model_mismatch_logged = True

if isinstance(chunk, dict):
chunk["model"] = requested_model_from_client
return chunk, model_mismatch_logged

try:
setattr(chunk, "model", requested_model_from_client)
except Exception as e:
verbose_proxy_logger.error(
"litellm_call_id=%s: failed to override chunk.model=%r on chunk_type=%s. error=%s",
request_data.get("litellm_call_id"),
requested_model_from_client,
type(chunk),
str(e),
exc_info=True,
)

return chunk, model_mismatch_logged


async def async_data_generator(
response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
):
Expand All @@ -4579,6 +4641,10 @@ async def async_data_generator(
# Use a list to accumulate response segments to avoid O(n^2) string concatenation
str_so_far_parts: list[str] = []
error_message: Optional[str] = None
requested_model_from_client = _get_client_requested_model_for_streaming(
request_data=request_data
)
model_mismatch_logged = False
async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict,
response=response,
Expand All @@ -4600,6 +4666,13 @@ async def async_data_generator(
response_str = litellm.get_response_string(response_obj=chunk)
str_so_far_parts.append(response_str)

chunk, model_mismatch_logged = _restamp_streaming_chunk_model(
chunk=chunk,
requested_model_from_client=requested_model_from_client,
request_data=request_data,
model_mismatch_logged=model_mismatch_logged,
)

if isinstance(chunk, BaseModel):
chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
elif isinstance(chunk, str) and chunk.startswith("data: "):
Expand Down
Loading
Loading