Skip to content
Open
112 changes: 111 additions & 1 deletion tests/v1/kv_connector/unit/test_lmcache_connector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
import torch

from vllm.distributed.kv_events import BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
WorkerConnectorInitializationData,
)
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import (
LMCacheConnectorV1,
LMCacheKVEvents,
Expand Down Expand Up @@ -784,3 +788,109 @@ def test_lmcache_kv_events_aggregation(self):
assert aggregated_events[0].block_hashes == ["hash_common"]
assert aggregated_events[0].parent_block_hash == "parent_common"
assert aggregated_events[0].token_ids == [1, 2, 3]


class TestInitializeWorkerConnector:
"""Test initialize_worker_connector for LMCache connectors."""

def test_delegates_model_to_engine(self):
"""LMCacheConnectorV1 forwards model to engine.register_model."""
mock_engine = MagicMock()
mock_model = MagicMock(spec=torch.nn.Module)

connector = LMCacheConnectorV1.__new__(LMCacheConnectorV1)
connector._lmcache_engine = mock_engine
connector.initialize_worker_connector(
WorkerConnectorInitializationData(model=mock_model)
)
mock_engine.register_model.assert_called_once_with(mock_model)

def test_skips_engine_without_register_model(self):
"""No-op when engine lacks register_model."""
mock_engine = MagicMock(spec=[])
connector = LMCacheConnectorV1.__new__(LMCacheConnectorV1)
connector._lmcache_engine = mock_engine
connector.initialize_worker_connector(
WorkerConnectorInitializationData(
model=MagicMock(spec=torch.nn.Module),
)
)

def test_skips_when_model_none(self):
"""Does not call register_model when model is None."""
mock_engine = MagicMock()
connector = LMCacheConnectorV1.__new__(LMCacheConnectorV1)
connector._lmcache_engine = mock_engine
connector.initialize_worker_connector(WorkerConnectorInitializationData())
mock_engine.register_model.assert_not_called()

def test_impl_registers_with_vllm_model_tracker(self):
"""Impl registers model with VLLMModelTracker when available."""
mock_tracker = MagicMock()
mock_model = MagicMock(spec=torch.nn.Module)

with patch.dict(
"sys.modules",
{
"lmcache": MagicMock(),
"lmcache.v1": MagicMock(),
"lmcache.v1.compute": MagicMock(),
"lmcache.v1.compute.models": MagicMock(),
"lmcache.v1.compute.models.utils": MagicMock(
VLLMModelTracker=mock_tracker
),
},
):
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.vllm_v1_adapter import ( # noqa: E501
LMCacheConnectorV1Impl,
)

impl = LMCacheConnectorV1Impl.__new__(LMCacheConnectorV1Impl)
impl.initialize_worker_connector(
WorkerConnectorInitializationData(model=mock_model)
)

mock_tracker.register_model.assert_called_once()
assert mock_tracker.register_model.call_args[0][1] is mock_model

def test_impl_graceful_on_import_error(self):
"""Doesn't crash if LMCache CacheBlend not installed."""
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.vllm_v1_adapter import ( # noqa: E501
LMCacheConnectorV1Impl,
)

impl = LMCacheConnectorV1Impl.__new__(LMCacheConnectorV1Impl)
with patch.dict(
"sys.modules",
{"lmcache.v1.compute.models.utils": None},
):
impl.initialize_worker_connector(
WorkerConnectorInitializationData(
model=MagicMock(spec=torch.nn.Module),
)
)

def test_impl_skips_when_model_none(self):
"""Does not call VLLMModelTracker when model is None."""
mock_tracker = MagicMock()

with patch.dict(
"sys.modules",
{
"lmcache": MagicMock(),
"lmcache.v1": MagicMock(),
"lmcache.v1.compute": MagicMock(),
"lmcache.v1.compute.models": MagicMock(),
"lmcache.v1.compute.models.utils": MagicMock(
VLLMModelTracker=mock_tracker
),
},
):
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.vllm_v1_adapter import ( # noqa: E501
LMCacheConnectorV1Impl,
)

impl = LMCacheConnectorV1Impl.__new__(LMCacheConnectorV1Impl)
impl.initialize_worker_connector(WorkerConnectorInitializationData())

mock_tracker.register_model.assert_not_called()
30 changes: 29 additions & 1 deletion tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from unittest.mock import MagicMock

import pytest
import torch

from tests.v1.kv_connector.unit.utils import create_vllm_config
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
WorkerConnectorInitializationData,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
Expand Down Expand Up @@ -920,3 +924,27 @@ def assert_update_connector_output_called(mc: MultiConnector):
mc.update_connector_output(kv_connector_output)
assert_update_connector_output_called(mc)
assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b


def test_initialize_worker_connector_delegates_to_sub_connectors():
"""MultiConnector.initialize_worker_connector fans out to all
sub-connectors and the base class default is a no-op."""
# Verify base class no-op
base = KVConnectorBase_V1.__new__(KVConnectorBase_V1)
data = WorkerConnectorInitializationData(
model=MagicMock(spec=torch.nn.Module),
)
assert base.initialize_worker_connector(data) is None

# Verify dataclass defaults
empty_data = WorkerConnectorInitializationData()
assert empty_data.model is None

# Verify MultiConnector delegates to each sub-connector
sub1 = MagicMock()
sub2 = MagicMock()
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [sub1, sub2]
mc.initialize_worker_connector(data)
sub1.initialize_worker_connector.assert_called_once_with(data)
sub2.initialize_worker_connector.assert_called_once_with(data)
33 changes: 33 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import enum
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal

import torch
Expand Down Expand Up @@ -81,6 +82,18 @@
logger = init_logger(__name__)


@dataclass
class WorkerConnectorInitializationData:
"""Data passed to initialize_worker_connector().

Designed to be extended without breaking existing connectors: new optional
fields can be added here and connectors that don't need them simply ignore
the extra data.
"""

model: torch.nn.Module | None = field(default=None)


class SupportsHMA(ABC):
"""
The class that indicates the corresponding connector supports hybrid memory
Expand Down Expand Up @@ -264,6 +277,26 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
return

def initialize_worker_connector(
self,
initialization_data: WorkerConnectorInitializationData,
) -> None:
"""
Initialize per-worker connector state after model loading.

Called once by the GPU model runner after the model and KV caches
are ready. The default implementation is a no-op; connectors that
need access to model weights (e.g. LMCache's CacheBlend selective
recomputation) should override this method.

Args:
initialization_data: data bag containing optional fields such
as the loaded model (``initialization_data.model``).
New fields may be added in future versions without breaking
existing connectors.
"""
return

def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"]
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
WorkerConnectorInitializationData,
)
from vllm.logger import init_logger
from vllm.v1.attention.backend import AttentionMetadata
Expand Down Expand Up @@ -133,6 +134,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"please check and use the latest version"
)

def initialize_worker_connector(
self,
initialization_data: WorkerConnectorInitializationData,
) -> None:
"""Pass initialization data to the underlying LMCache engine.

Extracts ``model`` from *initialization_data* and forwards it to
LMCache's ``register_model`` when available (used by CacheBlend for
selective layer recomputation).
"""
if initialization_data.model is not None and hasattr(
self._lmcache_engine, "register_model"
):
self._lmcache_engine.register_model(initialization_data.model)

def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
WorkerConnectorInitializationData,
)
from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import (
ENGINE_NAME,
Expand Down Expand Up @@ -794,6 +795,35 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
kvcaches = list(self.kv_caches.values())
self.lmcache_engine.post_init(kvcaches=kvcaches)

def initialize_worker_connector(
self,
initialization_data: WorkerConnectorInitializationData,
) -> None:
"""Register model with LMCache's VLLMModelTracker for CacheBlend.

