From e16afd39becb13e58f1e72145b68a824fa5335dd Mon Sep 17 00:00:00 2001 From: Dmitriy Alergant <93501479+DmitriyAlergant@users.noreply.github.com> Date: Mon, 13 Oct 2025 00:59:22 -0400 Subject: [PATCH] Respect Anthropic API base overrides in token counter --- litellm/proxy/utils.py | 13 +- .../test_proxy_token_counter.py | 121 +++++++++++++++++- 2 files changed, 129 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8aa2f4075552..8ca1bffea8f9 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -3753,16 +3753,25 @@ async def count_tokens_with_anthropic_api( # Get Anthropic API key from deployment config anthropic_api_key = None + anthropic_api_base: Optional[str] = None if deployment is not None: - anthropic_api_key = deployment.get("litellm_params", {}).get("api_key") + litellm_params = deployment.get("litellm_params", {}) + anthropic_api_key = litellm_params.get("api_key") + anthropic_api_base = litellm_params.get("api_base") # Fallback to environment variable if not anthropic_api_key: anthropic_api_key = os.getenv("ANTHROPIC_API_KEY") + if not anthropic_api_base: + anthropic_api_base = os.getenv("ANTHROPIC_API_BASE") if anthropic_api_key and messages: # Call Anthropic API directly for more accurate token counting - client = anthropic.Anthropic(api_key=anthropic_api_key) + client_kwargs: Dict[str, Any] = {"api_key": anthropic_api_key} + if anthropic_api_base: + client_kwargs["base_url"] = anthropic_api_base + + client = anthropic.Anthropic(**client_kwargs) # Call with explicit parameters to satisfy type checking # Type ignore for now since messages come from generic dict input diff --git a/tests/proxy_unit_tests/test_proxy_token_counter.py b/tests/proxy_unit_tests/test_proxy_token_counter.py index 534569f15afb..fe5a79cfb85e 100644 --- a/tests/proxy_unit_tests/test_proxy_token_counter.py +++ b/tests/proxy_unit_tests/test_proxy_token_counter.py @@ -3,6 +3,9 @@ import sys, os +from pathlib import Path +from contextlib import contextmanager +import importlib from dotenv import load_dotenv load_dotenv() @@ -10,9 +13,13 @@ # this file is to test litellm/proxy -sys.path.insert( - 0, os.path.abspath("../..") -) # Adds the parent directory to the system path +repo_root = Path(__file__).resolve().parents[2] +repo_root_str = str(repo_root) +if repo_root_str not in sys.path: + sys.path.insert(0, repo_root_str) +for module_name in list(sys.modules.keys()): + if module_name == "litellm" or module_name.startswith("litellm."): + sys.modules.pop(module_name) import pytest, logging import litellm from litellm.proxy.proxy_server import token_counter @@ -517,6 +524,114 @@ async def test_factory_anthropic_endpoint_calls_anthropic_counter(): mock_anthropic_count.assert_called_once() +def _setup_mock_anthropic_client(mock_client_instance, input_tokens: int = 123): + from types import SimpleNamespace + from unittest.mock import MagicMock + + mock_client_instance.beta = MagicMock() + mock_client_instance.beta.messages = MagicMock() + mock_client_instance.beta.messages.count_tokens.return_value = SimpleNamespace( + input_tokens=input_tokens + ) + + +@contextmanager +def _local_proxy_utils_module(): + original_modules = {} + for module_name in list(sys.modules.keys()): + if module_name == "litellm" or module_name.startswith("litellm."): + original_modules[module_name] = sys.modules.pop(module_name) + try: + proxy_utils_module = importlib.import_module("litellm.proxy.utils") + yield proxy_utils_module + finally: + for module_name in list(sys.modules.keys()): + if module_name == "litellm" or module_name.startswith("litellm."): + sys.modules.pop(module_name) + sys.modules.update(original_modules) + + +def test_count_tokens_with_anthropic_api_respects_deployment_api_base(): + """Ensure Anthropic client honors deployment-specific api_base.""" + import asyncio + from types import SimpleNamespace + from unittest.mock import MagicMock, patch + import sys + + messages = [{"role": "user", "content": "Hello"}] + deployment = { + "litellm_params": { + "api_key": "test-key", + "api_base": "https://custom.anthropic.example", + } + } + + mock_client_instance = MagicMock() + _setup_mock_anthropic_client(mock_client_instance) + mock_anthropic_class = MagicMock(return_value=mock_client_instance) + + anthropic_module = SimpleNamespace(Anthropic=mock_anthropic_class) + + with _local_proxy_utils_module() as proxy_utils_module: + with patch.dict(sys.modules, {"anthropic": anthropic_module}): + result = asyncio.run( + proxy_utils_module.count_tokens_with_anthropic_api( + model_to_use="claude-sonnet-4", + messages=messages, + deployment=deployment, + ) + ) + + mock_anthropic_class.assert_called_once_with( + api_key="test-key", base_url="https://custom.anthropic.example" + ) + mock_client_instance.beta.messages.count_tokens.assert_called_once_with( + model="claude-sonnet-4", + messages=messages, + betas=["token-counting-2024-11-01"], + ) + assert result == { + "total_tokens": 123, + "tokenizer_used": "anthropic_api", + } + + +def test_count_tokens_with_anthropic_api_respects_env_api_base(monkeypatch): + """Ensure Anthropic client honors ANTHROPIC_API_BASE env fallback.""" + import asyncio + from types import SimpleNamespace + from unittest.mock import MagicMock, patch + import sys + + messages = [{"role": "user", "content": "Hi"}] + deployment = {"litellm_params": {"api_key": "test-key"}} + monkeypatch.setenv("ANTHROPIC_API_BASE", "https://env.anthropic.example") + + mock_client_instance = MagicMock() + _setup_mock_anthropic_client(mock_client_instance, input_tokens=456) + mock_anthropic_class = MagicMock(return_value=mock_client_instance) + + anthropic_module = SimpleNamespace(Anthropic=mock_anthropic_class) + + with _local_proxy_utils_module() as proxy_utils_module: + with patch.dict(sys.modules, {"anthropic": anthropic_module}): + result = asyncio.run( + proxy_utils_module.count_tokens_with_anthropic_api( + model_to_use="claude-sonnet-4", + messages=messages, + deployment=deployment, + ) + ) + + mock_anthropic_class.assert_called_once_with( + api_key="test-key", base_url="https://env.anthropic.example" + ) + assert result == { + "total_tokens": 456, + "tokenizer_used": "anthropic_api", + } + + @pytest.mark.asyncio async def test_factory_gpt4_endpoint_does_not_call_anthropic_counter(): """Test that /v1/messages/count_tokens with GPT-4 does NOT use Anthropic counter."""