Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
Expand Down Expand Up @@ -471,3 +476,18 @@ def build_kv_connector_stats(
which can implement custom aggregation logic on the data dict.
"""
return None

@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> Optional["KVConnectorPromMetrics"]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
"""
return None
89 changes: 87 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from typing import Any
from typing import Any, TypeAlias, TypeVar

from vllm.config.kv_transfer import KVTransferConfig
from prometheus_client import Counter, Gauge, Histogram

from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_transfer_state import has_kv_transfer_group
from vllm.logger import init_logger

PromMetric: TypeAlias = Gauge | Counter | Histogram
PromMetricT = TypeVar("PromMetricT", bound=PromMetric)

logger = init_logger(__name__)


Expand Down Expand Up @@ -102,3 +107,83 @@ def log(self, log_fn=logger.info):

# Reset metrics for next interval
self.reset()


class KVConnectorPromMetrics:
"""
A base class for per-connector Prometheus metric registration
and recording.
"""

def __init__(
self,
vllm_config: VllmConfig,
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self._kv_transfer_config = vllm_config.kv_transfer_config
self._gauge_cls = metric_types[Gauge]
self._counter_cls = metric_types[Counter]
self._histogram_cls = metric_types[Histogram]
self._labelnames = labelnames
self._per_engine_labelvalues = per_engine_labelvalues

def make_per_engine(self, metric: PromMetric) -> PromMetric:
"""
Create a per-engine child of a prometheus_client.Metric with
the appropriate labels set. The parent metric must be created
using the labelnames list.
"""
return {
idx: metric.labels(*labelvalues)
for idx, labelvalues in self._per_engine_labelvalues.items()
}

def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
"""
Record the supplied transfer statistics to Prometheus metrics. These
statistics are engine-specific, and should be recorded to a metric
with the appropriate 'engine' label. These metric instances can be
created using the make_per_engine() helper method.
"""
raise NotImplementedError


class KVConnectorPrometheus:
"""
Support for registering per-connector Prometheus metrics, and
recording transfer statistics to those metrics. Uses
KVConnectorBase.build_prom_metrics().
"""

_gauge_cls = Gauge
_counter_cls = Counter
_histogram_cls = Histogram

def __init__(
self,
vllm_config: VllmConfig,
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
):
self.prom_metrics: KVConnectorPromMetrics | None = None
kv_transfer_config = vllm_config.kv_transfer_config
if kv_transfer_config and kv_transfer_config.kv_connector:
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
metric_types = {
Gauge: self._gauge_cls,
Counter: self._counter_cls,
Histogram: self._histogram_cls,
}
self.prom_metrics = connector_cls.build_prom_metrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
)

def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
if self.prom_metrics is None:
return
self.prom_metrics.observe(transfer_stats_data, engine_idx)
110 changes: 88 additions & 22 deletions vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@

from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
PromMetric,
PromMetricT,
)
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import KVConnectorOutput
Expand Down Expand Up @@ -72,6 +78,27 @@ def __setitem__(self, connector_id: str, stats: KVConnectorStats):
self.data[connector_id] = stats


class MultiKVConnectorPromMetrics(KVConnectorPromMetrics):
def __init__(
self,
vllm_config: "VllmConfig",
metric_types: dict[type[PromMetric], type[PromMetricT]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
prom_metrics: dict[str, KVConnectorPromMetrics],
):
super().__init__(vllm_config, metric_types, labelnames, per_engine_labelvalues)
self._prom_metrics = prom_metrics

def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0):
for connector_id, stats_data in transfer_stats_data.items():
assert connector_id in self._prom_metrics, (
f"{connector_id} is not contained in the list of registered connectors "
f"with Prometheus metrics support: {self._prom_metrics.keys()}"
)
self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx)


class MultiConnector(KVConnectorBase_V1):
"""
A wrapper for using multiple KVConnectors at the same time.
Expand All @@ -84,19 +111,13 @@ class MultiConnector(KVConnectorBase_V1):

def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)

self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
ktcs = self._kv_transfer_config.kv_connector_extra_config.get("connectors")
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", self._kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
self._connectors.append(
KVConnectorFactory.create_connector(temp_config, role)
)
for connector_cls, temp_config in self._get_connector_classes_and_configs(
vllm_config
):
self._connectors.append(connector_cls(temp_config, role))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)

# A mapping from request id to the index of the connector chosen to
Expand All @@ -109,6 +130,32 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}

@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
) -> list[tuple[type[KVConnectorBaseType], "VllmConfig"]]:
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
ret: list[tuple[type[KVConnectorBaseType], VllmConfig]] = []
for ktc in ktcs:
temp_config = copy.copy(vllm_config)
engine_id = ktc.get("engine_id", vllm_config.kv_transfer_config.engine_id)
temp_config.kv_transfer_config = KVTransferConfig(
**ktc, engine_id=engine_id
)
ret.append(
(
KVConnectorFactory.get_connector_class(
temp_config.kv_transfer_config
),
temp_config,
)
)
return ret

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
Expand Down Expand Up @@ -295,18 +342,12 @@ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
None if the connector does not require a specific layout.
"""
assert vllm_config.kv_transfer_config is not None
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors"
)
assert ktcs is not None
layouts: set[str] = set()
temp_vllm_config = copy.copy(vllm_config)
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
temp_vllm_config.kv_transfer_config = kv_transfer_config
connector_cls = KVConnectorFactory.get_connector_class(kv_transfer_config)
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
temp_vllm_config
temp_config
)
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)
Expand Down Expand Up @@ -372,3 +413,28 @@ def get_kv_connector_stats(self) -> MultiKVConnectorStats | None:
stats_by_connector = MultiKVConnectorStats()
stats_by_connector[c.__class__.__name__] = stats
return stats_by_connector

@classmethod
def build_prom_metrics(
cls,
vllm_config: "VllmConfig",
metric_types: dict[type["PromMetric"], type["PromMetricT"]],
labelnames: list[str],
per_engine_labelvalues: dict[int, list[str]],
) -> KVConnectorPromMetrics:
prom_metrics: dict[str, KVConnectorPromMetrics] = {}
for connector_cls, temp_config in cls._get_connector_classes_and_configs(
vllm_config
):
connector_prom = connector_cls.build_prom_metrics(
temp_config, metric_types, labelnames, per_engine_labelvalues
)
if connector_prom is not None:
prom_metrics[connector_cls.__name__] = connector_prom
return MultiKVConnectorPromMetrics(
vllm_config,
metric_types,
labelnames,
per_engine_labelvalues,
prom_metrics,
)
Loading