diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index c8892cd26a5..6fecc7fa976 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -789,6 +789,7 @@ def completion_cost( # noqa: PLR0915 from litellm.llms.recraft.cost_calculator import ( cost_calculator as recraft_image_cost_calculator, ) + return recraft_image_cost_calculator( model=model, image_response=completion_response, @@ -797,6 +798,7 @@ def completion_cost( # noqa: PLR0915 from litellm.llms.gemini.image_generation.cost_calculator import ( cost_calculator as gemini_image_cost_calculator, ) + return gemini_image_cost_calculator( model=model, image_response=completion_response, @@ -867,7 +869,10 @@ def completion_cost( # noqa: PLR0915 from litellm.proxy._experimental.mcp_server.cost_calculator import ( MCPCostCalculator, ) - return MCPCostCalculator.calculate_mcp_tool_call_cost(litellm_logging_obj=litellm_logging_obj) + + return MCPCostCalculator.calculate_mcp_tool_call_cost( + litellm_logging_obj=litellm_logging_obj + ) # Calculate cost based on prompt_tokens, completion_tokens if ( "togethercomputer" in model @@ -1318,7 +1323,7 @@ def combine_usage_objects(usage_objects: List[Usage]) -> Usage: combined.completion_tokens_details = CompletionTokensDetails() # Check what keys exist in the model's completion_tokens_details - for attr in dir(usage.completion_tokens_details): + for attr in usage.completion_tokens_details.model_fields: if not attr.startswith("_") and not callable( getattr(usage.completion_tokens_details, attr) ): @@ -1326,7 +1331,8 @@ def combine_usage_objects(usage_objects: List[Usage]) -> Usage: combined.completion_tokens_details, attr, 0 ) new_val = getattr(usage.completion_tokens_details, attr, 0) - if new_val is not None: + + if new_val is not None and current_val is not None: setattr( combined.completion_tokens_details, attr, diff --git a/litellm/router.py b/litellm/router.py index 2673df1504b..ca724006fdf 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -23,6 +23,7 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, Callable, Dict, List, @@ -146,7 +147,7 @@ from litellm.types.utils import GenericBudgetConfigType, LiteLLMBatch from litellm.types.utils import ModelInfo from litellm.types.utils import ModelInfo as ModelMapInfo -from litellm.types.utils import StandardLoggingPayload +from litellm.types.utils import ModelResponseStream, StandardLoggingPayload, Usage from litellm.utils import ( CustomStreamWrapper, EmbeddingResponse, @@ -1092,7 +1093,7 @@ async def _acompletion_streaming_iterator( from litellm.exceptions import MidStreamFallbackError class FallbackStreamWrapper(CustomStreamWrapper): - def __init__(self, async_generator): + def __init__(self, async_generator: AsyncGenerator): # Copy attributes from the original model_response super().__init__( completion_stream=async_generator, @@ -1103,7 +1104,7 @@ def __init__(self, async_generator): self._async_generator = async_generator def __aiter__(self): - return self._async_generator + return self async def __anext__(self): return await self._async_generator.__anext__() @@ -1113,6 +1114,15 @@ async def stream_with_fallbacks(): async for item in model_response: yield item except MidStreamFallbackError as e: + from litellm.main import stream_chunk_builder + + complete_response_object = stream_chunk_builder( + chunks=model_response.chunks + ) + complete_response_object_usage = cast( + Optional[Usage], + getattr(complete_response_object, "usage", None), + ) try: # Use the router's fallback system model_group = cast(str, initial_kwargs.get("model")) @@ -1156,6 +1166,37 @@ async def stream_with_fallbacks(): # If fallback returns a streaming response, iterate over it if hasattr(fallback_response, "__aiter__"): async for fallback_item in fallback_response: # type: ignore + if ( + fallback_item + and isinstance(fallback_item, ModelResponseStream) + and hasattr(fallback_item, "usage") + ): + from litellm.cost_calculator import ( + BaseTokenUsageProcessor, + ) + + usage = cast( + Optional[Usage], + getattr(fallback_item, "usage", None), + ) + if usage is not None: + usage_objects = [usage] + else: + usage_objects = [] + + if ( + complete_response_object_usage is not None + and hasattr(complete_response_object_usage, "usage") + and complete_response_object_usage.usage is not None # type: ignore + ): + usage_objects.append(complete_response_object_usage) + + combined_usage = ( + BaseTokenUsageProcessor.combine_usage_objects( + usage_objects=usage_objects + ) + ) + setattr(fallback_item, "usage", combined_usage) yield fallback_item else: # If fallback returns a non-streaming response, yield None @@ -1166,7 +1207,7 @@ async def stream_with_fallbacks(): verbose_router_logger.error( f"Fallback also failed: {fallback_error}" ) - raise e # Re-raise the original error + raise fallback_error return FallbackStreamWrapper(stream_with_fallbacks())