Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
366 changes: 188 additions & 178 deletions tests/v1/kv_connector/unit/test_lmcache_connector.py

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,12 +759,9 @@ def test_multi_connector_overrides_all_base_methods():
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
"""
# These are fine to inherit from KVConnectorBase_V1
# TODO(https://github.com/vllm-project/vllm/pull/31811): Remove
# get_kv_connector_kv_cache_events from INHERITED_OK once implemented.
INHERITED_OK = {
"role",
"has_connector_metadata",
"get_kv_connector_kv_cache_events",
}

base_members = {
Expand Down
35 changes: 0 additions & 35 deletions vllm/distributed/kv_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,41 +179,6 @@ def __repr__(self) -> str:
)


class KVConnectorKVEvents(ABC):
"""
Abstract base class for KV events.
Acts as a container for KV events from the connector.
"""

@abstractmethod
def add_events(self, events: list[KVCacheEvent]) -> None:
raise NotImplementedError

@abstractmethod
def aggregate(self) -> "KVConnectorKVEvents":
raise NotImplementedError

@abstractmethod
def increment_workers(self, count: int = 1) -> None:
raise NotImplementedError

@abstractmethod
def get_all_events(self) -> list[KVCacheEvent]:
raise NotImplementedError

@abstractmethod
def get_number_of_workers(self) -> int:
raise NotImplementedError

@abstractmethod
def clear_events(self) -> None:
raise NotImplementedError

def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
self.add_events(other.get_all_events())
return self


class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism
support.
Expand Down
15 changes: 0 additions & 15 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def update_finished_set(
finished_recving = set[str]()
aggregated_kv_connector_stats = None
aggregated_kv_connector_worker_meta = None
combined_kv_cache_events = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
assert model_runner_output is not None
Expand Down Expand Up @@ -139,19 +138,6 @@ def update_finished_set(
)
)

# Combine kv_cache_events from all workers.
if combined_kv_cache_events is None:
# Use the first worker's kv_cache events as start event list.
combined_kv_cache_events = kv_output.kv_cache_events
elif kv_cache_events := kv_output.kv_cache_events:
assert isinstance(
combined_kv_cache_events,
type(kv_cache_events),
)
worker_kv_cache_events = kv_cache_events.get_all_events()
combined_kv_cache_events.add_events(worker_kv_cache_events)
combined_kv_cache_events.increment_workers(1)

invalid_block_ids |= kv_output.invalid_block_ids

# select output of the worker specified by output_rank
Expand All @@ -162,7 +148,6 @@ def update_finished_set(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
kv_cache_events=combined_kv_cache_events or None,
kv_connector_worker_meta=aggregated_kv_connector_worker_meta or None,
invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count,
Expand Down
10 changes: 1 addition & 9 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent, KVConnectorKVEvents
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
KVConnectorStats,
Expand Down Expand Up @@ -412,14 +412,6 @@ def get_kv_connector_stats(self) -> "KVConnectorStats | None":
"""
return None

def get_kv_connector_kv_cache_events(self) -> "KVConnectorKVEvents | None":
"""
Get the KV connector kv cache events collected during the last interval.
This function should be called by the model runner every time after the
model execution and before cleanup.
"""
return None

def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
Get the KVConnector handshake metadata for this connector.
Expand Down
86 changes: 30 additions & 56 deletions vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import torch
Expand All @@ -9,13 +10,13 @@
from vllm.distributed.kv_events import (
BlockStored,
KVCacheEvent,
KVConnectorKVEvents,
KVEventAggregator,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
KVConnectorWorkerMetadata,
)
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
Expand All @@ -31,42 +32,19 @@
logger = init_logger(__name__)


class LMCacheKVEvents(KVConnectorKVEvents):
"""
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
"""
@dataclass
class LMCacheWorkerMetadata(KVConnectorWorkerMetadata):
"""Worker metadata for LMCache connector."""

