Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3753,16 +3753,25 @@ async def count_tokens_with_anthropic_api(

# Get Anthropic API key from deployment config
anthropic_api_key = None
anthropic_api_base: Optional[str] = None
if deployment is not None:
anthropic_api_key = deployment.get("litellm_params", {}).get("api_key")
litellm_params = deployment.get("litellm_params", {})
anthropic_api_key = litellm_params.get("api_key")
anthropic_api_base = litellm_params.get("api_base")

# Fallback to environment variable
if not anthropic_api_key:
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
if not anthropic_api_base:
anthropic_api_base = os.getenv("ANTHROPIC_API_BASE")

if anthropic_api_key and messages:
# Call Anthropic API directly for more accurate token counting
client = anthropic.Anthropic(api_key=anthropic_api_key)
client_kwargs: Dict[str, Any] = {"api_key": anthropic_api_key}
if anthropic_api_base:
client_kwargs["base_url"] = anthropic_api_base

client = anthropic.Anthropic(**client_kwargs)

# Call with explicit parameters to satisfy type checking
# Type ignore for now since messages come from generic dict input
Expand Down
121 changes: 118 additions & 3 deletions tests/proxy_unit_tests/test_proxy_token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@


import sys, os
from pathlib import Path
from contextlib import contextmanager
import importlib
from dotenv import load_dotenv

load_dotenv()
import os

# this file is to test litellm/proxy

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
repo_root = Path(__file__).resolve().parents[2]
repo_root_str = str(repo_root)
if repo_root_str not in sys.path:
sys.path.insert(0, repo_root_str)
for module_name in list(sys.modules.keys()):
if module_name == "litellm" or module_name.startswith("litellm."):
sys.modules.pop(module_name)
import pytest, logging
import litellm
from litellm.proxy.proxy_server import token_counter
Expand Down Expand Up @@ -517,6 +524,114 @@ async def test_factory_anthropic_endpoint_calls_anthropic_counter():
mock_anthropic_count.assert_called_once()


def _setup_mock_anthropic_client(mock_client_instance, input_tokens: int = 123):
from types import SimpleNamespace
from unittest.mock import MagicMock

mock_client_instance.beta = MagicMock()
mock_client_instance.beta.messages = MagicMock()
mock_client_instance.beta.messages.count_tokens.return_value = SimpleNamespace(
input_tokens=input_tokens
)


@contextmanager
def _local_proxy_utils_module():
original_modules = {}
for module_name in list(sys.modules.keys()):
if module_name == "litellm" or module_name.startswith("litellm."):
original_modules[module_name] = sys.modules.pop(module_name)
try:
proxy_utils_module = importlib.import_module("litellm.proxy.utils")
yield proxy_utils_module
finally:
for module_name in list(sys.modules.keys()):
if module_name == "litellm" or module_name.startswith("litellm."):
sys.modules.pop(module_name)
sys.modules.update(original_modules)


def test_count_tokens_with_anthropic_api_respects_deployment_api_base():
"""Ensure Anthropic client honors deployment-specific api_base."""
import asyncio
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import sys

messages = [{"role": "user", "content": "Hello"}]
deployment = {
"litellm_params": {
"api_key": "test-key",
"api_base": "https://custom.anthropic.example",
}
}

mock_client_instance = MagicMock()
_setup_mock_anthropic_client(mock_client_instance)
mock_anthropic_class = MagicMock(return_value=mock_client_instance)

anthropic_module = SimpleNamespace(Anthropic=mock_anthropic_class)

with _local_proxy_utils_module() as proxy_utils_module:
with patch.dict(sys.modules, {"anthropic": anthropic_module}):
result = asyncio.run(
proxy_utils_module.count_tokens_with_anthropic_api(
model_to_use="claude-sonnet-4",
messages=messages,
deployment=deployment,
)
)

mock_anthropic_class.assert_called_once_with(
api_key="test-key", base_url="https://custom.anthropic.example"
)
mock_client_instance.beta.messages.count_tokens.assert_called_once_with(
model="claude-sonnet-4",
messages=messages,
betas=["token-counting-2024-11-01"],
)
assert result == {
"total_tokens": 123,
"tokenizer_used": "anthropic_api",
}


def test_count_tokens_with_anthropic_api_respects_env_api_base(monkeypatch):
"""Ensure Anthropic client honors ANTHROPIC_API_BASE env fallback."""
import asyncio
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import sys

messages = [{"role": "user", "content": "Hi"}]
deployment = {"litellm_params": {"api_key": "test-key"}}
monkeypatch.setenv("ANTHROPIC_API_BASE", "https://env.anthropic.example")

mock_client_instance = MagicMock()
_setup_mock_anthropic_client(mock_client_instance, input_tokens=456)
mock_anthropic_class = MagicMock(return_value=mock_client_instance)

anthropic_module = SimpleNamespace(Anthropic=mock_anthropic_class)

with _local_proxy_utils_module() as proxy_utils_module:
with patch.dict(sys.modules, {"anthropic": anthropic_module}):
result = asyncio.run(
proxy_utils_module.count_tokens_with_anthropic_api(
model_to_use="claude-sonnet-4",
messages=messages,
deployment=deployment,
)
)

mock_anthropic_class.assert_called_once_with(
api_key="test-key", base_url="https://env.anthropic.example"
)
assert result == {
"total_tokens": 456,
"tokenizer_used": "anthropic_api",
}


@pytest.mark.asyncio
async def test_factory_gpt4_endpoint_does_not_call_anthropic_counter():
"""Test that /v1/messages/count_tokens with GPT-4 does NOT use Anthropic counter."""
Expand Down
Loading