diff --git a/litellm/integrations/prometheus_services.py b/litellm/integrations/prometheus_services.py index a5f2f0b5c7..55ce758ece 100644 --- a/litellm/integrations/prometheus_services.py +++ b/litellm/integrations/prometheus_services.py @@ -105,6 +105,11 @@ def _get_service_metrics_initialize( return metrics def is_metric_registered(self, metric_name) -> bool: + # Use _names_to_collectors (O(1)) instead of REGISTRY.collect() (O(n)) to avoid + # perf regression when a new Router is created per request (e.g. router_settings in DB). + names_to_collectors = getattr(self.REGISTRY, "_names_to_collectors", None) + if names_to_collectors is not None: + return metric_name in names_to_collectors for metric in self.REGISTRY.collect(): if metric_name == metric.name: return True diff --git a/tests/test_litellm/integrations/test_prometheus_services.py b/tests/test_litellm/integrations/test_prometheus_services.py index b627d31fda..ff80d7d9f8 100644 --- a/tests/test_litellm/integrations/test_prometheus_services.py +++ b/tests/test_litellm/integrations/test_prometheus_services.py @@ -1,6 +1,7 @@ import json import os import sys +import time from unittest.mock import AsyncMock, patch import pytest @@ -17,6 +18,63 @@ ) # Adds the parent directory to the system path +def test_is_metric_registered_does_not_use_registry_collect(): + """is_metric_registered() must use _names_to_collectors, not REGISTRY.collect() (perf; #19921).""" + from prometheus_client import CollectorRegistry, Counter, Histogram + + registry = CollectorRegistry() + for i in range(80): + Counter( + f"litellm_service_{i}_total_requests", + "Total requests", + labelnames=["service"], + registry=registry, + ) + Histogram( + f"litellm_service_{i}_latency", + "Latency", + labelnames=["service"], + registry=registry, + ) + + pl = PrometheusServicesLogger() + pl.REGISTRY = registry + + original_collect = registry.collect + collect_called = [] + + def track_collect(*args, **kwargs): + collect_called.append(1) + return original_collect(*args, **kwargs) + + registry.collect = track_collect + + n_calls = 30 * 2 + start = time.perf_counter() + for _ in range(30): + pl.is_metric_registered("litellm_service_0_latency") + pl.is_metric_registered("litellm_service_79_total_requests") + elapsed_s = time.perf_counter() - start + elapsed_ms = elapsed_s * 1000 + per_call_us = (elapsed_s / n_calls) * 1_000_000 if n_calls else 0 + n_collect = len(collect_called) + + path = "slow (REGISTRY.collect)" if n_collect else "fast (_names_to_collectors)" + print( + f"\n is_metric_registered: {elapsed_ms:.2f} ms total | " + f"{per_call_us:.1f} µs/call | {n_calls} calls | {n_collect} collect() | {path}\n" + ) + + assert n_collect == 0, ( + f"is_metric_registered() must not use REGISTRY.collect() when _names_to_collectors " + f"is available. Latency: {elapsed_ms:.2f} ms, {per_call_us:.1f} µs/call, {n_calls} calls, " + f"collect() called {n_collect} times." + ) + assert elapsed_s < 0.05, ( + f"is_metric_registered() took {elapsed_ms:.2f} ms for {n_calls} calls; expected <50 ms." + ) + + def test_create_gauge_new(): """Test creating a new gauge""" pl = PrometheusServicesLogger()