Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions litellm/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from litellm.litellm_core_utils.llm_cost_calc.utils import (
CostCalculatorUtils,
_generic_cost_per_character,
_get_service_tier_cost_key,
_parse_prompt_tokens_details,
calculate_cost_component,
generic_cost_per_token,
get_billable_input_tokens,
select_cost_metric_for_model,
)
from litellm.llms.anthropic.cost_calculation import (
Expand Down Expand Up @@ -427,12 +431,18 @@ def cost_per_token( # noqa: PLR0915
model=model, custom_llm_provider=custom_llm_provider
)

if model_info["input_cost_per_token"] > 0:
## COST PER TOKEN ##
prompt_tokens_cost_usd_dollar = (
model_info["input_cost_per_token"] * prompt_tokens
if (
model_info.get("input_cost_per_token", 0) > 0
or model_info.get("output_cost_per_token", 0) > 0
):
return generic_cost_per_token(
model=model,
usage=usage_block,
custom_llm_provider=custom_llm_provider,
service_tier=service_tier,
)
elif (

if (
model_info.get("input_cost_per_second", None) is not None
and response_time_ms is not None
):
Expand All @@ -447,11 +457,7 @@ def cost_per_token( # noqa: PLR0915
model_info["input_cost_per_second"] * response_time_ms / 1000 # type: ignore
)

if model_info["output_cost_per_token"] > 0:
completion_tokens_cost_usd_dollar = (
model_info["output_cost_per_token"] * completion_tokens
)
elif (
if (
model_info.get("output_cost_per_second", None) is not None
and response_time_ms is not None
):
Expand Down Expand Up @@ -951,7 +957,10 @@ def completion_cost( # noqa: PLR0915
router_model_id=router_model_id,
)

potential_model_names = [selected_model, _get_response_model(completion_response)]
potential_model_names = [
selected_model,
_get_response_model(completion_response),
]
if model is not None:
potential_model_names.append(model)

Expand Down Expand Up @@ -1706,10 +1715,16 @@ def default_image_cost_calculator(
)

# Priority 1: Use per-image pricing if available (for gpt-image-1 and similar models)
if "input_cost_per_image" in cost_info and cost_info["input_cost_per_image"] is not None:
if (
"input_cost_per_image" in cost_info
and cost_info["input_cost_per_image"] is not None
):
return cost_info["input_cost_per_image"] * n
# Priority 2: Fall back to per-pixel pricing for backward compatibility
elif "input_cost_per_pixel" in cost_info and cost_info["input_cost_per_pixel"] is not None:
elif (
"input_cost_per_pixel" in cost_info
and cost_info["input_cost_per_pixel"] is not None
):
return cost_info["input_cost_per_pixel"] * height * width * n
else:
raise Exception(
Expand Down Expand Up @@ -1829,9 +1844,22 @@ def batch_cost_calculator(
if input_cost_per_token_batches:
total_prompt_cost = usage.prompt_tokens * input_cost_per_token_batches
elif input_cost_per_token:
# Subtract cached tokens from prompt_tokens before calculating cost
# Fixes issue where cached tokens are being charged again
total_prompt_cost = (
usage.prompt_tokens * (input_cost_per_token) / 2
get_billable_input_tokens(usage) * (input_cost_per_token) / 2
) # batch cost is usually half of the regular token cost

# Add cache read cost if applicable
details = _parse_prompt_tokens_details(usage)
cache_read_tokens = details["cache_hit_tokens"]
cache_read_cost_key = _get_service_tier_cost_key(
"cache_read_input_token_cost", None
)
total_prompt_cost += (
calculate_cost_component(model_info, cache_read_cost_key, cache_read_tokens)
/ 2
)
if output_cost_per_token_batches:
total_completion_cost = usage.completion_tokens * output_cost_per_token_batches
elif output_cost_per_token:
Expand Down
16 changes: 14 additions & 2 deletions litellm/litellm_core_utils/llm_cost_calc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ def _is_above_128k(tokens: float) -> bool:
return False


def get_billable_input_tokens(usage: Usage) -> int:
"""
Returns the number of billable input tokens.
Subtracts cached tokens from prompt tokens if applicable.
"""
details = _parse_prompt_tokens_details(usage)
return usage.prompt_tokens - details["cache_hit_tokens"]


def select_cost_metric_for_model(
model_info: ModelInfo,
) -> Literal["cost_per_character", "cost_per_token"]:
Expand Down Expand Up @@ -190,7 +199,6 @@ def _get_token_base_cost(
1000 if "k" in threshold_str else 1
)
if usage.prompt_tokens > threshold:

prompt_base_cost = cast(
float, _get_cost_per_unit(model_info, key, prompt_base_cost)
)
Expand Down Expand Up @@ -619,7 +627,11 @@ def generic_cost_per_token( # noqa: PLR0915
# Calculate text tokens as remainder when we have a breakdown
# This handles cases like OpenAI's reasoning models where text_tokens isn't provided
text_tokens = max(
0, usage.completion_tokens - reasoning_tokens - audio_tokens - image_tokens
0,
usage.completion_tokens
- reasoning_tokens
- audio_tokens
- image_tokens,
)
else:
# No breakdown at all, all tokens are text tokens
Expand Down
Loading
Loading