CacheBlend's blender needs access to model weights for selective
layer recomputation. Called automatically by vLLM after model
loading.
"""
model = initialization_data.model
if model is not None:
try:
from lmcache.v1.compute.models.utils import VLLMModelTracker

from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( # noqa: E501
ENGINE_NAME,
)

VLLMModelTracker.register_model(ENGINE_NAME, model)
logger.info("Registered model with VLLMModelTracker")
except ImportError:
logger.debug("LMCache CacheBlend model registration not available")
except Exception:
logger.warning(
"Failed to register model with VLLMModelTracker",
exc_info=True,
)

@_lmcache_nvtx_annotate
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
"""Start loading the KV cache from the connector buffer to vLLM's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
WorkerConnectorInitializationData,
)
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -513,6 +514,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
self.worker_adapter.register_kv_caches(kv_caches)
return

def initialize_worker_connector(
self,
initialization_data: WorkerConnectorInitializationData,
) -> None:
"""Delegate initialization data to the worker adapter."""
if hasattr(self.worker_adapter, "initialize_worker_connector"):
self.worker_adapter.initialize_worker_connector(initialization_data)

def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""
Start loading the KV cache from the connector to vLLM's paged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
KVConnectorMetadata,
KVConnectorRole,
KVConnectorWorkerMetadata,
WorkerConnectorInitializationData,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
KVConnectorPromMetrics,
Expand Down Expand Up @@ -219,6 +220,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)

def initialize_worker_connector(
self,
initialization_data: WorkerConnectorInitializationData,
) -> None:
for c in self._connectors:
c.initialize_worker_connector(initialization_data)

# We must override the base class method here because we need to bind
# the metadata to each connector in the order of the connectors in the
# MultiKVConnectorMetadata.
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
WorkerConnectorInitializationData,
)
from vllm.distributed.parallel_state import (
get_dcp_group,
get_pp_group,
Expand Down Expand Up @@ -6803,6 +6806,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
else:
kv_transfer_group.register_kv_caches(kv_caches)
kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks)
kv_transfer_group.initialize_worker_connector(
WorkerConnectorInitializationData(model=self.model)
)

def _get_attention_kv_cache_gid(self) -> int:
"""Find the KV cache group index for attention layers."""
Expand Down
Loading