Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 31 additions & 86 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from litellm.proxy.route_llm_request import route_request
from litellm.proxy.utils import ProxyLogging
from litellm.router import Router
from litellm.types.utils import ServerToolUse
from litellm.types.utils import ServerToolUse, LlmProvidersSet

# Type alias for streaming chunk serializer (chunk after hooks + cost injection -> wire format)
StreamChunkSerializer = Callable[[Any], str]
Expand Down Expand Up @@ -248,85 +248,44 @@ async def combined_generator() -> AsyncGenerator[str, None]:
def _override_openai_response_model(
*,
response_obj: Any,
requested_model: str,
log_context: str,
) -> None:
"""
Force the OpenAI-compatible `model` field in the response to match what the client requested.
Strip known LiteLLM provider prefixes (e.g. hosted_vllm/) from the response model field.

LiteLLM internally prefixes some provider/deployment model identifiers (e.g. `hosted_vllm/...`).
That internal identifier should not be returned to clients in the OpenAI `model` field.

Note: This is intentionally verbose. A model mismatch is a useful signal that an internal
model identifier is being stamped/preserved somewhere in the request/response pipeline.
We log mismatches as warnings (and then restamp to the client-requested value) so these
paths stay observable for maintainers/operators without breaking client compatibility.

Errors are reserved for cases where the proxy cannot read/override the response model field.

Exception: If a fallback occurred (indicated by x-litellm-attempted-fallbacks header),
we should preserve the actual model that was used (the fallback model) rather than
overriding it with the originally requested model.
Previously this replaced response.model with the client-requested alias, but that
hid the actual model name from callers (see #21665). Now we only strip internal
provider routing prefixes, preserving the real model name.
"""
if not requested_model:
if isinstance(response_obj, dict):
downstream_model = response_obj.get("model")
elif hasattr(response_obj, "model"):
downstream_model = getattr(response_obj, "model", None)
else:
return

# Check if a fallback occurred - if so, preserve the actual model used
hidden_params = getattr(response_obj, "_hidden_params", {}) or {}
if isinstance(hidden_params, dict):
fallback_headers = hidden_params.get("additional_headers", {}) or {}
attempted_fallbacks = fallback_headers.get(
"x-litellm-attempted-fallbacks", None
)
if attempted_fallbacks is not None and attempted_fallbacks > 0:
# A fallback occurred - preserve the actual model that was used
verbose_proxy_logger.debug(
"%s: fallback detected (attempted_fallbacks=%d), preserving actual model used instead of overriding to requested model.",
log_context,
attempted_fallbacks,
)
return
if not downstream_model or not isinstance(downstream_model, str):
return

if isinstance(response_obj, dict):
downstream_model = response_obj.get("model")
if downstream_model != requested_model:
verbose_proxy_logger.debug(
"%s: response model mismatch - requested=%r downstream=%r. Overriding response['model'] to requested model.",
log_context,
requested_model,
downstream_model,
)
response_obj["model"] = requested_model
if "/" not in downstream_model:
return

if not hasattr(response_obj, "model"):
verbose_proxy_logger.error(
"%s: cannot override response model; missing `model` attribute. response_type=%s",
log_context,
type(response_obj),
)
prefix = downstream_model.split("/", 1)[0]
if prefix not in LlmProvidersSet:
return

downstream_model = getattr(response_obj, "model", None)
if downstream_model != requested_model:
verbose_proxy_logger.debug(
"%s: response model mismatch - requested=%r downstream=%r. Overriding response.model to requested model.",
log_context,
requested_model,
downstream_model,
)
stripped = downstream_model.split("/", 1)[1]

try:
setattr(response_obj, "model", requested_model)
except Exception as e:
verbose_proxy_logger.error(
"%s: failed to override response.model=%r on response_type=%s. error=%s",
log_context,
requested_model,
type(response_obj),
str(e),
exc_info=True,
)
if isinstance(response_obj, dict):
response_obj["model"] = stripped
else:
try:
setattr(response_obj, "model", stripped)
except Exception as e:
verbose_proxy_logger.debug(
"%s: failed to strip provider prefix on response.model, error=%s",
log_context, str(e),
)