def __init__(self, num_workers: int) -> None:
self._aggregator = KVEventAggregator(num_workers)
kv_events: list[KVCacheEvent] = field(default_factory=list)
num_workers: int = 1

def add_events(self, events: list[KVCacheEvent]) -> None:
self._aggregator.add_events(events)

def aggregate(self) -> "LMCacheKVEvents":
"""
Aggregate KV events and retain only common events.
"""
common_events = self._aggregator.get_common_events()
self._aggregator.clear_events()
self._aggregator.add_events(common_events)
self._aggregator.reset_workers()
return self

def increment_workers(self, count: int = 1) -> None:
self._aggregator.increment_workers(count)

def get_all_events(self) -> list[KVCacheEvent]:
return self._aggregator.get_all_events()

def get_number_of_workers(self) -> int:
return self._aggregator.get_number_of_workers()

def clear_events(self) -> None:
self._aggregator.clear_events()
self._aggregator.reset_workers()

def __repr__(self) -> str:
return f"<LMCacheKVEvents events={self.get_all_events()}>"
def aggregate(self, other: KVConnectorWorkerMetadata) -> "LMCacheWorkerMetadata":
assert isinstance(other, LMCacheWorkerMetadata)
return LMCacheWorkerMetadata(
kv_events=self.kv_events + other.kv_events,
num_workers=self.num_workers + other.num_workers,
)


class LMCacheConnectorV1(KVConnectorBase_V1):
Expand Down Expand Up @@ -112,7 +90,8 @@ def __init__(

self._lmcache_engine = cls(vllm_config, role, self)

self._kv_cache_events: LMCacheKVEvents | None = None
# Accumulated worker metadata across steps (scheduler-side).
self._accumulated_worker_meta: LMCacheWorkerMetadata | None = None

# ==============================
# Worker-side methods
Expand Down Expand Up @@ -227,11 +206,10 @@ def get_block_ids_with_load_errors(self) -> set[int]:
# Fallback for older versions that don't support this method
return set()

def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
def build_connector_worker_meta(self) -> LMCacheWorkerMetadata | None:
"""
Get the KV connector kv cache events collected during the last interval.
Build worker metadata from this step.
"""

events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
if not events:
return None
Expand All @@ -249,9 +227,7 @@ def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
for e in events
]

lmcache_kv_events = LMCacheKVEvents(num_workers=1)
lmcache_kv_events.add_events(blocks)
return lmcache_kv_events
return LMCacheWorkerMetadata(kv_events=blocks, num_workers=1)

# ==============================
# Scheduler-side methods
Expand Down Expand Up @@ -308,17 +284,15 @@ def update_connector_output(self, connector_output: KVConnectorOutput):
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
# Get the KV events
kv_cache_events = connector_output.kv_cache_events
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
worker_meta = connector_output.kv_connector_worker_meta
if not worker_meta or not isinstance(worker_meta, LMCacheWorkerMetadata):
return

