Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
Expand Down Expand Up @@ -431,3 +436,17 @@ def build_kv_connector_stats(
which can implement custom aggregation logic on the data dict.
"""
return None

@classmethod
def build_prom_metrics(
cls,
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> Optional["KVConnectorPromMetrics"]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
"""
return None
83 changes: 82 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any
from typing import Any, TypeAlias, TypeVar

from prometheus_client import Counter, Gauge, Histogram

from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group
from vllm.logger import init_logger

PromMetric: TypeAlias = Gauge | Counter | Histogram
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)

logger = init_logger(__name__)


Expand Down Expand Up @@ -102,3 +107,79 @@ def log(self, log_fn=logger.info):

# Reset metrics for next interval
self.reset()


class KVConnectorPromMetrics:
"""
A base class for per-connector Prometheus metric registration
and recording.
"""

def __init__(
self,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self._gauge_cls = metric_types[Gauge]
self._counter_cls = metric_types[Counter]
self._histogram_cls = metric_types[Histogram]
self._labelnames = labelnames
self._per_engine_labelvalues = per_engine_labelvalues

def make_per_engine(self, metric: PromMetric) -> PromMetric:
"""
Create a per-engine child of a prometheus_client.Metric with
the appropriate labels set. The parent metric must be created
using the labelnames list.
"""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in self._per_engine_labelvalues.items()
}

def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Record the supplied transfer statistics to Prometheus metrics. These
statistics are engine-specific, and should be recorded to a metric
with the appropriate 'engine' label. These metric instances can be
created using the make_per_engine() helper method.
"""
raise NotImplementedError


class KVConnectorPrometheus:
"""
Support for registering per-connector Prometheus metrics, and
recording transfer statistics to those metrics. Uses
KVConnectorBase.build_prom_metrics().
"""

_gauge_cls = Gauge
_counter_cls = Counter
_histogram_cls = Histogram

def __init__(
self,
kv_transfer_config: KVTransferConfig | None,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self.prom_metrics: KVConnectorPromMetrics | None = None
if kv_transfer_config and kv_transfer_config.kv_connector:
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
metric_types = {
Gauge: self._gauge_cls,
Counter: self._counter_cls,
Histogram: self._histogram_cls,
}
self.prom_metrics = connector_cls.build_prom_metrics(
metric_types,
labelnames,
per_engine_labelvalues,
)

def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
if self.prom_metrics is None:
return
self.prom_metrics.observe(transfer_stats_data, engine_idx)
137 changes: 136 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -254,6 +259,15 @@ def build_kv_connector_stats(
else NixlKVConnectorStats()
)

@classmethod
def build_prom_metrics(
cls,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
return NixlPromMetrics(metric_types, labelnames, per_engine_labelvalues)

def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
Expand Down Expand Up @@ -1744,3 +1758,124 @@ def reduce(self) -> dict[str, int | float]:
@property
def num_successful_transfers(self) -> int:
return len(self.data["transfer_duration"])


class NixlPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
super().__init__(metric_types, labelnames, per_engine_labelvalues)

buckets = [
0.001,
0.005,
0.01,
0.025,
0.05,
0.075,
0.1,
0.2,
0.3,
0.5,
0.75,
1.0,
5.0,
]
nixl_histogram_xfer_time = self._histogram_cls(
name="vllm:nixl_xfer_time_seconds",
documentation="Histogram of transfer duration for NIXL KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time)
nixl_histogram_post_time = self._histogram_cls(
name="vllm:nixl_post_time_seconds",
documentation="Histogram of transfer post time for NIXL KV"
" Cache transfers.",
buckets=buckets[1:],
labelnames=labelnames,
)
self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time)
# uniform 2kb to 16gb range
buckets = [2**10 + i for i in range(1, 24, 2)]
nixl_histogram_bytes_transferred = self._histogram_cls(
name="vllm:nixl_bytes_transferred",
documentation="Histogram of bytes transferred per NIXL KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_bytes_transferred = self.make_per_engine(
nixl_histogram_bytes_transferred
)
buckets = [
10,
20,
30,
50,
75,
100,
200,
400,
1000,
2000,
4000,
10000,
20000,
50000,
]
nixl_histogram_num_descriptors = self._histogram_cls(
name="vllm:nixl_num_descriptors",
documentation="Histogram of number of descriptors per NIXL"
" KV Cache transfers.",
buckets=buckets,
labelnames=labelnames,
)
self.nixl_histogram_num_descriptors = self.make_per_engine(
nixl_histogram_num_descriptors
)
counter_nixl_num_failed_transfers = self._counter_cls(
name="vllm:nixl_num_failed_transfers",
documentation="Number of failed NIXL KV Cache transfers.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_transfers = self.make_per_engine(
counter_nixl_num_failed_transfers
)
counter_nixl_num_failed_notifications = self._counter_cls(
name="vllm:nixl_num_failed_notifications",
documentation="Number of failed NIXL KV Cache notifications.",
labelnames=labelnames,
)
self.counter_nixl_num_failed_notifications = self.make_per_engine(
counter_nixl_num_failed_notifications
)

def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for prom_obj, list_item_key in zip(
[
self.nixl_histogram_xfer_time,
self.nixl_histogram_post_time,
self.nixl_histogram_bytes_transferred,
self.nixl_histogram_num_descriptors,
],
[
"transfer_duration",
"post_duration",
"bytes_transferred",
"num_descriptors",
],
):
for list_item in transfer_stats_data[list_item_key]:
prom_obj[engine_idx].observe(list_item)
for counter_obj, counter_item_key in zip(
[
self.counter_nixl_num_failed_transfers,
self.counter_nixl_num_failed_notifications,
],
["num_failed_transfers", "num_failed_notifications"],
):
for list_item in transfer_stats_data[counter_item_key]:
counter_obj[engine_idx].inc(list_item)
Loading