Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions litellm/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,7 @@ def completion_cost( # noqa: PLR0915
from litellm.llms.recraft.cost_calculator import (
cost_calculator as recraft_image_cost_calculator,
)

return recraft_image_cost_calculator(
model=model,
image_response=completion_response,
Expand All @@ -797,6 +798,7 @@ def completion_cost( # noqa: PLR0915
from litellm.llms.gemini.image_generation.cost_calculator import (
cost_calculator as gemini_image_cost_calculator,
)

return gemini_image_cost_calculator(
model=model,
image_response=completion_response,
Expand Down Expand Up @@ -867,7 +869,10 @@ def completion_cost( # noqa: PLR0915
from litellm.proxy._experimental.mcp_server.cost_calculator import (
MCPCostCalculator,
)
return MCPCostCalculator.calculate_mcp_tool_call_cost(litellm_logging_obj=litellm_logging_obj)

return MCPCostCalculator.calculate_mcp_tool_call_cost(
litellm_logging_obj=litellm_logging_obj
)
# Calculate cost based on prompt_tokens, completion_tokens
if (
"togethercomputer" in model
Expand Down Expand Up @@ -1318,15 +1323,16 @@ def combine_usage_objects(usage_objects: List[Usage]) -> Usage:
combined.completion_tokens_details = CompletionTokensDetails()

# Check what keys exist in the model's completion_tokens_details
for attr in dir(usage.completion_tokens_details):
for attr in usage.completion_tokens_details.model_fields:
if not attr.startswith("_") and not callable(
getattr(usage.completion_tokens_details, attr)
):
current_val = getattr(
combined.completion_tokens_details, attr, 0
)
new_val = getattr(usage.completion_tokens_details, attr, 0)
if new_val is not None:

if new_val is not None and current_val is not None:
setattr(
combined.completion_tokens_details,
attr,
Expand Down
49 changes: 45 additions & 4 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Callable,
Dict,
List,
Expand Down Expand Up @@ -146,7 +147,7 @@
from litellm.types.utils import GenericBudgetConfigType, LiteLLMBatch
from litellm.types.utils import ModelInfo
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.utils import StandardLoggingPayload
from litellm.types.utils import ModelResponseStream, StandardLoggingPayload, Usage
from litellm.utils import (
CustomStreamWrapper,
EmbeddingResponse,
Expand Down Expand Up @@ -1092,7 +1093,7 @@ async def _acompletion_streaming_iterator(
from litellm.exceptions import MidStreamFallbackError

class FallbackStreamWrapper(CustomStreamWrapper):
def __init__(self, async_generator):
def __init__(self, async_generator: AsyncGenerator):
# Copy attributes from the original model_response
super().__init__(
completion_stream=async_generator,
Expand All @@ -1103,7 +1104,7 @@ def __init__(self, async_generator):
self._async_generator = async_generator

def __aiter__(self):
return self._async_generator
return self

async def __anext__(self):
return await self._async_generator.__anext__()
Expand All @@ -1113,6 +1114,15 @@ async def stream_with_fallbacks():
async for item in model_response:
yield item
except MidStreamFallbackError as e:
from litellm.main import stream_chunk_builder

complete_response_object = stream_chunk_builder(
chunks=model_response.chunks
)
complete_response_object_usage = cast(
Optional[Usage],
getattr(complete_response_object, "usage", None),
)
try:
# Use the router's fallback system
model_group = cast(str, initial_kwargs.get("model"))
Expand Down Expand Up @@ -1156,6 +1166,37 @@ async def stream_with_fallbacks():
# If fallback returns a streaming response, iterate over it
if hasattr(fallback_response, "__aiter__"):
async for fallback_item in fallback_response: # type: ignore
if (
fallback_item
and isinstance(fallback_item, ModelResponseStream)
and hasattr(fallback_item, "usage")
):
from litellm.cost_calculator import (
BaseTokenUsageProcessor,
)

usage = cast(
Optional[Usage],
getattr(fallback_item, "usage", None),
)
if usage is not None:
usage_objects = [usage]
else:
usage_objects = []

if (
complete_response_object_usage is not None
and hasattr(complete_response_object_usage, "usage")
and complete_response_object_usage.usage is not None # type: ignore
):
usage_objects.append(complete_response_object_usage)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incorrect Attribute Access in Nested Object Check

The code incorrectly checks for a nested .usage attribute on complete_response_object_usage. This variable is already cast as Optional[Usage], meaning it is the usage object itself, not an object containing a usage attribute. This logical error causes the condition to likely always fail.

Locations (1)
Fix in Cursor Fix in Web


combined_usage = (
BaseTokenUsageProcessor.combine_usage_objects(
usage_objects=usage_objects
)
)
setattr(fallback_item, "usage", combined_usage)
yield fallback_item
else:
# If fallback returns a non-streaming response, yield None
Expand All @@ -1166,7 +1207,7 @@ async def stream_with_fallbacks():
verbose_router_logger.error(
f"Fallback also failed: {fallback_error}"
)
raise e # Re-raise the original error
raise fallback_error

return FallbackStreamWrapper(stream_with_fallbacks())

Expand Down
Loading