Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion litellm/litellm_core_utils/core_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# What is this?
## Helper utilities
from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union

import httpx

Expand Down Expand Up @@ -189,6 +189,58 @@ def get_litellm_metadata_from_kwargs(kwargs: dict):
return {}


_MISSING_MODEL_INFO = object()


def get_model_info_from_litellm_params(
litellm_params: Optional[dict],
) -> Dict[str, Any]:
"""
Read deployment model_info from supported litellm_params locations.

metadata.model_info takes precedence, including an explicitly empty dict.
"""
if litellm_params is None:
return {}

metadata = litellm_params.get("metadata", {}) or {}
metadata_model_info = metadata.get("model_info", _MISSING_MODEL_INFO)
if metadata_model_info is not _MISSING_MODEL_INFO:
result = metadata_model_info
else:
litellm_model_info = litellm_params.get("model_info", _MISSING_MODEL_INFO)
if litellm_model_info is not _MISSING_MODEL_INFO:
result = litellm_model_info
else:
result = (litellm_params.get("litellm_metadata", {}) or {}).get(
"model_info", {}
)

return result if isinstance(result, dict) else {}


def merge_metadata_preserving_deployment_model_info(
litellm_metadata: Optional[dict],
user_metadata: Optional[dict],
model_info: Optional[dict] = None,
) -> dict:
"""
Merge user metadata into LiteLLM metadata while preserving deployment model_info.

Router-provided deployment model_info should win over any user-supplied
metadata.model_info, but direct callers may still pass model_info top-level.
"""
merged_metadata = dict(litellm_metadata or {})
deployment_model_info = merged_metadata.pop("model_info", _MISSING_MODEL_INFO)
merged_metadata.update(user_metadata or {})
if deployment_model_info is not _MISSING_MODEL_INFO:
merged_metadata["model_info"] = deployment_model_info
elif "model_info" not in merged_metadata and model_info is not None:
merged_metadata["model_info"] = model_info

return merged_metadata


