Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 75 additions & 43 deletions litellm/integrations/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,18 @@ def __init__( # noqa: PLR0915
self.litellm_remaining_api_key_requests_for_model = self._gauge_factory(
"litellm_remaining_api_key_requests_for_model",
"Remaining Requests API Key can make for model (model based rpm limit on key)",
labelnames=["hashed_api_key", "api_key_alias", "model"],
labelnames=self.get_labels_for_metric(
"litellm_remaining_api_key_requests_for_model"
),
)

# Remaining MODEL TPM limit for API Key
self.litellm_remaining_api_key_tokens_for_model = self._gauge_factory(
"litellm_remaining_api_key_tokens_for_model",
"Remaining Tokens API Key can make for model (model based tpm limit on key)",
labelnames=["hashed_api_key", "api_key_alias", "model"],
labelnames=self.get_labels_for_metric(
"litellm_remaining_api_key_tokens_for_model"
),
)

########################################
Expand Down Expand Up @@ -373,15 +377,9 @@ def __init__( # noqa: PLR0915
self.litellm_llm_api_failed_requests_metric = self._counter_factory(
name="litellm_llm_api_failed_requests_metric",
documentation="deprecated - use litellm_proxy_failed_requests_metric",
labelnames=[
"end_user",
"hashed_api_key",
"api_key_alias",
"model",
"team",
"team_alias",
"user",
],
labelnames=self.get_labels_for_metric(
"litellm_llm_api_failed_requests_metric"
),
)

self.litellm_requests_metric = self._counter_factory(
Expand Down Expand Up @@ -954,6 +952,8 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
route=standard_logging_payload["metadata"].get(
"user_api_key_request_route"
),
client_ip=standard_logging_payload["metadata"].get("requester_ip_address"),
user_agent=standard_logging_payload["metadata"].get("user_agent"),
)

if (
Expand Down Expand Up @@ -1011,6 +1011,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
user_api_key_alias=user_api_key_alias,
kwargs=kwargs,
metadata=_metadata,
model_id=enum_values.model_id,
)

# set latency metrics
Expand Down Expand Up @@ -1245,6 +1246,7 @@ def _set_virtual_key_rate_limit_metrics(
user_api_key_alias: Optional[str],
kwargs: dict,
metadata: dict,
model_id: Optional[str] = None,
):
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
Expand All @@ -1266,11 +1268,11 @@ def _set_virtual_key_rate_limit_metrics(
)

self.litellm_remaining_api_key_requests_for_model.labels(
user_api_key, user_api_key_alias, model_group
user_api_key, user_api_key_alias, model_group, model_id
).set(remaining_requests)

self.litellm_remaining_api_key_tokens_for_model.labels(
user_api_key, user_api_key_alias, model_group
user_api_key, user_api_key_alias, model_group, model_id
).set(remaining_tokens)

def _set_latency_metrics(
Expand Down Expand Up @@ -1365,14 +1367,14 @@ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_ti
standard_logging_payload: StandardLoggingPayload = kwargs.get(
"standard_logging_object", {}
)

if self._should_skip_metrics_for_invalid_key(
kwargs=kwargs, standard_logging_payload=standard_logging_payload
):
return

model = kwargs.get("model", "")

litellm_params = kwargs.get("litellm_params", {}) or {}
get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking()

Expand All @@ -1396,6 +1398,7 @@ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_ti
user_api_team,
user_api_team_alias,
user_id,
standard_logging_payload.get("model_id", ""),
).inc()
self.set_llm_deployment_failure_metrics(kwargs)
except Exception as e:
Expand All @@ -1413,73 +1416,81 @@ def _extract_status_code(
) -> Optional[int]:
"""
Extract HTTP status code from various input formats for validation.

This is a centralized helper to extract status code from different
callback function signatures. Handles both ProxyException (uses 'code')
and standard exceptions (uses 'status_code').

Args:
kwargs: Dictionary potentially containing 'exception' key
enum_values: Object with 'status_code' attribute
exception: Exception object to extract status code from directly

Returns:
Status code as integer if found, None otherwise
"""
status_code = None

# Try from enum_values first (most common in our callbacks)
if enum_values and hasattr(enum_values, "status_code") and enum_values.status_code:
if (
enum_values
and hasattr(enum_values, "status_code")
and enum_values.status_code
):
try:
status_code = int(enum_values.status_code)
except (ValueError, TypeError):
pass

if not status_code and exception:
# ProxyException uses 'code' attribute, other exceptions may use 'status_code'
status_code = getattr(exception, "status_code", None) or getattr(exception, "code", None)
status_code = getattr(exception, "status_code", None) or getattr(
exception, "code", None
)
if status_code is not None:
try:
status_code = int(status_code)
except (ValueError, TypeError):
status_code = None

if not status_code and kwargs:
exception_in_kwargs = kwargs.get("exception")
if exception_in_kwargs:
status_code = getattr(exception_in_kwargs, "status_code", None) or getattr(exception_in_kwargs, "code", None)
status_code = getattr(
exception_in_kwargs, "status_code", None
) or getattr(exception_in_kwargs, "code", None)
if status_code is not None:
try:
status_code = int(status_code)
except (ValueError, TypeError):
status_code = None

return status_code

def _is_invalid_api_key_request(
self,
status_code: Optional[int],
exception: Optional[Exception] = None,
) -> bool:
"""
Determine if a request has an invalid API key based on status code and exception.

This method prevents invalid authentication attempts from being recorded in
Prometheus metrics. A 401 status code is the definitive indicator of authentication
failure. Additionally, we check exception messages for authentication error patterns
to catch cases where the exception hasn't been converted to a ProxyException yet.

Args:
status_code: HTTP status code (401 indicates authentication error)
exception: Exception object to check for auth-related error messages

Returns:
True if the request has an invalid API key and metrics should be skipped,
False otherwise
"""
if status_code == 401:
return True

# Handle cases where AssertionError is raised before conversion to ProxyException
if exception is not None:
exception_str = str(exception).lower()
Expand All @@ -1492,9 +1503,9 @@ def _is_invalid_api_key_request(
]
if any(pattern in exception_str for pattern in auth_error_patterns):
return True

return False

def _should_skip_metrics_for_invalid_key(
self,
kwargs: Optional[dict] = None,
Expand All @@ -1505,18 +1516,18 @@ def _should_skip_metrics_for_invalid_key(
) -> bool:
"""
Determine if Prometheus metrics should be skipped for invalid API key requests.

This is a centralized validation method that extracts status code and exception
information from various callback function signatures and determines if the request
represents an invalid API key attempt that should be filtered from metrics.

Args:
kwargs: Dictionary potentially containing exception and other data
user_api_key_dict: User API key authentication object (currently unused)
enum_values: Object with status_code attribute
standard_logging_payload: Standard logging payload dictionary
exception: Exception object to check directly

Returns:
True if metrics should be skipped (invalid key detected), False otherwise
"""
Expand All @@ -1525,17 +1536,17 @@ def _should_skip_metrics_for_invalid_key(
enum_values=enum_values,
exception=exception,
)

if exception is None and kwargs:
exception = kwargs.get("exception")

if self._is_invalid_api_key_request(status_code, exception=exception):
verbose_logger.debug(
"Skipping Prometheus metrics for invalid API key request: "
f"status_code={status_code}, exception={type(exception).__name__ if exception else None}"
)
return True

return False

async def async_post_call_failure_hook(
Expand Down Expand Up @@ -1576,6 +1587,10 @@ async def async_post_call_failure_hook(
litellm_params=request_data,
proxy_server_request=request_data.get("proxy_server_request", {}),
)
_metadata = request_data.get("metadata", {}) or {}
model_id = _metadata.get("model_info", {}).get("id") or request_data.get(
"model_info", {}
).get("id")
enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id,
user=user_api_key_dict.user_id,
Expand All @@ -1590,6 +1605,9 @@ async def async_post_call_failure_hook(
exception_class=self._get_exception_class_name(original_exception),
tags=_tags,
route=user_api_key_dict.request_route,
client_ip=_metadata.get("requester_ip_address"),
user_agent=_metadata.get("user_agent"),
model_id=model_id,
)
_labels = prometheus_label_factory(
supported_enum_labels=self.get_labels_for_metric(
Expand Down Expand Up @@ -1629,6 +1647,7 @@ async def async_post_call_success_hook(
):
return

_metadata = data.get("metadata", {}) or {}
enum_values = UserAPIKeyLabelValues(
end_user=user_api_key_dict.end_user_id,
hashed_api_key=user_api_key_dict.api_key,
Expand All @@ -1644,6 +1663,8 @@ async def async_post_call_success_hook(
litellm_params=data,
proxy_server_request=data.get("proxy_server_request", {}),
),
client_ip=_metadata.get("requester_ip_address"),
user_agent=_metadata.get("user_agent"),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Success hook missing model_id unlike failure hook

Medium Severity

The PR adds MODEL_ID to litellm_proxy_total_requests_metric labels and updates async_post_call_failure_hook to include model_id, but async_post_call_success_hook was not updated. This creates inconsistent metrics where failures have model_id values but successes have None.

Additional Locations (1)

Fix in Cursor Fix in Web

)
_labels = prometheus_label_factory(
supported_enum_labels=self.get_labels_for_metric(
Expand Down Expand Up @@ -1684,7 +1705,7 @@ def set_llm_deployment_failure_metrics(self, request_kwargs: dict):
exception = request_kwargs.get("exception", None)

llm_provider = _litellm_params.get("custom_llm_provider", None)

if self._should_skip_metrics_for_invalid_key(
kwargs=request_kwargs,
standard_logging_payload=standard_logging_payload,
Expand Down Expand Up @@ -1716,6 +1737,10 @@ def set_llm_deployment_failure_metrics(self, request_kwargs: dict):
"user_api_key_team_alias"
],
tags=standard_logging_payload.get("request_tags", []),
client_ip=standard_logging_payload["metadata"].get(
"requester_ip_address"
),
user_agent=standard_logging_payload["metadata"].get("user_agent"),
)

"""
Expand Down Expand Up @@ -2263,7 +2288,10 @@ async def _initialize_api_key_budget_metrics(self):

async def fetch_keys(
page_size: int, page: int
) -> Tuple[List[Union[str, UserAPIKeyAuth, LiteLLM_DeletedVerificationToken]], Optional[int]]:
) -> Tuple[
List[Union[str, UserAPIKeyAuth, LiteLLM_DeletedVerificationToken]],
Optional[int],
]:
key_list_response = await _list_key_helper(
prisma_client=prisma_client,
page=page,
Expand Down Expand Up @@ -2379,12 +2407,16 @@ async def _initialize_user_and_team_count_metrics(self):
# Get total user count
total_users = await prisma_client.db.litellm_usertable.count()
self.litellm_total_users_metric.set(total_users)
verbose_logger.debug(f"Prometheus: set litellm_total_users to {total_users}")
verbose_logger.debug(
f"Prometheus: set litellm_total_users to {total_users}"
)

# Get total team count
total_teams = await prisma_client.db.litellm_teamtable.count()
self.litellm_teams_count_metric.set(total_teams)
verbose_logger.debug(f"Prometheus: set litellm_teams_count to {total_teams}")
verbose_logger.debug(
f"Prometheus: set litellm_teams_count to {total_teams}"
)
except Exception as e:
verbose_logger.exception(
f"Error initializing user/team count metrics: {str(e)}"
Expand Down
12 changes: 10 additions & 2 deletions litellm/litellm_core_utils/litellm_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ def __init__(
self.start_time = start_time # log the call start time
self.call_type = call_type
self.litellm_call_id = litellm_call_id
self.litellm_trace_id: str = litellm_trace_id if litellm_trace_id else str(uuid.uuid4())
self.litellm_trace_id: str = (
litellm_trace_id if litellm_trace_id else str(uuid.uuid4())
)
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[
Expand Down Expand Up @@ -544,7 +546,10 @@ def update_environment_variables(
if "stream_options" in additional_params:
self.stream_options = additional_params["stream_options"]
## check if custom pricing set ##
if any(litellm_params.get(key) is not None for key in _CUSTOM_PRICING_KEYS & litellm_params.keys()):
if any(
litellm_params.get(key) is not None
for key in _CUSTOM_PRICING_KEYS & litellm_params.keys()
):
self.custom_pricing = True

if "custom_llm_provider" in self.model_call_details:
Expand Down Expand Up @@ -4453,6 +4458,7 @@ def get_standard_logging_metadata(
user_api_key_request_route=None,
spend_logs_metadata=None,
requester_ip_address=None,
user_agent=None,
requester_metadata=None,
prompt_management_metadata=prompt_management_metadata,
applied_guardrails=applied_guardrails,
Expand Down Expand Up @@ -5138,6 +5144,7 @@ def get_standard_logging_object_payload(
model_group=_model_group,
model_id=_model_id,
requester_ip_address=clean_metadata.get("requester_ip_address", None),
user_agent=clean_metadata.get("user_agent", None),
messages=StandardLoggingPayloadSetup.append_system_prompt_messages(
kwargs=kwargs, messages=kwargs.get("messages")
),
Expand Down Expand Up @@ -5203,6 +5210,7 @@ def get_standard_logging_metadata(
user_api_key_team_alias=None,
spend_logs_metadata=None,
requester_ip_address=None,
user_agent=None,
requester_metadata=None,
user_api_key_end_user_id=None,
prompt_management_metadata=None,
Expand Down
Loading
Loading