diff --git a/tests/v1/kv_connector/unit/test_lmcache_connector.py b/tests/v1/kv_connector/unit/test_lmcache_connector.py index c3df2b68b1ff..3e722d7901aa 100644 --- a/tests/v1/kv_connector/unit/test_lmcache_connector.py +++ b/tests/v1/kv_connector/unit/test_lmcache_connector.py @@ -1,10 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +import torch from vllm.distributed.kv_events import BlockStored +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + WorkerConnectorInitializationData, +) from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector import ( LMCacheConnectorV1, LMCacheKVEvents, @@ -784,3 +788,109 @@ def test_lmcache_kv_events_aggregation(self): assert aggregated_events[0].block_hashes == ["hash_common"] assert aggregated_events[0].parent_block_hash == "parent_common" assert aggregated_events[0].token_ids == [1, 2, 3] + + +class TestInitializeWorkerConnector: + """Test initialize_worker_connector for LMCache connectors.""" + + def test_delegates_model_to_engine(self): + """LMCacheConnectorV1 forwards model to engine.register_model.""" + mock_engine = MagicMock() + mock_model = MagicMock(spec=torch.nn.Module) + + connector = LMCacheConnectorV1.__new__(LMCacheConnectorV1) + connector._lmcache_engine = mock_engine + connector.initialize_worker_connector( + WorkerConnectorInitializationData(model=mock_model) + ) + mock_engine.register_model.assert_called_once_with(mock_model) + + def test_skips_engine_without_register_model(self): + """No-op when engine lacks register_model.""" + mock_engine = MagicMock(spec=[]) + connector = LMCacheConnectorV1.__new__(LMCacheConnectorV1) + connector._lmcache_engine = mock_engine + connector.initialize_worker_connector( + WorkerConnectorInitializationData( + model=MagicMock(spec=torch.nn.Module), + ) + ) + + def test_skips_when_model_none(self): + """Does not call register_model when model is None.""" + mock_engine = MagicMock() + connector = LMCacheConnectorV1.__new__(LMCacheConnectorV1) + connector._lmcache_engine = mock_engine + connector.initialize_worker_connector(WorkerConnectorInitializationData()) + mock_engine.register_model.assert_not_called() + + def test_impl_registers_with_vllm_model_tracker(self): + """Impl registers model with VLLMModelTracker when available.""" + mock_tracker = MagicMock() + mock_model = MagicMock(spec=torch.nn.Module) + + with patch.dict( + "sys.modules", + { + "lmcache": MagicMock(), + "lmcache.v1": MagicMock(), + "lmcache.v1.compute": MagicMock(), + "lmcache.v1.compute.models": MagicMock(), + "lmcache.v1.compute.models.utils": MagicMock( + VLLMModelTracker=mock_tracker + ), + }, + ): + from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.vllm_v1_adapter import ( # noqa: E501 + LMCacheConnectorV1Impl, + ) + + impl = LMCacheConnectorV1Impl.__new__(LMCacheConnectorV1Impl) + impl.initialize_worker_connector( + WorkerConnectorInitializationData(model=mock_model) + ) + + mock_tracker.register_model.assert_called_once() + assert mock_tracker.register_model.call_args[0][1] is mock_model + + def test_impl_graceful_on_import_error(self): + """Doesn't crash if LMCache CacheBlend not installed.""" + from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.vllm_v1_adapter import ( # noqa: E501 + LMCacheConnectorV1Impl, + ) + + impl = LMCacheConnectorV1Impl.__new__(LMCacheConnectorV1Impl) + with patch.dict( + "sys.modules", + {"lmcache.v1.compute.models.utils": None}, + ): + impl.initialize_worker_connector( + WorkerConnectorInitializationData( + model=MagicMock(spec=torch.nn.Module), + ) + ) + + def test_impl_skips_when_model_none(self): + """Does not call VLLMModelTracker when model is None.""" + mock_tracker = MagicMock() + + with patch.dict( + "sys.modules", + { + "lmcache": MagicMock(), + "lmcache.v1": MagicMock(), + "lmcache.v1.compute": MagicMock(), + "lmcache.v1.compute.models": MagicMock(), + "lmcache.v1.compute.models.utils": MagicMock( + VLLMModelTracker=mock_tracker + ), + }, + ): + from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.vllm_v1_adapter import ( # noqa: E501 + LMCacheConnectorV1Impl, + ) + + impl = LMCacheConnectorV1Impl.__new__(LMCacheConnectorV1Impl) + impl.initialize_worker_connector(WorkerConnectorInitializationData()) + + mock_tracker.register_model.assert_not_called() diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 671a80137b63..1ec9b0e49ceb 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -8,13 +8,17 @@ from unittest.mock import MagicMock import pytest +import torch from tests.v1.kv_connector.unit.utils import create_vllm_config from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + WorkerConnectorInitializationData, +) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( MultiConnector, @@ -920,3 +924,27 @@ def assert_update_connector_output_called(mc: MultiConnector): mc.update_connector_output(kv_connector_output) assert_update_connector_output_called(mc) assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b + + +def test_initialize_worker_connector_delegates_to_sub_connectors(): + """MultiConnector.initialize_worker_connector fans out to all + sub-connectors and the base class default is a no-op.""" + # Verify base class no-op + base = KVConnectorBase_V1.__new__(KVConnectorBase_V1) + data = WorkerConnectorInitializationData( + model=MagicMock(spec=torch.nn.Module), + ) + assert base.initialize_worker_connector(data) is None + + # Verify dataclass defaults + empty_data = WorkerConnectorInitializationData() + assert empty_data.model is None + + # Verify MultiConnector delegates to each sub-connector + sub1 = MagicMock() + sub2 = MagicMock() + mc = MultiConnector.__new__(MultiConnector) + mc._connectors = [sub1, sub2] + mc.initialize_worker_connector(data) + sub1.initialize_worker_connector.assert_called_once_with(data) + sub2.initialize_worker_connector.assert_called_once_with(data) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index ef143cba7fb5..3163ec00557a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -43,6 +43,7 @@ import enum from abc import ABC, abstractmethod from collections.abc import Callable, Iterable +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal import torch @@ -81,6 +82,18 @@ logger = init_logger(__name__) +@dataclass +class WorkerConnectorInitializationData: + """Data passed to initialize_worker_connector(). + + Designed to be extended without breaking existing connectors: new optional + fields can be added here and connectors that don't need them simply ignore + the extra data. + """ + + model: torch.nn.Module | None = field(default=None) + + class SupportsHMA(ABC): """ The class that indicates the corresponding connector supports hybrid memory @@ -264,6 +277,26 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ return + def initialize_worker_connector( + self, + initialization_data: WorkerConnectorInitializationData, + ) -> None: + """ + Initialize per-worker connector state after model loading. + + Called once by the GPU model runner after the model and KV caches + are ready. The default implementation is a no-op; connectors that + need access to model weights (e.g. LMCache's CacheBlend selective + recomputation) should override this method. + + Args: + initialization_data: data bag containing optional fields such + as the loaded model (``initialization_data.model``). + New fields may be added in future versions without breaking + existing connectors. + """ + return + def register_cross_layers_kv_cache( self, kv_cache: torch.Tensor, attn_backend: type["AttentionBackend"] ): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 64aee2bd9c49..9e1d0600fa4c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -16,6 +16,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + WorkerConnectorInitializationData, ) from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionMetadata @@ -133,6 +134,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): "please check and use the latest version" ) + def initialize_worker_connector( + self, + initialization_data: WorkerConnectorInitializationData, + ) -> None: + """Pass initialization data to the underlying LMCache engine. + + Extracts ``model`` from *initialization_data* and forwards it to + LMCache's ``register_model`` when available (used by CacheBlend for + selective layer recomputation). + """ + if initialization_data.model is not None and hasattr( + self._lmcache_engine, "register_model" + ): + self._lmcache_engine.register_model(initialization_data.model) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 35cd70606915..a8cc920bf74e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -41,6 +41,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + WorkerConnectorInitializationData, ) from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( ENGINE_NAME, @@ -794,6 +795,35 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): kvcaches = list(self.kv_caches.values()) self.lmcache_engine.post_init(kvcaches=kvcaches) + def initialize_worker_connector( + self, + initialization_data: WorkerConnectorInitializationData, + ) -> None: + """Register model with LMCache's VLLMModelTracker for CacheBlend. + + CacheBlend's blender needs access to model weights for selective + layer recomputation. Called automatically by vLLM after model + loading. + """ + model = initialization_data.model + if model is not None: + try: + from lmcache.v1.compute.models.utils import VLLMModelTracker + + from vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils import ( # noqa: E501 + ENGINE_NAME, + ) + + VLLMModelTracker.register_model(ENGINE_NAME, model) + logger.info("Registered model with VLLMModelTracker") + except ImportError: + logger.debug("LMCache CacheBlend model registration not available") + except Exception: + logger.warning( + "Failed to register model with VLLMModelTracker", + exc_info=True, + ) + @_lmcache_nvtx_annotate def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: """Start loading the KV cache from the connector buffer to vLLM's diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index 5f14c733a8b0..57abdb66cab7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -16,6 +16,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + WorkerConnectorInitializationData, ) from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput @@ -513,6 +514,14 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): self.worker_adapter.register_kv_caches(kv_caches) return + def initialize_worker_connector( + self, + initialization_data: WorkerConnectorInitializationData, + ) -> None: + """Delegate initialization data to the worker adapter.""" + if hasattr(self.worker_adapter, "initialize_worker_connector"): + self.worker_adapter.initialize_worker_connector(initialization_data) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 3888d2e0f44c..ce7a459fbd14 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -18,6 +18,7 @@ KVConnectorMetadata, KVConnectorRole, KVConnectorWorkerMetadata, + WorkerConnectorInitializationData, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -219,6 +220,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for c in self._connectors: c.register_kv_caches(kv_caches) + def initialize_worker_connector( + self, + initialization_data: WorkerConnectorInitializationData, + ) -> None: + for c in self._connectors: + c.initialize_worker_connector(initialization_data) + # We must override the base class method here because we need to bind # the metadata to each connector in the order of the connectors in the # MultiKVConnectorMetadata. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8a43f43d0398..c5f50081ea74 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -37,6 +37,9 @@ from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + WorkerConnectorInitializationData, +) from vllm.distributed.parallel_state import ( get_dcp_group, get_pp_group, @@ -6803,6 +6806,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: else: kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) + kv_transfer_group.initialize_worker_connector( + WorkerConnectorInitializationData(model=self.model) + ) def _get_attention_kv_cache_gid(self) -> int: """Find the KV cache group index for attention layers."""