def reconstruct_model_name(
model_name: str,
custom_llm_provider: Optional[str],
Expand Down
48 changes: 31 additions & 17 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@
from litellm.integrations.deepeval.deepeval import DeepEvalLogger
from litellm.integrations.mlflow import MlflowLogger
from litellm.integrations.sqs import SQSLogger
from litellm.litellm_core_utils.core_helpers import reconstruct_model_name
from litellm.litellm_core_utils.core_helpers import (
get_model_info_from_litellm_params,
reconstruct_model_name,
)
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
StandardBuiltInToolCostTracking,
Expand Down Expand Up @@ -1648,8 +1651,12 @@ def _process_hidden_params_and_response_cost(
# handlers like Gemini/Vertex which call completion_cost directly)
pass
else:
model_info = get_model_info_from_litellm_params(
self.model_call_details.get("litellm_params", {}) or {}
)
self.model_call_details["response_cost"] = self._response_cost_calculator(
result=logging_result
result=logging_result,
router_model_id=model_info.get("id"),
)

self.model_call_details[
Expand Down Expand Up @@ -1943,12 +1950,18 @@ def success_handler( # noqa: PLR0915
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=complete_streaming_response)
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
model_info = get_model_info_from_litellm_params(
self.model_call_details.get("litellm_params", {}) or {}
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response,
router_model_id=model_info.get("id"),
)
)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details[
"standard_logging_object"
Expand Down Expand Up @@ -2468,11 +2481,14 @@ async def async_success_handler( # noqa: PLR0915
_get_base_model_from_metadata(
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response
model_info = get_model_info_from_litellm_params(
self.model_call_details.get("litellm_params", {}) or {}
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response,
router_model_id=model_info.get("id"),
)
)

verbose_logger.debug(
Expand Down Expand Up @@ -4446,10 +4462,8 @@ def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool:
if litellm_params.get(key) is not None:
return True

# Check model_info
metadata: dict = litellm_params.get("metadata", {}) or {}
model_info: dict = metadata.get("model_info", {}) or {}

# Check model_info in all supported locations.
model_info = get_model_info_from_litellm_params(litellm_params)
if model_info:
matching_keys = _CUSTOM_PRICING_KEYS & model_info.keys()
for key in matching_keys:
Expand Down
11 changes: 10 additions & 1 deletion litellm/llms/custom_httpx/llm_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
update_headers_with_filtered_beta,
)
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES
from litellm.litellm_core_utils.core_helpers import (
merge_metadata_preserving_deployment_model_info,
)
from litellm.litellm_core_utils.realtime_streaming import RealTimeStreaming
from litellm.llms.base_llm.anthropic_messages.transformation import (
BaseAnthropicMessagesConfig,
Expand Down Expand Up @@ -1884,14 +1887,20 @@ async def async_anthropic_messages_handler(
headers=headers, provider=custom_llm_provider
)

merged_metadata = merge_metadata_preserving_deployment_model_info(
litellm_metadata=kwargs.get("litellm_metadata"),
user_metadata=kwargs.get("metadata"),
model_info=kwargs.get("model_info"),
)

logging_obj.update_environment_variables(
model=model,
optional_params=dict(anthropic_messages_optional_request_params),
litellm_params={
"metadata": kwargs.get("metadata", {}),
"preset_cache_key": None,
"stream_response": {},
**anthropic_messages_optional_request_params,
"metadata": merged_metadata,
},
custom_llm_provider=custom_llm_provider,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.core_helpers import (
get_model_info_from_litellm_params,
merge_metadata_preserving_deployment_model_info,
)
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
use_custom_pricing_for_model,
)
from litellm.llms.anthropic import get_anthropic_config
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
Expand Down Expand Up @@ -118,6 +125,8 @@ def _create_anthropic_response_logging_payload(
custom_llm_provider = logging_obj.model_call_details.get(
"custom_llm_provider"
)
litellm_params = logging_obj.model_call_details.get("litellm_params", {}) or {}
model_info = get_model_info_from_litellm_params(litellm_params)

# Prepend custom_llm_provider to model if not already present
model_for_cost = model
Expand All @@ -128,10 +137,22 @@ def _create_anthropic_response_logging_payload(
completion_response=litellm_model_response,
model=model_for_cost,
custom_llm_provider=custom_llm_provider,
custom_pricing=use_custom_pricing_for_model(litellm_params),
router_model_id=model_info.get("id"),
litellm_logging_obj=logging_obj,
)

kwargs["response_cost"] = response_cost
kwargs["model"] = model
kwargs.setdefault("litellm_params", {})
raw_logging_metadata = litellm_params.get("metadata")
if raw_logging_metadata is not None:
kwargs["litellm_params"]["metadata"] = (
merge_metadata_preserving_deployment_model_info(
litellm_metadata=raw_logging_metadata,
user_metadata=kwargs["litellm_params"].get("metadata"),
)
)
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
kwargs.get("passthrough_logging_payload")
)
Expand Down
9 changes: 8 additions & 1 deletion litellm/responses/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
)
from litellm.constants import request_timeout
from litellm.litellm_core_utils.asyncify import run_async_function
from litellm.litellm_core_utils.core_helpers import (
merge_metadata_preserving_deployment_model_info,
)
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.prompt_templates.common_utils import (
update_responses_input_with_model_file_ids,
Expand Down Expand Up @@ -740,7 +743,11 @@ def responses(

# Pre Call logging - preserve metadata for custom callbacks
# When called from completion bridge (codex models), metadata is in litellm_metadata
metadata_for_callbacks = metadata or kwargs.get("litellm_metadata") or {}
metadata_for_callbacks = merge_metadata_preserving_deployment_model_info(
litellm_metadata=kwargs.get("litellm_metadata"),
user_metadata=metadata,
model_info=kwargs.get("model_info"),
)

litellm_logging_obj.update_environment_variables(
model=model,
Expand Down
Loading
Loading