diff --git a/litellm/utils.py b/litellm/utils.py index ad9e36795af..3f66c9c4334 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5135,10 +5135,15 @@ def _invalidate_model_cost_lowercase_map() -> None: """Invalidate the case-insensitive lookup map for model_cost. Call this whenever litellm.model_cost is modified to ensure the map is rebuilt. + Also clears related LRU caches that depend on model_cost data. """ global _model_cost_lowercase_map _model_cost_lowercase_map = None + # Clear LRU caches that depend on model_cost data + get_model_info.cache_clear() + _cached_get_model_info_helper.cache_clear() + def _rebuild_model_cost_lowercase_map() -> Dict[str, str]: """Rebuild the case-insensitive lookup map from the current model_cost. @@ -5352,6 +5357,7 @@ def _get_max_position_embeddings(model_name: str) -> Optional[int]: return None +@lru_cache(maxsize=DEFAULT_MAX_LRU_CACHE_SIZE) def _cached_get_model_info_helper( model: str, custom_llm_provider: Optional[str] ) -> ModelInfoBase: @@ -5699,6 +5705,7 @@ def _get_model_info_helper( # noqa: PLR0915 ) +@lru_cache(maxsize=DEFAULT_MAX_LRU_CACHE_SIZE) def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> ModelInfo: """ Get a dict for the maximum tokens (context window), input_cost_per_token, output_cost_per_token for a given model. diff --git a/tests/test_litellm/test_cost_calculator.py b/tests/test_litellm/test_cost_calculator.py index 4d6599fc1b5..b59c9341e02 100644 --- a/tests/test_litellm/test_cost_calculator.py +++ b/tests/test_litellm/test_cost_calculator.py @@ -70,8 +70,6 @@ class MockResponse(BaseModel): def test_cost_calculator_with_usage(monkeypatch): - from litellm import get_model_info - os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True" litellm.model_cost = litellm.get_model_cost_map(url="") @@ -123,6 +121,10 @@ def test_cost_calculator_with_usage(monkeypatch): }, ) + # Invalidate caches after modifying litellm.model_cost + from litellm.utils import _invalidate_model_cost_lowercase_map + _invalidate_model_cost_lowercase_map() + result = response_cost_calculator( response_object=mr, model="",