def _get_cost_breakdown_from_logging_obj(
Expand Down Expand Up @@ -809,9 +768,6 @@ async def base_process_llm_request(
"""
Common request processing logic for both chat completions and responses API endpoints
"""
requested_model_from_client: Optional[str] = (
self.data.get("model") if isinstance(self.data.get("model"), str) else None
)
self._debug_log_request_payload()

self.data, logging_obj = await self.common_processing_pre_call_logic(
Expand Down Expand Up @@ -918,14 +874,6 @@ async def base_process_llm_request(
if callback_headers:
custom_headers.update(callback_headers)

# Preserve the original client-requested model (pre-alias mapping) for downstream
# streaming generators. Pre-call processing can rewrite `self.data["model"]` for
# aliasing/routing, but the OpenAI-compatible response `model` field should reflect
# what the client sent.
if requested_model_from_client:
self.data[
"_litellm_client_requested_model"
] = requested_model_from_client
if route_type == "allm_passthrough_route":
# Check if response is an async generator
if self._is_streaming_response(response):
Expand Down Expand Up @@ -985,14 +933,11 @@ async def base_process_llm_request(
data=self.data, user_api_key_dict=user_api_key_dict, response=response
)

# Always return the client-requested model name (not provider-prefixed internal identifiers)
# for OpenAI-compatible responses.
if requested_model_from_client:
_override_openai_response_model(
response_obj=response,
requested_model=requested_model_from_client,
log_context=f"litellm_call_id={logging_obj.litellm_call_id}",
)
# Strip any internal provider prefixes from the response model field.
_override_openai_response_model(
response_obj=response,
log_context=f"litellm_call_id={logging_obj.litellm_call_id}",
)

hidden_params = (
getattr(response, "_hidden_params", {}) or {}
Expand Down
71 changes: 25 additions & 46 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
ModelResponseStream,
TextCompletionResponse,
TokenCountResponse,
LlmProvidersSet,
)
from litellm.utils import (
_invalidate_model_cost_lowercase_map,
Expand Down Expand Up @@ -5229,64 +5230,46 @@ async def async_assistants_data_generator(
yield f"data: {error_returned}\n\n"


def _get_client_requested_model_for_streaming(request_data: dict) -> str:
"""
Prefer the original client-requested model (pre-alias mapping) when available.

Pre-call processing can rewrite `request_data["model"]` for aliasing/routing purposes.
The OpenAI-compatible public `model` field should reflect what the client sent.
"""
requested_model = request_data.get("_litellm_client_requested_model")
if isinstance(requested_model, str):
return requested_model

requested_model = request_data.get("model")
return requested_model if isinstance(requested_model, str) else ""


def _restamp_streaming_chunk_model(
*,
chunk: Any,
requested_model_from_client: str,
request_data: dict,
model_mismatch_logged: bool,
) -> Tuple[Any, bool]:
# Always return the client-requested model name (not provider-prefixed internal identifiers)
# on streaming chunks.
#
# Note: This warning is intentionally verbose. A mismatch is a useful signal that an
# internal provider/deployment identifier is leaking into the public API, and helps
# maintainers/operators catch regressions while preserving OpenAI-compatible output.
if not requested_model_from_client or not isinstance(chunk, (BaseModel, dict)):
"""Strip known provider prefixes from streaming chunk model field."""
if not isinstance(chunk, (BaseModel, dict)):
return chunk, model_mismatch_logged

downstream_model = (
chunk.get("model") if isinstance(chunk, dict) else getattr(chunk, "model", None)
)
if not model_mismatch_logged and downstream_model != requested_model_from_client:

if not downstream_model or not isinstance(downstream_model, str) or "/" not in downstream_model:
return chunk, model_mismatch_logged

prefix = downstream_model.split("/", 1)[0]
if prefix not in LlmProvidersSet:
return chunk, model_mismatch_logged

stripped = downstream_model.split("/", 1)[1]

if not model_mismatch_logged:
verbose_proxy_logger.debug(
"litellm_call_id=%s: streaming chunk model mismatch - requested=%r downstream=%r. Overriding model to requested.",
request_data.get("litellm_call_id"),
requested_model_from_client,
downstream_model,
"litellm_call_id=%s: stripping provider prefix %r from chunk model %r",
request_data.get("litellm_call_id"), prefix, downstream_model,
)
model_mismatch_logged = True

if isinstance(chunk, dict):
chunk["model"] = requested_model_from_client
return chunk, model_mismatch_logged

try:
setattr(chunk, "model", requested_model_from_client)
except Exception as e:
verbose_proxy_logger.error(
"litellm_call_id=%s: failed to override chunk.model=%r on chunk_type=%s. error=%s",
request_data.get("litellm_call_id"),
requested_model_from_client,
type(chunk),
str(e),
exc_info=True,
)
chunk["model"] = stripped
else:
try:
setattr(chunk, "model", stripped)
except Exception as e:
verbose_proxy_logger.debug(
"litellm_call_id=%s: failed to strip provider prefix on chunk.model, error=%s",
request_data.get("litellm_call_id"), str(e),
)

return chunk, model_mismatch_logged

Expand All @@ -5299,9 +5282,6 @@ async def async_data_generator(
# Use a list to accumulate response segments to avoid O(n^2) string concatenation
str_so_far_parts: list[str] = []
error_message: Optional[str] = None
requested_model_from_client = _get_client_requested_model_for_streaming(
request_data=request_data
)
model_mismatch_logged = False
async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict,
Expand All @@ -5322,7 +5302,6 @@ async def async_data_generator(

chunk, model_mismatch_logged = _restamp_streaming_chunk_model(
chunk=chunk,
requested_model_from_client=requested_model_from_client,
request_data=request_data,
model_mismatch_logged=model_mismatch_logged,
)
Expand Down
Loading
Loading