diff --git a/litellm/litellm_core_utils/core_helpers.py b/litellm/litellm_core_utils/core_helpers.py index ee111f35929..e789d28f01b 100644 --- a/litellm/litellm_core_utils/core_helpers.py +++ b/litellm/litellm_core_utils/core_helpers.py @@ -1,6 +1,6 @@ # What is this? ## Helper utilities -from typing import TYPE_CHECKING, Any, Iterable, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Union import httpx @@ -189,6 +189,58 @@ def get_litellm_metadata_from_kwargs(kwargs: dict): return {} +_MISSING_MODEL_INFO = object() + + +def get_model_info_from_litellm_params( + litellm_params: Optional[dict], +) -> Dict[str, Any]: + """ + Read deployment model_info from supported litellm_params locations. + + metadata.model_info takes precedence, including an explicitly empty dict. + """ + if litellm_params is None: + return {} + + metadata = litellm_params.get("metadata", {}) or {} + metadata_model_info = metadata.get("model_info", _MISSING_MODEL_INFO) + if metadata_model_info is not _MISSING_MODEL_INFO: + result = metadata_model_info + else: + litellm_model_info = litellm_params.get("model_info", _MISSING_MODEL_INFO) + if litellm_model_info is not _MISSING_MODEL_INFO: + result = litellm_model_info + else: + result = (litellm_params.get("litellm_metadata", {}) or {}).get( + "model_info", {} + ) + + return result if isinstance(result, dict) else {} + + +def merge_metadata_preserving_deployment_model_info( + litellm_metadata: Optional[dict], + user_metadata: Optional[dict], + model_info: Optional[dict] = None, +) -> dict: + """ + Merge user metadata into LiteLLM metadata while preserving deployment model_info. + + Router-provided deployment model_info should win over any user-supplied + metadata.model_info, but direct callers may still pass model_info top-level. + """ + merged_metadata = dict(litellm_metadata or {}) + deployment_model_info = merged_metadata.pop("model_info", _MISSING_MODEL_INFO) + merged_metadata.update(user_metadata or {}) + if deployment_model_info is not _MISSING_MODEL_INFO: + merged_metadata["model_info"] = deployment_model_info + elif "model_info" not in merged_metadata and model_info is not None: + merged_metadata["model_info"] = model_info + + return merged_metadata + + def reconstruct_model_name( model_name: str, custom_llm_provider: Optional[str], diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index e22d057bb69..3ffe02dce34 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -59,7 +59,10 @@ from litellm.integrations.deepeval.deepeval import DeepEvalLogger from litellm.integrations.mlflow import MlflowLogger from litellm.integrations.sqs import SQSLogger -from litellm.litellm_core_utils.core_helpers import reconstruct_model_name +from litellm.litellm_core_utils.core_helpers import ( + get_model_info_from_litellm_params, + reconstruct_model_name, +) from litellm.litellm_core_utils.get_litellm_params import get_litellm_params from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import ( StandardBuiltInToolCostTracking, @@ -1648,8 +1651,12 @@ def _process_hidden_params_and_response_cost( # handlers like Gemini/Vertex which call completion_cost directly) pass else: + model_info = get_model_info_from_litellm_params( + self.model_call_details.get("litellm_params", {}) or {} + ) self.model_call_details["response_cost"] = self._response_cost_calculator( - result=logging_result + result=logging_result, + router_model_id=model_info.get("id"), ) self.model_call_details[ @@ -1943,12 +1950,18 @@ def success_handler( # noqa: PLR0915 verbose_logger.debug( "Logging Details LiteLLM-Success Call streaming complete" ) - self.model_call_details[ - "complete_streaming_response" - ] = complete_streaming_response - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator(result=complete_streaming_response) + self.model_call_details["complete_streaming_response"] = ( + complete_streaming_response + ) + model_info = get_model_info_from_litellm_params( + self.model_call_details.get("litellm_params", {}) or {} + ) + self.model_call_details["response_cost"] = ( + self._response_cost_calculator( + result=complete_streaming_response, + router_model_id=model_info.get("id"), + ) + ) ## STANDARDIZED LOGGING PAYLOAD self.model_call_details[ "standard_logging_object" @@ -2468,11 +2481,14 @@ async def async_success_handler( # noqa: PLR0915 _get_base_model_from_metadata( model_call_details=self.model_call_details ) - # base_model defaults to None if not set on model_info - self.model_call_details[ - "response_cost" - ] = self._response_cost_calculator( - result=complete_streaming_response + model_info = get_model_info_from_litellm_params( + self.model_call_details.get("litellm_params", {}) or {} + ) + self.model_call_details["response_cost"] = ( + self._response_cost_calculator( + result=complete_streaming_response, + router_model_id=model_info.get("id"), + ) ) verbose_logger.debug( @@ -4446,10 +4462,8 @@ def use_custom_pricing_for_model(litellm_params: Optional[dict]) -> bool: if litellm_params.get(key) is not None: return True - # Check model_info - metadata: dict = litellm_params.get("metadata", {}) or {} - model_info: dict = metadata.get("model_info", {}) or {} - + # Check model_info in all supported locations. + model_info = get_model_info_from_litellm_params(litellm_params) if model_info: matching_keys = _CUSTOM_PRICING_KEYS & model_info.keys() for key in matching_keys: diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 3e7a636640f..521eab8a29a 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -26,6 +26,9 @@ update_headers_with_filtered_beta, ) from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES +from litellm.litellm_core_utils.core_helpers import ( + merge_metadata_preserving_deployment_model_info, +) from litellm.litellm_core_utils.realtime_streaming import RealTimeStreaming from litellm.llms.base_llm.anthropic_messages.transformation import ( BaseAnthropicMessagesConfig, @@ -1884,14 +1887,20 @@ async def async_anthropic_messages_handler( headers=headers, provider=custom_llm_provider ) + merged_metadata = merge_metadata_preserving_deployment_model_info( + litellm_metadata=kwargs.get("litellm_metadata"), + user_metadata=kwargs.get("metadata"), + model_info=kwargs.get("model_info"), + ) + logging_obj.update_environment_variables( model=model, optional_params=dict(anthropic_messages_optional_request_params), litellm_params={ - "metadata": kwargs.get("metadata", {}), "preset_cache_key": None, "stream_response": {}, **anthropic_messages_optional_request_params, + "metadata": merged_metadata, }, custom_llm_provider=custom_llm_provider, ) diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py index 20d06b7d531..08aeea91791 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/anthropic_passthrough_logging_handler.py @@ -6,7 +6,14 @@ import litellm from litellm._logging import verbose_proxy_logger -from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.core_helpers import ( + get_model_info_from_litellm_params, + merge_metadata_preserving_deployment_model_info, +) +from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObj, + use_custom_pricing_for_model, +) from litellm.llms.anthropic import get_anthropic_config from litellm.llms.anthropic.chat.handler import ( ModelResponseIterator as AnthropicModelResponseIterator, @@ -118,6 +125,8 @@ def _create_anthropic_response_logging_payload( custom_llm_provider = logging_obj.model_call_details.get( "custom_llm_provider" ) + litellm_params = logging_obj.model_call_details.get("litellm_params", {}) or {} + model_info = get_model_info_from_litellm_params(litellm_params) # Prepend custom_llm_provider to model if not already present model_for_cost = model @@ -128,10 +137,22 @@ def _create_anthropic_response_logging_payload( completion_response=litellm_model_response, model=model_for_cost, custom_llm_provider=custom_llm_provider, + custom_pricing=use_custom_pricing_for_model(litellm_params), + router_model_id=model_info.get("id"), + litellm_logging_obj=logging_obj, ) kwargs["response_cost"] = response_cost kwargs["model"] = model + kwargs.setdefault("litellm_params", {}) + raw_logging_metadata = litellm_params.get("metadata") + if raw_logging_metadata is not None: + kwargs["litellm_params"]["metadata"] = ( + merge_metadata_preserving_deployment_model_info( + litellm_metadata=raw_logging_metadata, + user_metadata=kwargs["litellm_params"].get("metadata"), + ) + ) passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore kwargs.get("passthrough_logging_payload") ) diff --git a/litellm/responses/main.py b/litellm/responses/main.py index 6f7e38dc8b8..389511e797f 100644 --- a/litellm/responses/main.py +++ b/litellm/responses/main.py @@ -25,6 +25,9 @@ ) from litellm.constants import request_timeout from litellm.litellm_core_utils.asyncify import run_async_function +from litellm.litellm_core_utils.core_helpers import ( + merge_metadata_preserving_deployment_model_info, +) from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.prompt_templates.common_utils import ( update_responses_input_with_model_file_ids, @@ -740,7 +743,11 @@ def responses( # Pre Call logging - preserve metadata for custom callbacks # When called from completion bridge (codex models), metadata is in litellm_metadata - metadata_for_callbacks = metadata or kwargs.get("litellm_metadata") or {} + metadata_for_callbacks = merge_metadata_preserving_deployment_model_info( + litellm_metadata=kwargs.get("litellm_metadata"), + user_metadata=metadata, + model_info=kwargs.get("model_info"), + ) litellm_logging_obj.update_environment_variables( model=model, diff --git a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py index f4aeb27b31a..bf16dd96ba5 100644 --- a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py +++ b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py @@ -13,6 +13,7 @@ from litellm.constants import SENTRY_DENYLIST, SENTRY_PII_DENYLIST from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging from litellm.litellm_core_utils.litellm_logging import set_callbacks +from litellm.types.llms.openai import ResponsesAPIResponse from litellm.types.utils import ModelResponse, TextCompletionResponse @@ -148,6 +149,53 @@ def test_use_custom_pricing_for_model(): assert use_custom_pricing_for_model(litellm_params) == True +def test_get_model_info_from_litellm_params_preserves_explicit_empty_metadata(): + from litellm.litellm_core_utils.core_helpers import ( + get_model_info_from_litellm_params, + ) + + litellm_params = { + "metadata": {"model_info": {}}, + "model_info": {"id": "top-level"}, + "litellm_metadata": {"model_info": {"id": "router"}}, + } + + assert get_model_info_from_litellm_params(litellm_params) == {} + + +def test_merge_metadata_preserving_deployment_model_info_keeps_router_model_info(): + from litellm.litellm_core_utils.core_helpers import ( + merge_metadata_preserving_deployment_model_info, + ) + + merged_metadata = merge_metadata_preserving_deployment_model_info( + litellm_metadata={"model_info": {"id": "router"}, "deployment": "test"}, + user_metadata={"model_info": {"id": "user"}, "user_field": "present"}, + model_info={"id": "top-level"}, + ) + + assert merged_metadata["model_info"] == {"id": "router"} + assert merged_metadata["deployment"] == "test" + assert merged_metadata["user_field"] == "present" + + +def test_merge_metadata_preserving_deployment_model_info_preserves_explicit_none(): + from litellm.litellm_core_utils.core_helpers import ( + merge_metadata_preserving_deployment_model_info, + ) + + merged_metadata = merge_metadata_preserving_deployment_model_info( + litellm_metadata={"model_info": None, "deployment": "test"}, + user_metadata={"model_info": {"id": "user"}, "user_field": "present"}, + model_info={"id": "top-level"}, + ) + + assert "model_info" in merged_metadata + assert merged_metadata["model_info"] is None + assert merged_metadata["deployment"] == "test" + assert merged_metadata["user_field"] == "present" + + def test_logging_prevent_double_logging(logging_obj): """ When using a bridge, log only once from the underlying bridge call. @@ -538,6 +586,115 @@ def _guardrail_logging_hook(kwargs, result, call_type): guardrail.logging_hook.assert_called_once() assert logging_obj.model_call_details.get("guardrail_hook_ran") is True +def test_process_hidden_params_and_response_cost_uses_router_model_id_for_aresponses( + logging_obj, +): + logging_obj.stream = False + logging_obj.call_type = "aresponses" + logging_obj.model_call_details["litellm_params"] = { + "metadata": {"model_info": {"id": "deployment-custom-pricing-test"}} + } + + response = ResponsesAPIResponse( + id="resp_test", + object="response", + created_at=1741476542, + status="completed", + model="gpt-4o", + output=[], + parallel_tool_calls=True, + usage={"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + text={"format": {"type": "text"}}, + error=None, + incomplete_details=None, + instructions=None, + metadata={}, + temperature=1.0, + tool_choice="auto", + tools=[], + top_p=1.0, + max_output_tokens=None, + previous_response_id=None, + reasoning={"effort": None, "summary": None}, + truncation="disabled", + user=None, + ) + + logging_obj._response_cost_calculator = MagicMock(return_value=123.0) + logging_obj._build_standard_logging_payload = MagicMock(return_value={}) + + import datetime + + start_time = datetime.datetime.now() + end_time = datetime.datetime.now() + + logging_obj._process_hidden_params_and_response_cost( + logging_result=response, + start_time=start_time, + end_time=end_time, + ) + + logging_obj._response_cost_calculator.assert_called_once() + assert logging_obj._response_cost_calculator.call_args.kwargs["router_model_id"] == ( + "deployment-custom-pricing-test" + ) + + +def test_streaming_response_cost_uses_router_model_id_for_aresponses_websocket( + logging_obj, +): + import datetime + + logging_obj.stream = True + logging_obj.call_type = "_aresponses_websocket" + logging_obj.model_call_details["litellm_params"] = { + "metadata": {"model_info": {"id": "deployment-custom-pricing-test"}} + } + logging_obj.sync_streaming_chunks = [{"choices": []}] + + response = ResponsesAPIResponse( + id="resp_test", + object="response", + created_at=1741476542, + status="completed", + model="gpt-4o", + output=[], + parallel_tool_calls=True, + usage={"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + text={"format": {"type": "text"}}, + error=None, + incomplete_details=None, + instructions=None, + metadata={}, + temperature=1.0, + tool_choice="auto", + tools=[], + top_p=1.0, + max_output_tokens=None, + previous_response_id=None, + reasoning={"effort": None, "summary": None}, + truncation="disabled", + user=None, + ) + + logging_obj._get_assembled_streaming_response = MagicMock(return_value=response) + logging_obj._response_cost_calculator = MagicMock(return_value=123.0) + logging_obj._build_standard_logging_payload = MagicMock(return_value={}) + + start_time = datetime.datetime.now() + end_time = datetime.datetime.now() + + logging_obj.success_handler( + result=response, + start_time=start_time, + end_time=end_time, + cache_hit=False, + ) + + logging_obj._response_cost_calculator.assert_called_once() + assert logging_obj._response_cost_calculator.call_args.kwargs["router_model_id"] == ( + "deployment-custom-pricing-test" + ) def test_get_user_agent_tags(): from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup diff --git a/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py b/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py index f145cfef16d..ac7a7963323 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py @@ -1,21 +1,87 @@ +import asyncio import json import os import sys from datetime import datetime -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from unittest.mock import AsyncMock, MagicMock, patch +import httpx import pytest sys.path.insert( 0, os.path.abspath("../../..") ) # Adds the parent directory to the system path +import litellm +from litellm import Router +from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import ( AnthropicPassthroughLoggingHandler, ) +CUSTOM_INPUT_COST = 0.50 +CUSTOM_OUTPUT_COST = 1.00 +DEPLOYMENT_MODEL_ID = "deployment-custom-pricing-test" + + +class CostCapturingCallback(CustomLogger): + def __init__(self): + super().__init__() + self.response_cost: Optional[float] = None + self.event = asyncio.Event() + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self.response_cost = kwargs.get("response_cost") + self.event.set() + + +class MockStreamingHTTPResponse: + def __init__(self, lines, status_code=200, headers=None): + self.status_code = status_code + self._lines = [line.encode("utf-8") for line in lines] + self.headers = httpx.Headers(headers or {"content-type": "text/event-stream"}) + + async def aiter_lines(self): + for line in self._lines: + yield line.decode("utf-8") + + async def aiter_bytes(self): + for line in self._lines: + yield line + b"\n" + + def raise_for_status(self): + return None + + +def _make_router_with_custom_pricing( + backend_model: str, api_key: str = "fake-key" +) -> Router: + return Router( + model_list=[ + { + "model_name": "test-custom-pricing", + "litellm_params": { + "model": backend_model, + "api_key": api_key, + }, + "model_info": { + "id": DEPLOYMENT_MODEL_ID, + "input_cost_per_token": CUSTOM_INPUT_COST, + "output_cost_per_token": CUSTOM_OUTPUT_COST, + }, + }, + ], + ) + + +@pytest.fixture(autouse=True) +def cleanup_custom_pricing_state(): + yield + litellm.callbacks = [] + litellm.model_cost.pop(DEPLOYMENT_MODEL_ID, None) + class TestAnthropicLoggingHandlerModelFallback: """Test the model fallback logic in the anthropic passthrough logging handler.""" @@ -244,6 +310,61 @@ def test_cost_calculation_with_azure_ai_custom_llm_provider( assert call_kwargs["model"] == "azure_ai/claude-sonnet-4-5_gb_20250929" assert call_kwargs["custom_llm_provider"] == "azure_ai" + @patch("litellm.completion_cost") + def test_metadata_merge_does_not_overwrite_existing_litellm_params( + self, mock_completion_cost + ): + mock_completion_cost.return_value = 123.0 + + logging_obj = MagicMock() + logging_obj.model_call_details = { + "custom_llm_provider": "anthropic", + "litellm_params": { + "metadata": { + "model_info": { + "id": "deployment-custom-pricing-test", + "input_cost_per_token": 0.5, + "output_cost_per_token": 1.0, + }, + "new_field": "from-logging-obj", + "shared_field": "from-logging-obj", + }, + "stream_response": {"should": "not-overwrite"}, + }, + } + logging_obj.litellm_call_id = "call-test" + + model_response = litellm.ModelResponse() + model_response.usage = litellm.Usage( + prompt_tokens=100, completion_tokens=50, total_tokens=150 + ) # type: ignore + + kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=model_response, + model="claude-sonnet-4-20250514", + kwargs={ + "litellm_params": { + "metadata": { + "existing_field": "preserved", + "shared_field": "from-existing", + }, + "stream_response": {"keep": "existing"}, + } + }, + start_time=datetime.now(), + end_time=datetime.now(), + logging_obj=logging_obj, + ) + + assert kwargs["litellm_params"]["metadata"]["existing_field"] == "preserved" + assert kwargs["litellm_params"]["metadata"]["new_field"] == "from-logging-obj" + assert kwargs["litellm_params"]["metadata"]["shared_field"] == "from-existing" + assert ( + kwargs["litellm_params"]["metadata"]["model_info"]["id"] + == "deployment-custom-pricing-test" + ) + assert kwargs["litellm_params"]["stream_response"] == {"keep": "existing"} + @patch("litellm.completion_cost") def test_cost_calculation_without_custom_llm_provider(self, mock_completion_cost): """Test that cost calculation works without custom_llm_provider (standard Anthropic)""" @@ -318,6 +439,97 @@ def test_cost_calculation_does_not_duplicate_provider_prefix( assert call_kwargs["custom_llm_provider"] == "azure_ai" +class TestAnthropicPassthroughLoggingPayload: + @staticmethod + def _mock_streaming_events() -> List[str]: + return [ + "event: message_start", + f'data: {{"type":"message_start","message":{{"id":"msg_test","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","stop_reason":null,"stop_sequence":null,"usage":{{"input_tokens":100,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}}}}}', + "", + "event: content_block_start", + 'data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}', + "", + "event: content_block_delta", + 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello!"}}', + "", + "event: content_block_stop", + 'data: {"type":"content_block_stop","index":0}', + "", + "event: message_delta", + 'data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":50}}', + "", + "event: message_stop", + 'data: {"type":"message_stop"}', + ] + + async def _assert_streaming_custom_pricing(self, metadata: Optional[dict] = None): + cost_callback = CostCapturingCallback() + litellm.callbacks = [cost_callback] + + try: + router = _make_router_with_custom_pricing( + "anthropic/claude-sonnet-4-20250514" + ) + mock_stream_response = MockStreamingHTTPResponse( + self._mock_streaming_events(), + headers={ + "content-type": "text/event-stream", + "request-id": "req_test", + }, + ) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post: + mock_post.return_value = mock_stream_response + + response = await router.aanthropic_messages( + model="test-custom-pricing", + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=100, + stream=True, + metadata=metadata, + ) + + async for _chunk in response: + pass + + try: + await asyncio.wait_for(cost_callback.event.wait(), timeout=10.0) + except asyncio.TimeoutError: + pass + + assert cost_callback.response_cost is not None + expected_custom_cost = 100 * CUSTOM_INPUT_COST + 50 * CUSTOM_OUTPUT_COST + assert cost_callback.response_cost == pytest.approx( + expected_custom_cost, rel=0.01 + ) + finally: + litellm.callbacks = [] + + @pytest.mark.asyncio + async def test_streaming_messages_uses_custom_pricing(self): + await self._assert_streaming_custom_pricing() + + @pytest.mark.asyncio + async def test_streaming_messages_with_user_metadata_uses_custom_pricing(self): + await self._assert_streaming_custom_pricing(metadata={"user_field": "present"}) + + @pytest.mark.asyncio + async def test_streaming_messages_with_user_model_info_ignored_for_pricing(self): + await self._assert_streaming_custom_pricing( + metadata={ + "user_field": "present", + "model_info": { + "id": "user-garbage", + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + }, + } + ) + + class TestAnthropicBatchPassthroughCostTracking: """Test cases for Anthropic batch passthrough cost tracking functionality""" diff --git a/tests/test_litellm/test_custom_pricing_metadata_propagation.py b/tests/test_litellm/test_custom_pricing_metadata_propagation.py new file mode 100644 index 00000000000..81008fc3c68 --- /dev/null +++ b/tests/test_litellm/test_custom_pricing_metadata_propagation.py @@ -0,0 +1,799 @@ +""" +Test that custom pricing from model_info is correctly propagated through +the metadata flow for all API endpoints. + +Background: +- PR #20679 correctly strips custom pricing from the shared litellm.model_cost key + to prevent cross-deployment pollution. +- The deployment's model_info retains custom pricing and must reach cost calculation + via litellm_params["metadata"]["model_info"]. +- /v1/chat/completions works because it explicitly extracts pricing into litellm_params. +- /v1/responses and /v1/messages fail because model_info with custom pricing + does not reach the logging object's litellm_params["metadata"]. + +These tests verify that use_custom_pricing_for_model() returns True when called +with the litellm_params as constructed by each code path. +""" + +import asyncio +import json +import os +import sys +from typing import Optional +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import httpx +import pytest + +sys.path.insert(0, os.path.abspath("../../..")) + +import litellm +from litellm import Router +from litellm.litellm_core_utils.core_helpers import ( + get_model_info_from_litellm_params, + merge_metadata_preserving_deployment_model_info, +) +from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObj, + use_custom_pricing_for_model, +) +from litellm.litellm_core_utils.litellm_logging import CustomLogger + + +# --------------------------------------------------------------------------- +# Custom callback to capture response cost from the logging flow +# --------------------------------------------------------------------------- +class CostCapturingCallback(CustomLogger): + """Captures response_cost from the async success callback kwargs.""" + + def __init__(self): + super().__init__() + self.response_cost: Optional[float] = None + self.custom_pricing: Optional[bool] = None + self.event = asyncio.Event() + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self.response_cost = kwargs.get("response_cost") + # Extract custom_pricing from standard_logging_object if available + slo = kwargs.get("standard_logging_object", {}) + if slo: + self.custom_pricing = slo.get("custom_pricing") + self.event.set() + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + self.response_cost = kwargs.get("response_cost") + slo = kwargs.get("standard_logging_object", {}) + if slo: + self.custom_pricing = slo.get("custom_pricing") + + +# --------------------------------------------------------------------------- +# Fixtures: Router with custom pricing deployment +# --------------------------------------------------------------------------- +CUSTOM_INPUT_COST = 0.50 # $0.50 per token — absurdly high, easy to spot +CUSTOM_OUTPUT_COST = 1.00 + + +DEPLOYMENT_MODEL_ID = "deployment-custom-pricing-test" + + +def _make_router_with_custom_pricing(backend_model: str, api_key: str = "fake-key"): + """Create a Router with a single deployment that has custom pricing.""" + return Router( + model_list=[ + { + "model_name": "test-custom-pricing", + "litellm_params": { + "model": backend_model, + "api_key": api_key, + }, + "model_info": { + "id": DEPLOYMENT_MODEL_ID, + "input_cost_per_token": CUSTOM_INPUT_COST, + "output_cost_per_token": CUSTOM_OUTPUT_COST, + }, + }, + ], + ) + +@pytest.fixture(autouse=True) +def cleanup_model_cost(): + """Remove callback and pricing state between tests.""" + yield + litellm.callbacks = [] + litellm.model_cost.pop(DEPLOYMENT_MODEL_ID, None) + + +# --------------------------------------------------------------------------- +# Mock HTTP response helpers +# --------------------------------------------------------------------------- +class MockHTTPResponse: + """Mimics httpx.Response for non-streaming and streaming.""" + + def __init__(self, json_data, status_code=200, headers=None): + self._json_data = json_data + self.status_code = status_code + self.text = json.dumps(json_data) + self.headers = httpx.Headers(headers or {"content-type": "application/json"}) + + def json(self): + return self._json_data + + def raise_for_status(self): + if self.status_code >= 400: + raise httpx.HTTPStatusError( + "error", request=Mock(), response=self + ) + + async def aiter_bytes(self): + yield self.text.encode("utf-8") + + async def aiter_lines(self): + for line in self.text.split("\n"): + yield line + + def iter_lines(self): + for line in self.text.split("\n"): + yield line + + +class MockStreamingHTTPResponse: + """Mimics httpx.Response for streaming (SSE).""" + + def __init__(self, sse_lines: list[str], status_code=200, headers=None): + self._sse_lines = sse_lines + self.status_code = status_code + self.headers = httpx.Headers( + headers or {"content-type": "text/event-stream"} + ) + self.text = "\n".join(sse_lines) + + def json(self): + return json.loads(self.text) + + def raise_for_status(self): + if self.status_code >= 400: + raise httpx.HTTPStatusError( + "error", request=Mock(), response=self + ) + + def iter_lines(self): + for line in self._sse_lines: + yield line + + async def aiter_lines(self): + for line in self._sse_lines: + yield line + + async def aiter_bytes(self): + for line in self._sse_lines: + yield (line + "\n").encode("utf-8") + + +# --------------------------------------------------------------------------- +# Standard mock response payloads +# --------------------------------------------------------------------------- +RESPONSES_API_MOCK = { + "id": "resp_test_custom_pricing", + "object": "response", + "created_at": 1741476542, + "status": "completed", + "model": "gpt-4o", + "output": [ + { + "type": "message", + "id": "msg_test", + "status": "completed", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "Hello!", "annotations": []} + ], + } + ], + "parallel_tool_calls": True, + "usage": { + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + "output_tokens_details": {"reasoning_tokens": 0}, + }, + "text": {"format": {"type": "text"}}, + "error": None, + "incomplete_details": None, + "instructions": None, + "metadata": {}, + "temperature": 1.0, + "tool_choice": "auto", + "tools": [], + "top_p": 1.0, + "max_output_tokens": None, + "previous_response_id": None, + "reasoning": {"effort": None, "summary": None}, + "truncation": "disabled", + "user": None, +} + +ANTHROPIC_MESSAGES_MOCK = { + "id": "msg_test_custom_pricing", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello!"}], + "model": "claude-sonnet-4-20250514", + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": { + "input_tokens": 100, + "output_tokens": 50, + }, +} + + +# ========================================================================= +# UNIT TESTS: Metadata propagation to use_custom_pricing_for_model +# ========================================================================= + + +class TestRouterMetadataPropagation: + """ + Verify that Router._update_kwargs_with_deployment places model_info + with custom pricing into the correct metadata location. + """ + + def test_chat_completions_puts_model_info_in_metadata(self): + """ + For chat/completions, function_name is None or "acompletion" etc., + so metadata_variable_name = "metadata". model_info should be in + kwargs["metadata"]["model_info"]. + """ + router = _make_router_with_custom_pricing("openai/gpt-4o") + deployment = router.model_list[0] + + kwargs: dict = {} + router._update_kwargs_with_deployment( + deployment=deployment, kwargs=kwargs, function_name="acompletion" + ) + + # model_info should be in kwargs["metadata"]["model_info"] + assert "metadata" in kwargs + model_info = kwargs["metadata"].get("model_info", {}) + assert model_info.get("input_cost_per_token") == CUSTOM_INPUT_COST + assert model_info.get("output_cost_per_token") == CUSTOM_OUTPUT_COST + + # use_custom_pricing_for_model should return True when litellm_params + # includes this metadata + litellm_params = {"metadata": kwargs["metadata"]} + assert use_custom_pricing_for_model(litellm_params) is True + + def test_generic_api_call_puts_model_info_in_litellm_metadata(self): + """ + For _ageneric_api_call_with_fallbacks (used by /v1/responses and /v1/messages), + metadata_variable_name = "litellm_metadata". model_info should be in + kwargs["litellm_metadata"]["model_info"]. + """ + router = _make_router_with_custom_pricing("openai/gpt-4o") + deployment = router.model_list[0] + + kwargs: dict = {} + router._update_kwargs_with_deployment( + deployment=deployment, + kwargs=kwargs, + function_name="_ageneric_api_call_with_fallbacks", + ) + + # model_info should be in kwargs["litellm_metadata"]["model_info"] + assert "litellm_metadata" in kwargs + model_info = kwargs["litellm_metadata"].get("model_info", {}) + assert model_info.get("input_cost_per_token") == CUSTOM_INPUT_COST + assert model_info.get("output_cost_per_token") == CUSTOM_OUTPUT_COST + + def test_responses_api_metadata_for_callbacks_gets_model_info(self): + """ + In responses/main.py, metadata_for_callbacks should merge + litellm_metadata (which has model_info) with explicit metadata. + """ + router = _make_router_with_custom_pricing("openai/gpt-4o") + deployment = router.model_list[0] + + kwargs: dict = {} + router._update_kwargs_with_deployment( + deployment=deployment, + kwargs=kwargs, + function_name="_ageneric_api_call_with_fallbacks", + ) + + metadata = {"user_field": "present"} + metadata_for_callbacks = merge_metadata_preserving_deployment_model_info( + kwargs.get("litellm_metadata"), metadata + ) + + model_info = metadata_for_callbacks.get("model_info", {}) + assert model_info.get("input_cost_per_token") == CUSTOM_INPUT_COST, ( + "metadata_for_callbacks should contain model_info with custom pricing " + "when explicit metadata is also passed" + ) + assert metadata_for_callbacks["user_field"] == "present" + + litellm_params = {"metadata": metadata_for_callbacks} + assert use_custom_pricing_for_model(litellm_params) is True + + def test_responses_api_user_model_info_does_not_override_deployment(self): + """ + User metadata should not overwrite router-provided model_info for + responses callback pricing calculation. + """ + router = _make_router_with_custom_pricing("openai/gpt-4o") + deployment = router.model_list[0] + + kwargs: dict = {} + router._update_kwargs_with_deployment( + deployment=deployment, + kwargs=kwargs, + function_name="_ageneric_api_call_with_fallbacks", + ) + + user_metadata = { + "user_field": "present", + "model_info": {"id": "user-supplied", "input_cost_per_token": 0.0}, + } + metadata_for_callbacks = merge_metadata_preserving_deployment_model_info( + kwargs.get("litellm_metadata"), user_metadata + ) + + model_info = metadata_for_callbacks.get("model_info", {}) + assert model_info.get("id") == DEPLOYMENT_MODEL_ID + assert model_info.get("input_cost_per_token") == CUSTOM_INPUT_COST + assert metadata_for_callbacks["user_field"] == "present" + + def test_messages_api_user_model_info_does_not_override_deployment(self): + """ + For /v1/messages, the handler should preserve router-provided model_info + even if the request metadata includes its own model_info payload. + """ + router = _make_router_with_custom_pricing("anthropic/claude-sonnet-4-20250514") + deployment = router.model_list[0] + + kwargs: dict = {} + router._update_kwargs_with_deployment( + deployment=deployment, + kwargs=kwargs, + function_name="_ageneric_api_call_with_fallbacks", + ) + + kwargs["metadata"] = { + "user_field": "present", + "model_info": {"id": "user-supplied", "output_cost_per_token": 0.0}, + } + metadata_from_handler = merge_metadata_preserving_deployment_model_info( + kwargs.get("litellm_metadata"), kwargs.get("metadata") + ) + litellm_params = {"metadata": metadata_from_handler} + + model_info = metadata_from_handler.get("model_info", {}) + assert model_info.get("id") == DEPLOYMENT_MODEL_ID + assert model_info.get("output_cost_per_token") == CUSTOM_OUTPUT_COST + assert metadata_from_handler["user_field"] == "present" + assert use_custom_pricing_for_model(litellm_params) is True + + def test_empty_deployment_model_info_still_overrides_user_metadata(self): + """ + An explicitly empty deployment model_info should remain explicit during + merge rather than being treated like a missing value. + """ + metadata_for_callbacks = merge_metadata_preserving_deployment_model_info( + {"model_info": {}}, + { + "user_field": "present", + "model_info": {"id": "user-supplied", "input_cost_per_token": 0.0}, + }, + ) + + assert metadata_for_callbacks["model_info"] == {} + assert metadata_for_callbacks["user_field"] == "present" + + def test_get_model_info_preserves_explicit_empty_metadata_model_info(self): + """Explicit empty metadata.model_info should win over fallback locations.""" + model_info = get_model_info_from_litellm_params( + { + "metadata": {"model_info": {}}, + "model_info": {"id": "top-level"}, + "litellm_metadata": {"model_info": {"id": "litellm-metadata"}}, + } + ) + + assert model_info == {} + + def test_use_custom_pricing_detects_top_level_model_info(self): + """Custom pricing detection should work when model_info is top-level.""" + litellm_params = { + "metadata": {"user_field": "present"}, + "model_info": { + "id": DEPLOYMENT_MODEL_ID, + "input_cost_per_token": CUSTOM_INPUT_COST, + "output_cost_per_token": CUSTOM_OUTPUT_COST, + }, + } + + assert use_custom_pricing_for_model(litellm_params) is True + + +# ========================================================================= +# INTEGRATION TESTS: Full Router → cost calculation with HTTP mocking +# ========================================================================= + + +class TestResponsesAPICustomPricingCost: + """ + Test that /v1/responses (via router.aresponses) uses custom pricing + for cost calculation when model_info has custom pricing fields. + """ + + @pytest.mark.asyncio + async def test_nonstreaming_responses_uses_custom_pricing(self): + """Non-streaming /v1/responses should use custom pricing for cost.""" + cost_callback = CostCapturingCallback() + litellm.callbacks = [cost_callback] + + try: + router = _make_router_with_custom_pricing("openai/gpt-4o") + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post: + mock_post.return_value = MockHTTPResponse(RESPONSES_API_MOCK) + + response = await router.aresponses( + model="test-custom-pricing", + input="Hello, how are you?", + ) + + # Wait for async callback + try: + await asyncio.wait_for(cost_callback.event.wait(), timeout=10.0) + except asyncio.TimeoutError: + pass + + assert cost_callback.response_cost is not None, ( + "response_cost should be set in the callback" + ) + + # With 100 input + 50 output tokens at custom pricing: + # expected = 100 * 0.50 + 50 * 1.00 = 50 + 50 = 100.0 + expected_custom_cost = 100 * CUSTOM_INPUT_COST + 50 * CUSTOM_OUTPUT_COST + assert cost_callback.response_cost == pytest.approx( + expected_custom_cost, rel=0.01 + ), ( + f"Cost should use custom pricing ({expected_custom_cost}), " + f"got {cost_callback.response_cost}" + ) + finally: + litellm.callbacks = [] + + @pytest.mark.asyncio + async def test_streaming_responses_uses_custom_pricing(self): + """Streaming /v1/responses should use custom pricing for cost.""" + cost_callback = CostCapturingCallback() + litellm.callbacks = [cost_callback] + + try: + router = _make_router_with_custom_pricing("openai/gpt-4o") + + # SSE events for streaming responses API + sse_events = [ + 'data: {"type":"response.created","response":{"id":"resp_test","object":"response","created_at":1741476542,"status":"in_progress","model":"gpt-4o","output":[],"usage":null}}', + 'data: {"type":"response.output_item.added","output_index":0,"item":{"type":"message","id":"msg_test","status":"in_progress","role":"assistant","content":[]}}', + 'data: {"type":"response.content_part.added","output_index":0,"content_index":0,"part":{"type":"output_text","text":"","annotations":[]}}', + 'data: {"type":"response.output_text.delta","output_index":0,"content_index":0,"delta":"Hello!"}', + 'data: {"type":"response.output_text.done","output_index":0,"content_index":0,"text":"Hello!"}', + 'data: {"type":"response.content_part.done","output_index":0,"content_index":0,"part":{"type":"output_text","text":"Hello!","annotations":[]}}', + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"message","id":"msg_test","status":"completed","role":"assistant","content":[{"type":"output_text","text":"Hello!","annotations":[]}]}}', + f'data: {{"type":"response.completed","response":{json.dumps(RESPONSES_API_MOCK)}}}', + "data: [DONE]", + ] + + mock_stream_response = MockStreamingHTTPResponse(sse_events) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post: + mock_post.return_value = mock_stream_response + + response = await router.aresponses( + model="test-custom-pricing", + input="Hello, how are you?", + stream=True, + ) + + # Consume the stream + async for chunk in response: + pass + + # Wait for async callback + try: + await asyncio.wait_for(cost_callback.event.wait(), timeout=10.0) + except asyncio.TimeoutError: + pass + + assert cost_callback.response_cost is not None, ( + "response_cost should be set in the callback" + ) + + expected_custom_cost = 100 * CUSTOM_INPUT_COST + 50 * CUSTOM_OUTPUT_COST + assert cost_callback.response_cost == pytest.approx( + expected_custom_cost, rel=0.01 + ), ( + f"Streaming cost should use custom pricing ({expected_custom_cost}), " + f"got {cost_callback.response_cost}" + ) + finally: + litellm.callbacks = [] + + def test_sync_streaming_responses_uses_custom_pricing(self): + """Sync streaming /v1/responses should use custom pricing for cost.""" + cost_callback = CostCapturingCallback() + litellm.callbacks = [cost_callback] + + try: + router = _make_router_with_custom_pricing("openai/gpt-4o") + + sse_events = [ + 'data: {"type":"response.created","response":{"id":"resp_test","object":"response","created_at":1741476542,"status":"in_progress","model":"gpt-4o","output":[],"usage":null}}', + 'data: {"type":"response.output_item.added","output_index":0,"item":{"type":"message","id":"msg_test","status":"in_progress","role":"assistant","content":[]}}', + 'data: {"type":"response.content_part.added","output_index":0,"content_index":0,"part":{"type":"output_text","text":"","annotations":[]}}', + 'data: {"type":"response.output_text.delta","output_index":0,"content_index":0,"delta":"Hello!"}', + 'data: {"type":"response.output_text.done","output_index":0,"content_index":0,"text":"Hello!"}', + 'data: {"type":"response.content_part.done","output_index":0,"content_index":0,"part":{"type":"output_text","text":"Hello!","annotations":[]}}', + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"message","id":"msg_test","status":"completed","role":"assistant","content":[{"type":"output_text","text":"Hello!","annotations":[]}]}}', + f'data: {{"type":"response.completed","response":{json.dumps(RESPONSES_API_MOCK)}}}', + "data: [DONE]", + ] + + mock_stream_response = MockStreamingHTTPResponse(sse_events) + + with patch( + "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", + new_callable=MagicMock, + ) as mock_post: + mock_post.return_value = mock_stream_response + + response = router.responses( + model="test-custom-pricing", + input="Hello, how are you?", + stream=True, + ) + + for _chunk in response: + pass + + assert cost_callback.response_cost is not None, ( + "response_cost should be set in the callback" + ) + + expected_custom_cost = 100 * CUSTOM_INPUT_COST + 50 * CUSTOM_OUTPUT_COST + assert cost_callback.response_cost == pytest.approx( + expected_custom_cost, rel=0.01 + ), ( + f"Sync streaming cost should use custom pricing ({expected_custom_cost}), " + f"got {cost_callback.response_cost}" + ) + finally: + litellm.callbacks = [] + + @pytest.mark.asyncio + async def test_streaming_responses_with_user_metadata_uses_custom_pricing(self): + """Streaming /v1/responses should preserve custom pricing when metadata is also passed.""" + cost_callback = CostCapturingCallback() + litellm.callbacks = [cost_callback] + + try: + router = _make_router_with_custom_pricing("openai/gpt-4o") + + sse_events = [ + 'data: {"type":"response.created","response":{"id":"resp_test","object":"response","created_at":1741476542,"status":"in_progress","model":"gpt-4o","output":[],"usage":null}}', + 'data: {"type":"response.output_item.added","output_index":0,"item":{"type":"message","id":"msg_test","status":"in_progress","role":"assistant","content":[]}}', + 'data: {"type":"response.content_part.added","output_index":0,"content_index":0,"part":{"type":"output_text","text":"","annotations":[]}}', + 'data: {"type":"response.output_text.delta","output_index":0,"content_index":0,"delta":"Hello!"}', + 'data: {"type":"response.output_text.done","output_index":0,"content_index":0,"text":"Hello!"}', + 'data: {"type":"response.content_part.done","output_index":0,"content_index":0,"part":{"type":"output_text","text":"Hello!","annotations":[]}}', + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"message","id":"msg_test","status":"completed","role":"assistant","content":[{"type":"output_text","text":"Hello!","annotations":[]}]}}', + f'data: {{"type":"response.completed","response":{json.dumps(RESPONSES_API_MOCK)}}}', + "data: [DONE]", + ] + + mock_stream_response = MockStreamingHTTPResponse(sse_events) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post: + mock_post.return_value = mock_stream_response + + response = await router.aresponses( + model="test-custom-pricing", + input="Hello, how are you?", + stream=True, + metadata={"user_field": "present"}, + ) + + async for chunk in response: + pass + + try: + await asyncio.wait_for(cost_callback.event.wait(), timeout=10.0) + except asyncio.TimeoutError: + pass + + assert cost_callback.response_cost is not None + expected_custom_cost = 100 * CUSTOM_INPUT_COST + 50 * CUSTOM_OUTPUT_COST + assert cost_callback.response_cost == pytest.approx( + expected_custom_cost, rel=0.01 + ) + finally: + litellm.callbacks = [] + + @pytest.mark.asyncio + async def test_streaming_responses_with_user_model_info_ignored_for_pricing(self): + """Streaming /v1/responses should ignore user-supplied metadata.model_info.""" + cost_callback = CostCapturingCallback() + litellm.callbacks = [cost_callback] + + try: + router = _make_router_with_custom_pricing("openai/gpt-4o") + + sse_events = [ + 'data: {"type":"response.created","response":{"id":"resp_test","object":"response","created_at":1741476542,"status":"in_progress","model":"gpt-4o","output":[],"usage":null}}', + 'data: {"type":"response.output_item.added","output_index":0,"item":{"type":"message","id":"msg_test","status":"in_progress","role":"assistant","content":[]}}', + 'data: {"type":"response.content_part.added","output_index":0,"content_index":0,"part":{"type":"output_text","text":"","annotations":[]}}', + 'data: {"type":"response.output_text.delta","output_index":0,"content_index":0,"delta":"Hello!"}', + 'data: {"type":"response.output_text.done","output_index":0,"content_index":0,"text":"Hello!"}', + 'data: {"type":"response.content_part.done","output_index":0,"content_index":0,"part":{"type":"output_text","text":"Hello!","annotations":[]}}', + 'data: {"type":"response.output_item.done","output_index":0,"item":{"type":"message","id":"msg_test","status":"completed","role":"assistant","content":[{"type":"output_text","text":"Hello!","annotations":[]}]}}', + f'data: {{"type":"response.completed","response":{json.dumps(RESPONSES_API_MOCK)}}}', + "data: [DONE]", + ] + + mock_stream_response = MockStreamingHTTPResponse(sse_events) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post: + mock_post.return_value = mock_stream_response + + response = await router.aresponses( + model="test-custom-pricing", + input="Hello, how are you?", + stream=True, + metadata={ + "user_field": "present", + "model_info": { + "id": "user-garbage", + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + }, + }, + ) + + async for chunk in response: + pass + + try: + await asyncio.wait_for(cost_callback.event.wait(), timeout=10.0) + except asyncio.TimeoutError: + pass + + assert cost_callback.response_cost is not None + expected_custom_cost = 100 * CUSTOM_INPUT_COST + 50 * CUSTOM_OUTPUT_COST + assert cost_callback.response_cost == pytest.approx( + expected_custom_cost, rel=0.01 + ) + finally: + litellm.callbacks = [] + + +class TestAnthropicMessagesCustomPricingCost: + """ + Test that /v1/messages (Anthropic via router) uses custom pricing + for cost calculation when model_info has custom pricing fields. + """ + + @pytest.mark.asyncio + async def test_nonstreaming_messages_uses_custom_pricing(self): + """Non-streaming /v1/messages should use custom pricing for cost.""" + cost_callback = CostCapturingCallback() + litellm.callbacks = [cost_callback] + + try: + router = _make_router_with_custom_pricing( + "anthropic/claude-sonnet-4-20250514" + ) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post: + mock_post.return_value = MockHTTPResponse( + ANTHROPIC_MESSAGES_MOCK, + headers={ + "content-type": "application/json", + "request-id": "req_test", + }, + ) + + response = await router.aanthropic_messages( + model="test-custom-pricing", + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=100, + ) + + # Wait for async callback + try: + await asyncio.wait_for(cost_callback.event.wait(), timeout=10.0) + except asyncio.TimeoutError: + pass + + assert cost_callback.response_cost is not None, ( + "response_cost should be set in the callback" + ) + + expected_custom_cost = 100 * CUSTOM_INPUT_COST + 50 * CUSTOM_OUTPUT_COST + assert cost_callback.response_cost == pytest.approx( + expected_custom_cost, rel=0.01 + ), ( + f"Cost should use custom pricing ({expected_custom_cost}), " + f"got {cost_callback.response_cost}" + ) + finally: + litellm.callbacks = [] + + @pytest.mark.asyncio + async def test_nonstreaming_messages_with_user_metadata_uses_custom_pricing(self): + """Non-streaming /v1/messages should preserve custom pricing when metadata is also passed.""" + cost_callback = CostCapturingCallback() + litellm.callbacks = [cost_callback] + + try: + router = _make_router_with_custom_pricing( + "anthropic/claude-sonnet-4-20250514" + ) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + new_callable=AsyncMock, + ) as mock_post: + mock_post.return_value = MockHTTPResponse( + ANTHROPIC_MESSAGES_MOCK, + headers={ + "content-type": "application/json", + "request-id": "req_test", + }, + ) + + response = await router.aanthropic_messages( + model="test-custom-pricing", + messages=[{"role": "user", "content": "Hello!"}], + max_tokens=100, + metadata={"user_field": "present"}, + ) + + assert response is not None + + try: + await asyncio.wait_for(cost_callback.event.wait(), timeout=10.0) + except asyncio.TimeoutError: + pass + + assert cost_callback.response_cost is not None, ( + "response_cost should be set in the callback" + ) + + expected_custom_cost = 100 * CUSTOM_INPUT_COST + 50 * CUSTOM_OUTPUT_COST + assert cost_callback.response_cost == pytest.approx( + expected_custom_cost, rel=0.01 + ), ( + f"Cost should use custom pricing ({expected_custom_cost}), " + f"got {cost_callback.response_cost}" + ) + finally: + litellm.callbacks = []