Skip to content

Commit f3cc4e3

Browse files
authored
Merge pull request vllm-project#4 from markmc/nixl-prometheus-abstraction
[NIXL][Metrics] Add abstraction for per-connector Prometheus metrics
2 parents f6b39fa + 4e87197 commit f3cc4e3

File tree

5 files changed

+267
-133
lines changed

5 files changed

+267
-133
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@
5050
from vllm.attention.backends.abstract import AttentionMetadata
5151
from vllm.config import VllmConfig
5252
from vllm.distributed.kv_events import KVCacheEvent
53-
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
53+
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
54+
KVConnectorPromMetrics,
55+
KVConnectorStats,
56+
PromMetric,
57+
PromMetricT,
58+
)
5459
from vllm.forward_context import ForwardContext
5560
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
5661
from vllm.v1.request import Request
@@ -431,3 +436,17 @@ def build_kv_connector_stats(
431436
which can implement custom aggregation logic on the data dict.
432437
"""
433438
return None
439+
440+
@classmethod
441+
def build_prom_metrics(
442+
cls,
443+
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
444+
labelnames: list[str],
445+
per_engine_labelvalues: dict[int, list[str]],
446+
) -> Optional["KVConnectorPromMetrics"]:
447+
"""
448+
Create a KVConnectorPromMetrics subclass which should register
449+
per-connector Prometheus metrics and implement observe() to
450+
expose connector transfer stats via Prometheus.
451+
"""
452+
return None

vllm/distributed/kv_transfer/kv_connector/v1/metrics.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from dataclasses import dataclass, field
4-
from typing import Any
4+
from typing import Any, TypeAlias, TypeVar
5+
6+
from prometheus_client import Counter, Gauge, Histogram
57

68
from vllm.config.kv_transfer import KVTransferConfig
79
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
810
from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group
911
from vllm.logger import init_logger
1012

13+
PromMetric: TypeAlias = Gauge | Counter | Histogram
14+
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)
15+
1116
logger = init_logger(__name__)
1217

1318

@@ -102,3 +107,79 @@ def log(self, log_fn=logger.info):
102107

103108
# Reset metrics for next interval
104109
self.reset()
110+
111+
112+
class KVConnectorPromMetrics:
113+
"""
114+
A base class for per-connector Prometheus metric registration
115+
and recording.
116+
"""
117+
118+
def __init__(
119+
self,
120+
metric_types: dict[type[PromMetric], type[PromMetricT]],
121+
labelnames: list[str],
122+
per_engine_labelvalues: dict[int, list[str]],
123+
):
124+
self._gauge_cls = metric_types[Gauge]
125+
self._counter_cls = metric_types[Counter]
126+
self._histogram_cls = metric_types[Histogram]
127+
self._labelnames = labelnames
128+
self._per_engine_labelvalues = per_engine_labelvalues
129+
130+
def make_per_engine(self, metric: PromMetric) -> PromMetric:
131+
"""
132+
Create a per-engine child of a prometheus_client.Metric with
133+
the appropriate labels set. The parent metric must be created
134+
using the labelnames list.
135+
"""
136+
return {
137+
idx: metric.labels(*labelvalues)
138+
for idx, labelvalues in self._per_engine_labelvalues.items()
139+
}
140+
141+
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
142+
"""
143+
Record the supplied transfer statistics to Prometheus metrics. These
144+
statistics are engine-specific, and should be recorded to a metric
145+
with the appropriate 'engine' label. These metric instances can be
146+
created using the make_per_engine() helper method.
147+
"""
148+
raise NotImplementedError
149+
150+
151+
class KVConnectorPrometheus:
152+
"""
153+
Support for registering per-connector Prometheus metrics, and
154+
recording transfer statistics to those metrics. Uses
155+
KVConnectorBase.build_prom_metrics().
156+
"""
157+
158+
_gauge_cls = Gauge
159+
_counter_cls = Counter
160+
_histogram_cls = Histogram
161+
162+
def __init__(
163+
self,
164+
kv_transfer_config: KVTransferConfig | None,
165+
labelnames: list[str],
166+
per_engine_labelvalues: dict[int, list[str]],
167+
):
168+
self.prom_metrics: KVConnectorPromMetrics | None = None
169+
if kv_transfer_config and kv_transfer_config.kv_connector:
170+
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
171+
metric_types = {
172+
Gauge: self._gauge_cls,
173+
Counter: self._counter_cls,
174+
Histogram: self._histogram_cls,
175+
}
176+
self.prom_metrics = connector_cls.build_prom_metrics(
177+
metric_types,
178+
labelnames,
179+
per_engine_labelvalues,
180+
)
181+
182+
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
183+
if self.prom_metrics is None:
184+
return
185+
self.prom_metrics.observe(transfer_stats_data, engine_idx)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@
3030
KVConnectorMetadata,
3131
KVConnectorRole,
3232
)
33-
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
33+
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
34+
KVConnectorPromMetrics,
35+
KVConnectorStats,
36+
PromMetric,
37+
PromMetricT,
38+
)
3439
from vllm.distributed.parallel_state import (
3540
get_tensor_model_parallel_rank,
3641
get_tensor_model_parallel_world_size,
@@ -254,6 +259,15 @@ def build_kv_connector_stats(
254259
else NixlKVConnectorStats()
255260
)
256261

262+
@classmethod
263+
def build_prom_metrics(
264+
cls,
265+
metric_types: dict[type[PromMetric], type[PromMetricT]],
266+
labelnames: list[str],
267+
per_engine_labelvalues: dict[int, list[str]],
268+
) -> KVConnectorPromMetrics:
269+
return NixlPromMetrics(metric_types, labelnames, per_engine_labelvalues)
270+
257271
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
258272
assert self.connector_worker is not None
259273
assert isinstance(self._connector_metadata, NixlConnectorMetadata)
@@ -1744,3 +1758,124 @@ def reduce(self) -> dict[str, int | float]:
17441758
@property
17451759
def num_successful_transfers(self) -> int:
17461760
return len(self.data["transfer_duration"])
1761+
1762+
1763+
class NixlPromMetrics(KVConnectorPromMetrics):
1764+
def __init__(
1765+
self,
1766+
metric_types: dict[type[PromMetric], type[PromMetricT]],
1767+
labelnames: list[str],
1768+
per_engine_labelvalues: dict[int, list[str]],
1769+
):
1770+
super().__init__(metric_types, labelnames, per_engine_labelvalues)
1771+
1772+
buckets = [
1773+
0.001,
1774+
0.005,
1775+
0.01,
1776+
0.025,
1777+
0.05,
1778+
0.075,
1779+
0.1,
1780+
0.2,
1781+
0.3,
1782+
0.5,
1783+
0.75,
1784+
1.0,
1785+
5.0,
1786+
]
1787+
nixl_histogram_xfer_time = self._histogram_cls(
1788+
name="vllm:nixl_xfer_time_seconds",
1789+
documentation="Histogram of transfer duration for NIXL KV Cache transfers.",
1790+
buckets=buckets,
1791+
labelnames=labelnames,
1792+
)
1793+
self.nixl_histogram_xfer_time = self.make_per_engine(nixl_histogram_xfer_time)
1794+
nixl_histogram_post_time = self._histogram_cls(
1795+
name="vllm:nixl_post_time_seconds",
1796+
documentation="Histogram of transfer post time for NIXL KV"
1797+
" Cache transfers.",
1798+
buckets=buckets[1:],
1799+
labelnames=labelnames,
1800+
)
1801+
self.nixl_histogram_post_time = self.make_per_engine(nixl_histogram_post_time)
1802+
# uniform 2kb to 16gb range
1803+
buckets = [2**10 + i for i in range(1, 24, 2)]
1804+
nixl_histogram_bytes_transferred = self._histogram_cls(
1805+
name="vllm:nixl_bytes_transferred",
1806+
documentation="Histogram of bytes transferred per NIXL KV Cache transfers.",
1807+
buckets=buckets,
1808+
labelnames=labelnames,
1809+
)
1810+
self.nixl_histogram_bytes_transferred = self.make_per_engine(
1811+
nixl_histogram_bytes_transferred
1812+
)
1813+
buckets = [
1814+
10,
1815+
20,
1816+
30,
1817+
50,
1818+
75,
1819+
100,
1820+
200,
1821+
400,
1822+
1000,
1823+
2000,
1824+
4000,
1825+
10000,
1826+
20000,
1827+
50000,
1828+
]
1829+
nixl_histogram_num_descriptors = self._histogram_cls(
1830+
name="vllm:nixl_num_descriptors",
1831+
documentation="Histogram of number of descriptors per NIXL"
1832+
" KV Cache transfers.",
1833+
buckets=buckets,
1834+
labelnames=labelnames,
1835+
)
1836+
self.nixl_histogram_num_descriptors = self.make_per_engine(
1837+
nixl_histogram_num_descriptors
1838+
)
1839+
counter_nixl_num_failed_transfers = self._counter_cls(
1840+
name="vllm:nixl_num_failed_transfers",
1841+
documentation="Number of failed NIXL KV Cache transfers.",
1842+
labelnames=labelnames,
1843+
)
1844+
self.counter_nixl_num_failed_transfers = self.make_per_engine(
1845+
counter_nixl_num_failed_transfers
1846+
)
1847+
counter_nixl_num_failed_notifications = self._counter_cls(
1848+
name="vllm:nixl_num_failed_notifications",
1849+
documentation="Number of failed NIXL KV Cache notifications.",
1850+
labelnames=labelnames,
1851+
)
1852+
self.counter_nixl_num_failed_notifications = self.make_per_engine(
1853+
counter_nixl_num_failed_notifications
1854+
)
1855+
1856+
def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
1857+
for prom_obj, list_item_key in zip(
1858+
[
1859+
self.nixl_histogram_xfer_time,
1860+
self.nixl_histogram_post_time,
1861+
self.nixl_histogram_bytes_transferred,
1862+
self.nixl_histogram_num_descriptors,
1863+
],
1864+
[
1865+
"transfer_duration",
1866+
"post_duration",
1867+
"bytes_transferred",
1868+
"num_descriptors",
1869+
],
1870+
):
1871+
for list_item in transfer_stats_data[list_item_key]:
1872+
prom_obj[engine_idx].observe(list_item)
1873+
for counter_obj, counter_item_key in zip(
1874+
[
1875+
self.counter_nixl_num_failed_transfers,
1876+
self.counter_nixl_num_failed_notifications,
1877+
],
1878+
["num_failed_transfers", "num_failed_notifications"],
1879+
):
1880+
for list_item in transfer_stats_data[counter_item_key]:
1881+
counter_obj[engine_idx].inc(list_item)

0 commit comments

Comments
 (0)