Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions litellm/cost_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,19 @@ def completion_cost( # noqa: PLR0915
router_model_id=router_model_id,
)

# When base_model overrides model and carries its own provider prefix
# (e.g. base_model="gemini/gemini-2.0-flash" on an anthropic deployment),
# align custom_llm_provider so cost_per_token builds the correct key.
# Skip when custom_pricing is True (base_model is ignored in that path).
_provider_overridden = False
if base_model is not None and selected_model is not None and not custom_pricing:
_parts = selected_model.split("/", 1)
if len(_parts) > 1 and _parts[0] in LlmProvidersSet:
extracted = _parts[0]
if extracted != custom_llm_provider:
custom_llm_provider = extracted
_provider_overridden = True

potential_model_names = [
selected_model,
_get_response_model(completion_response),
Expand Down Expand Up @@ -1174,9 +1187,10 @@ def completion_cost( # noqa: PLR0915

hidden_params = getattr(completion_response, "_hidden_params", None)
if hidden_params is not None:
custom_llm_provider = hidden_params.get(
"custom_llm_provider", custom_llm_provider or None
)
if not _provider_overridden:
custom_llm_provider = hidden_params.get(
"custom_llm_provider", custom_llm_provider or None
)
region_name = hidden_params.get("region_name", region_name)

# For Gemini/Vertex AI responses, trafficType is stored in
Expand Down
74 changes: 74 additions & 0 deletions tests/local_testing/test_completion_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2945,3 +2945,77 @@ def test_batch_cost_calculator():

cost = completion_cost(**args)
assert cost > 0


def test_cost_calculator_base_model_cross_provider():
"""
When base_model has a different provider prefix than the deployment,
custom_llm_provider should be updated so cost_per_token builds the
correct model key. Regression test for #22257.
"""
resp = litellm.completion(
model="anthropic/my-custom-deployment",
messages=[{"role": "user", "content": "Hello"}],
base_model="gemini/gemini-2.0-flash",
mock_response="Hi there!",
)
assert resp._hidden_params["response_cost"] > 0


def test_cost_calculator_base_model_cross_provider_direct():
"""
Direct completion_cost unit test for cross-provider base_model override.
Verifies that completion_cost correctly routes to the base_model provider.
"""
from litellm import ModelResponse, Usage

resp = ModelResponse(
id="chatcmpl-test",
model="gemini/gemini-2.0-flash",
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
cost = completion_cost(
model="anthropic/my-custom-deployment",
completion_response=resp,
base_model="gemini/gemini-2.0-flash",
custom_llm_provider="anthropic",
)
assert cost > 0


def test_cost_calculator_base_model_cross_provider_hidden_params_guard():
"""
Verify that hidden_params.custom_llm_provider does not undo the
base_model provider override when the response carries a stale provider.
"""
from litellm import ModelResponse, Usage

resp = ModelResponse(
id="chatcmpl-guard",
model="gemini/gemini-2.0-flash",
usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30),
)
# Simulate hidden_params carrying the original (wrong) provider
resp._hidden_params = {"custom_llm_provider": "anthropic"}

cost = completion_cost(
model="anthropic/my-custom-deployment",
completion_response=resp,
base_model="gemini/gemini-2.0-flash",
custom_llm_provider="anthropic",
)
assert cost > 0


def test_cost_calculator_base_model_same_provider_no_regression():
"""
When base_model has the same provider prefix as the deployment,
custom_llm_provider should remain unchanged (no-regression).
"""
resp = litellm.completion(
model="openai/my-custom-deployment",
messages=[{"role": "user", "content": "Hello"}],
base_model="openai/gpt-4o",
mock_response="Hi there!",
)
assert resp._hidden_params["response_cost"] > 0
Loading