if self._kv_cache_events is None:
self._kv_cache_events = kv_cache_events
if self._accumulated_worker_meta is None:
self._accumulated_worker_meta = worker_meta
else:
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
self._kv_cache_events.increment_workers(
kv_cache_events.get_number_of_workers()
self._accumulated_worker_meta = self._accumulated_worker_meta.aggregate(
worker_meta
)
return

Expand Down Expand Up @@ -346,9 +320,9 @@ def take_events(self) -> Iterable["KVCacheEvent"]:
Yields:
New KV cache events since the last call.
"""
if self._kv_cache_events is not None:
self._kv_cache_events.aggregate()
kv_cache_events = self._kv_cache_events.get_all_events()
yield from kv_cache_events
self._kv_cache_events.clear_events()
self._kv_cache_events = None
if self._accumulated_worker_meta is not None:
# Consensus aggregation: only keep events reported by all workers.
aggregator = KVEventAggregator(self._accumulated_worker_meta.num_workers)
aggregator.add_events(self._accumulated_worker_meta.kv_events)
yield from aggregator.get_common_events()
self._accumulated_worker_meta = None
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,6 @@ def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
return None
return MultiKVConnectorWorkerMetadata(metadata=tuple(metadata_list))

# TODO: Add a generic implementation of 'get_kv_connector_kv_cache_events'
# method for the MultiConnector. It should be able to get events from
# multiple connectors, handling the case where only a subset of the
# requested connectors implements the 'get_kv_connector_kv_cache_events'
# WIP: https://github.com/vllm-project/vllm/pull/31811

# ==============================
# Scheduler-side methods
# ==============================
Expand Down
4 changes: 0 additions & 4 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
from vllm.v1.core.sched.output import SchedulerOutput

if TYPE_CHECKING:
from vllm.distributed.kv_events import KVConnectorKVEvents
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorWorkerMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
else:
KVConnectorStats = object
KVConnectorWorkerMetadata = object
KVConnectorKVEvents = object


class LogprobsLists(NamedTuple):
Expand Down Expand Up @@ -145,7 +143,6 @@
finished_sending: set[str] | None = None
finished_recving: set[str] | None = None
kv_connector_stats: KVConnectorStats | None = None
kv_cache_events: KVConnectorKVEvents | None = None
kv_connector_worker_meta: KVConnectorWorkerMetadata | None = None
# IDs of externally computed KV blocks that failed to load.
# Requests referencing these blocks should be rescheduled to recompute them
Expand All @@ -162,7 +159,6 @@
not self.finished_sending
and not self.finished_recving
and not self.kv_connector_stats
and not self.kv_cache_events
and not self.invalid_block_ids
and not self.kv_connector_worker_meta
)
Expand All @@ -182,7 +178,7 @@
)
kv_cache_events = _combine_non_none(
lambda x, y: x.merge(y),
[output.kv_cache_events for output in outputs],

Check failure on line 181 in vllm/v1/outputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

"KVConnectorOutput" has no attribute "kv_cache_events" [attr-defined]

Check failure on line 181 in vllm/v1/outputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

"KVConnectorOutput" has no attribute "kv_cache_events" [attr-defined]

Check failure on line 181 in vllm/v1/outputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

"KVConnectorOutput" has no attribute "kv_cache_events" [attr-defined]

Check failure on line 181 in vllm/v1/outputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

"KVConnectorOutput" has no attribute "kv_cache_events" [attr-defined]
)
invalid_block_ids = _combine_non_none(
set.union, [output.invalid_block_ids for output in outputs]
Expand All @@ -195,7 +191,7 @@
)
expected_finished_count = outputs[0].expected_finished_count

return cls(

Check failure on line 194 in vllm/v1/outputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "kv_cache_events" for "KVConnectorOutput" [call-arg]

Check failure on line 194 in vllm/v1/outputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "kv_cache_events" for "KVConnectorOutput" [call-arg]

Check failure on line 194 in vllm/v1/outputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "kv_cache_events" for "KVConnectorOutput" [call-arg]

Check failure on line 194 in vllm/v1/outputs.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unexpected keyword argument "kv_cache_events" for "KVConnectorOutput" [call-arg]
finished_sending=finished_sending,
finished_recving=finished_recving,
kv_connector_stats=kv_connector_stats,
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/worker/gpu/kv_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def post_forward(
)
output.invalid_block_ids = self.kv_connector.get_block_ids_with_load_errors()
output.kv_connector_stats = self.kv_connector.get_kv_connector_stats()
output.kv_cache_events = self.kv_connector.get_kv_connector_kv_cache_events()
if clear_metadata:
self.kv_connector.clear_connector_metadata()
return output
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/worker/kv_connector_model_runner_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def _get_kv_connector_output(
output.invalid_block_ids = kv_connector.get_block_ids_with_load_errors()

output.kv_connector_stats = kv_connector.get_kv_connector_stats()
output.kv_cache_events = kv_connector.get_kv_connector_kv_cache_events()
output.kv_connector_worker_meta = kv_connector.build_connector_worker_meta()

if not defer_finalize:
Expand Down
Loading