diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index ab5d2ecdc71b..e32c9076630e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -50,7 +50,12 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig from vllm.distributed.kv_events import KVCacheEvent - from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats + from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, + ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.request import Request @@ -431,3 +436,17 @@ def build_kv_connector_stats( which can implement custom aggregation logic on the data dict. """ return None + + @classmethod + def build_prom_metrics( + cls, + metric_types: dict[type["PromMetric"], type["PromMetricT"]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ) -> Optional["KVConnectorPromMetrics"]: + """ + Create a KVConnectorPromMetrics subclass which should register + per-connector Prometheus metrics and implement observe() to + expose connector transfer stats via Prometheus. + """ + return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py index 21002fe572c5..5cb739c21a69 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/metrics.py @@ -1,13 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass, field -from typing import Any +from typing import Any, TypeAlias, TypeVar + +from prometheus_client import Counter, Gauge, Histogram from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group from vllm.logger import init_logger +PromMetric: TypeAlias = Gauge | Counter | Histogram +PromMetricT = TypeVar("PromMetricT", bound=PromMetric) + logger = init_logger(__name__) @@ -102,3 +107,79 @@ def log(self, log_fn=logger.info): # Reset metrics for next interval self.reset() + + +class KVConnectorPromMetrics: + """ + A base class for per-connector Prometheus metric registration + and recording. + """ + + def __init__( + self, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ): + self._gauge_cls = metric_types[Gauge] + self._counter_cls = metric_types[Counter] + self._histogram_cls = metric_types[Histogram] + self._labelnames = labelnames + self._per_engine_labelvalues = per_engine_labelvalues + + def make_per_engine(self, metric: PromMetric) -> PromMetric: + """ + Create a per-engine child of a prometheus_client.Metric with + the appropriate labels set. The parent metric must be created + using the labelnames list. + """ + return { + idx: metric.labels(*labelvalues) + for idx, labelvalues in self._per_engine_labelvalues.items() + } + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + """ + Record the supplied transfer statistics to Prometheus metrics. These + statistics are engine-specific, and should be recorded to a metric + with the appropriate 'engine' label. These metric instances can be + created using the make_per_engine() helper method. + """ + raise NotImplementedError + + +class KVConnectorPrometheus: + """ + Support for registering per-connector Prometheus metrics, and + recording transfer statistics to those metrics. Uses + KVConnectorBase.build_prom_metrics(). + """ + + _gauge_cls = Gauge + _counter_cls = Counter + _histogram_cls = Histogram + + def __init__( + self, + kv_transfer_config: KVTransferConfig | None, + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ): + self.prom_metrics: KVConnectorPromMetrics | None = None + if kv_transfer_config and kv_transfer_config.kv_connector: + connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config) + metric_types = { + Gauge: self._gauge_cls, + Counter: self._counter_cls, + Histogram: self._histogram_cls, + } + self.prom_metrics = connector_cls.build_prom_metrics( + metric_types, + labelnames, + per_engine_labelvalues, + ) + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + if self.prom_metrics is None: + return + self.prom_metrics.observe(transfer_stats_data, engine_idx) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 8c4c82f76ff2..1774834be8c5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -30,7 +30,12 @@ KVConnectorMetadata, KVConnectorRole, ) -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorPromMetrics, + KVConnectorStats, + PromMetric, + PromMetricT, +) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -254,6 +259,15 @@ def build_kv_connector_stats( else NixlKVConnectorStats() ) + @classmethod + def build_prom_metrics( + cls, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ) -> KVConnectorPromMetrics: + return NixlPromMetrics(metric_types, labelnames, per_engine_labelvalues) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None assert isinstance(self._connector_metadata, NixlConnectorMetadata) @@ -1744,3 +1758,124 @@ def reduce(self) -> dict[str, int | float]: @property def num_successful_transfers(self) -> int: return len(self.data["transfer_duration"]) + + +class NixlPromMetrics(KVConnectorPromMetrics): + def __init__( + self, + metric_types: dict[type[PromMetric], type[PromMetricT]], + labelnames: list[str], + per_engine_labelvalues: dict[int, list[str]], + ): + super().__init__(metric_types, labelnames, per_engine_labelvalues) + + buckets = [ + 0.001, + 0.005, + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.2, + 0.3, + 0.5, + 0.75, + 1.0, + 5.0, + ] + nixl_histogram_xfer_time = self._histogram_cls( + name="vllm:nixl_xfer_time_seconds", + documentation="Histogram of transfer duration for NIXL KV Cache transfers.", + buckets=buckets, + labelnames=labelnames, + ) + self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time) + nixl_histogram_post_time = self._histogram_cls( + name="vllm:nixl_post_time_seconds", + documentation="Histogram of transfer post time for NIXL KV" + " Cache transfers.", + buckets=buckets[1:], + labelnames=labelnames, + ) + self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time) + # uniform 2kb to 16gb range + buckets = [2**10 + i for i in range(1, 24, 2)] + nixl_histogram_bytes_transferred = self._histogram_cls( + name="vllm:nixl_bytes_transferred", + documentation="Histogram of bytes transferred per NIXL KV Cache transfers.", + buckets=buckets, + labelnames=labelnames, + ) + self.nixl_histogram_bytes_transferred = self.make_per_engine( + nixl_histogram_bytes_transferred + ) + buckets = [ + 10, + 20, + 30, + 50, + 75, + 100, + 200, + 400, + 1000, + 2000, + 4000, + 10000, + 20000, + 50000, + ] + nixl_histogram_num_descriptors = self._histogram_cls( + name="vllm:nixl_num_descriptors", + documentation="Histogram of number of descriptors per NIXL" + " KV Cache transfers.", + buckets=buckets, + labelnames=labelnames, + ) + self.nixl_histogram_num_descriptors = self.make_per_engine( + nixl_histogram_num_descriptors + ) + counter_nixl_num_failed_transfers = self._counter_cls( + name="vllm:nixl_num_failed_transfers", + documentation="Number of failed NIXL KV Cache transfers.", + labelnames=labelnames, + ) + self.counter_nixl_num_failed_transfers = self.make_per_engine( + counter_nixl_num_failed_transfers + ) + counter_nixl_num_failed_notifications = self._counter_cls( + name="vllm:nixl_num_failed_notifications", + documentation="Number of failed NIXL KV Cache notifications.", + labelnames=labelnames, + ) + self.counter_nixl_num_failed_notifications = self.make_per_engine( + counter_nixl_num_failed_notifications + ) + + def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): + for prom_obj, list_item_key in zip( + [ + self.nixl_histogram_xfer_time, + self.nixl_histogram_post_time, + self.nixl_histogram_bytes_transferred, + self.nixl_histogram_num_descriptors, + ], + [ + "transfer_duration", + "post_duration", + "bytes_transferred", + "num_descriptors", + ], + ): + for list_item in transfer_stats_data[list_item_key]: + prom_obj[engine_idx].observe(list_item) + for counter_obj, counter_item_key in zip( + [ + self.counter_nixl_num_failed_transfers, + self.counter_nixl_num_failed_notifications, + ], + ["num_failed_transfers", "num_failed_notifications"], + ): + for list_item in transfer_stats_data[counter_item_key]: + counter_obj[engine_idx].inc(list_item) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 99e0b68be50a..7682c3401b39 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -10,7 +10,10 @@ from prometheus_client import Counter, Gauge, Histogram from vllm.config import SupportsMetricsInfo, VllmConfig -from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( + KVConnectorLogging, + KVConnectorPrometheus, +) from vllm.logger import init_logger from vllm.v1.engine import FinishReason from vllm.v1.metrics.prometheus import unregister_vllm_metrics @@ -308,6 +311,7 @@ class PrometheusStatLogger(AggregateStatLoggerBase): _counter_cls = Counter _histogram_cls = Histogram _spec_decoding_cls = SpecDecodingProm + _kv_connector_cls = KVConnectorPrometheus def __init__( self, vllm_config: VllmConfig, engine_indexes: list[int] | None = None @@ -327,12 +331,15 @@ def __init__( model_name = vllm_config.model_config.served_model_name max_model_len = vllm_config.model_config.max_model_len - spec_decode_labelvalues: dict[int, list[str]] = { + per_engine_labelvalues: dict[int, list[str]] = { idx: [model_name, str(idx)] for idx in engine_indexes } self.spec_decoding_prom = self._spec_decoding_cls( - vllm_config.speculative_config, labelnames, spec_decode_labelvalues + vllm_config.speculative_config, labelnames, per_engine_labelvalues + ) + self.kv_connector_prom = self._kv_connector_cls( + vllm_config.kv_transfer_config, labelnames, per_engine_labelvalues ) # @@ -804,104 +811,6 @@ def __init__( ], ) - # - # KVConnector metrics - # - self._nixl_metrics_enabled = False - if ( - kv_transfer_config := vllm_config.kv_transfer_config - ) and kv_transfer_config.kv_connector == "NixlConnector": - self._nixl_metrics_enabled = True - buckets = [ - 0.001, - 0.005, - 0.01, - 0.025, - 0.05, - 0.075, - 0.1, - 0.2, - 0.3, - 0.5, - 0.75, - 1.0, - 5.0, - ] - nixl_histogram_xfer_time = self._histogram_cls( - name="vllm:nixl_xfer_time_seconds", - documentation="Histogram of transfer duration for NIXL KV" - " Cache transfers.", - buckets=buckets, - labelnames=labelnames, - ) - self.nixl_histogram_xfer_time = make_per_engine( - nixl_histogram_xfer_time, engine_indexes, model_name - ) - nixl_histogram_post_time = self._histogram_cls( - name="vllm:nixl_post_time_seconds", - documentation="Histogram of transfer post time for NIXL KV" - " Cache transfers.", - buckets=buckets[1:], - labelnames=labelnames, - ) - self.nixl_histogram_post_time = make_per_engine( - nixl_histogram_post_time, engine_indexes, model_name - ) - # uniform 2kb to 16gb range - buckets = [2**10 + i for i in range(1, 24, 2)] - nixl_histogram_bytes_transferred = self._histogram_cls( - name="vllm:nixl_bytes_transferred", - documentation="Histogram of bytes transferred per NIXL KV" - " Cache transfers.", - buckets=buckets, - labelnames=labelnames, - ) - self.nixl_histogram_bytes_transferred = make_per_engine( - nixl_histogram_bytes_transferred, engine_indexes, model_name - ) - buckets = [ - 10, - 20, - 30, - 50, - 75, - 100, - 200, - 400, - 1000, - 2000, - 4000, - 10000, - 20000, - 50000, - ] - nixl_histogram_num_descriptors = self._histogram_cls( - name="vllm:nixl_num_descriptors", - documentation="Histogram of number of descriptors per NIXL" - " KV Cache transfers.", - buckets=buckets, - labelnames=labelnames, - ) - self.nixl_histogram_num_descriptors = make_per_engine( - nixl_histogram_num_descriptors, engine_indexes, model_name - ) - counter_nixl_num_failed_transfers = self._counter_cls( - name="vllm:nixl_num_failed_transfers", - documentation="Number of failed NIXL KV Cache transfers.", - labelnames=labelnames, - ) - self.counter_nixl_num_failed_transfers = make_per_engine( - counter_nixl_num_failed_transfers, engine_indexes, model_name - ) - counter_nixl_num_failed_notifications = self._counter_cls( - name="vllm:nixl_num_failed_notifications", - documentation="Number of failed NIXL KV Cache notifications.", - labelnames=labelnames, - ) - self.counter_nixl_num_failed_notifications = make_per_engine( - counter_nixl_num_failed_notifications, engine_indexes, model_name - ) - def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): metrics_info = config_obj.metrics_info() metrics_info["engine"] = "" @@ -967,35 +876,11 @@ def record( self.spec_decoding_prom.observe( scheduler_stats.spec_decoding_stats, engine_idx ) - # TODO factor this out into OOT metrics class - if self._nixl_metrics_enabled and ( - kv_stats := scheduler_stats.kv_connector_stats - ): - for prom_obj, list_item_key in zip( - [ - self.nixl_histogram_xfer_time, - self.nixl_histogram_post_time, - self.nixl_histogram_bytes_transferred, - self.nixl_histogram_num_descriptors, - ], - [ - "transfer_duration", - "post_duration", - "bytes_transferred", - "num_descriptors", - ], - ): - for list_item in kv_stats[list_item_key]: - prom_obj[engine_idx].observe(list_item) - for counter_obj, counter_item_key in zip( - [ - self.counter_nixl_num_failed_transfers, - self.counter_nixl_num_failed_notifications, - ], - ["num_failed_transfers", "num_failed_notifications"], - ): - for list_item in kv_stats[counter_item_key]: - counter_obj[engine_idx].inc(list_item) + + if scheduler_stats.kv_connector_stats is not None: + self.kv_connector_prom.observe( + scheduler_stats.kv_connector_stats, engine_idx + ) if mm_cache_stats is not None: self.counter_mm_cache_queries[engine_idx].inc(mm_cache_stats.queries) diff --git a/vllm/v1/metrics/ray_wrappers.py b/vllm/v1/metrics/ray_wrappers.py index b845852a0c0d..a319ffb1d257 100644 --- a/vllm/v1/metrics/ray_wrappers.py +++ b/vllm/v1/metrics/ray_wrappers.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorPrometheus from vllm.v1.metrics.loggers import PrometheusStatLogger from vllm.v1.spec_decode.metrics import SpecDecodingProm @@ -141,6 +142,18 @@ class RaySpecDecodingProm(SpecDecodingProm): _counter_cls = RayCounterWrapper +class RayKVConnectorPrometheus(KVConnectorPrometheus): + """ + RayKVConnectorPrometheus is used by RayMetrics to log Ray + metrics. Provides the same metrics as KV connectors but + uses Ray's util.metrics library. + """ + + _gauge_cls = RayGaugeWrapper + _counter_cls = RayCounterWrapper + _histogram_cls = RayHistogramWrapper + + class RayPrometheusStatLogger(PrometheusStatLogger): """RayPrometheusStatLogger uses Ray metrics instead.""" @@ -148,6 +161,7 @@ class RayPrometheusStatLogger(PrometheusStatLogger): _counter_cls = RayCounterWrapper _histogram_cls = RayHistogramWrapper _spec_decoding_cls = RaySpecDecodingProm + _kv_connector_cls = RayKVConnectorPrometheus @staticmethod def _unregister_vllm_metrics():