diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 6354bf44943..4c857f18823 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -1093,6 +1093,19 @@ def completion_cost( # noqa: PLR0915 router_model_id=router_model_id, ) + # When base_model overrides model and carries its own provider prefix + # (e.g. base_model="gemini/gemini-2.0-flash" on an anthropic deployment), + # align custom_llm_provider so cost_per_token builds the correct key. + # Skip when custom_pricing is True (base_model is ignored in that path). + _provider_overridden = False + if base_model is not None and selected_model is not None and not custom_pricing: + _parts = selected_model.split("/", 1) + if len(_parts) > 1 and _parts[0] in LlmProvidersSet: + extracted = _parts[0] + if extracted != custom_llm_provider: + custom_llm_provider = extracted + _provider_overridden = True + potential_model_names = [ selected_model, _get_response_model(completion_response), @@ -1174,9 +1187,10 @@ def completion_cost( # noqa: PLR0915 hidden_params = getattr(completion_response, "_hidden_params", None) if hidden_params is not None: - custom_llm_provider = hidden_params.get( - "custom_llm_provider", custom_llm_provider or None - ) + if not _provider_overridden: + custom_llm_provider = hidden_params.get( + "custom_llm_provider", custom_llm_provider or None + ) region_name = hidden_params.get("region_name", region_name) # For Gemini/Vertex AI responses, trafficType is stored in diff --git a/tests/local_testing/test_completion_cost.py b/tests/local_testing/test_completion_cost.py index 2f78f27361e..0485b2187d4 100644 --- a/tests/local_testing/test_completion_cost.py +++ b/tests/local_testing/test_completion_cost.py @@ -2945,3 +2945,77 @@ def test_batch_cost_calculator(): cost = completion_cost(**args) assert cost > 0 + + +def test_cost_calculator_base_model_cross_provider(): + """ + When base_model has a different provider prefix than the deployment, + custom_llm_provider should be updated so cost_per_token builds the + correct model key. Regression test for #22257. + """ + resp = litellm.completion( + model="anthropic/my-custom-deployment", + messages=[{"role": "user", "content": "Hello"}], + base_model="gemini/gemini-2.0-flash", + mock_response="Hi there!", + ) + assert resp._hidden_params["response_cost"] > 0 + + +def test_cost_calculator_base_model_cross_provider_direct(): + """ + Direct completion_cost unit test for cross-provider base_model override. + Verifies that completion_cost correctly routes to the base_model provider. + """ + from litellm import ModelResponse, Usage + + resp = ModelResponse( + id="chatcmpl-test", + model="gemini/gemini-2.0-flash", + usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + cost = completion_cost( + model="anthropic/my-custom-deployment", + completion_response=resp, + base_model="gemini/gemini-2.0-flash", + custom_llm_provider="anthropic", + ) + assert cost > 0 + + +def test_cost_calculator_base_model_cross_provider_hidden_params_guard(): + """ + Verify that hidden_params.custom_llm_provider does not undo the + base_model provider override when the response carries a stale provider. + """ + from litellm import ModelResponse, Usage + + resp = ModelResponse( + id="chatcmpl-guard", + model="gemini/gemini-2.0-flash", + usage=Usage(prompt_tokens=10, completion_tokens=20, total_tokens=30), + ) + # Simulate hidden_params carrying the original (wrong) provider + resp._hidden_params = {"custom_llm_provider": "anthropic"} + + cost = completion_cost( + model="anthropic/my-custom-deployment", + completion_response=resp, + base_model="gemini/gemini-2.0-flash", + custom_llm_provider="anthropic", + ) + assert cost > 0 + + +def test_cost_calculator_base_model_same_provider_no_regression(): + """ + When base_model has the same provider prefix as the deployment, + custom_llm_provider should remain unchanged (no-regression). + """ + resp = litellm.completion( + model="openai/my-custom-deployment", + messages=[{"role": "user", "content": "Hello"}], + base_model="openai/gpt-4o", + mock_response="Hi there!", + ) + assert resp._hidden_params["response_cost"] > 0