diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 9b6d52e7c294..783678e9cefd 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -49,6 +49,33 @@ def build_kv_connector_stats( ) -> KVConnectorStats | None: return MockConnectorStats(data=data) if data is not None else None + def start_load_kv(self, forward_context, **kwargs): + pass + + def wait_for_layer_load(self, layer_name): + pass + + def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs): + pass + + def wait_for_save(self): + pass + + def build_connector_meta(self, scheduler_output): + return None + + def get_num_new_matched_tokens(self, request, num_computed_tokens): + return (0, False) + + def update_state_after_alloc(self, request, blocks, num_tokens) -> None: + pass + + +class MockCrossLayerConnector(MockConnector): + @property + def prefer_cross_layer_blocks(self) -> bool: + return True + # Register the mock connector KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__) @@ -601,3 +628,21 @@ def test_is_empty_with_multiple_connectors(self): # One non-empty stats.data["NixlConnector"].data["transfer_duration"].append(1.0) assert not stats.is_empty() + + +class TestMultiConnectorPreferCrossLayerBlocks: + def test_all_connectors_prefer_cross_layer_blocks(self): + mc = MultiConnector.__new__(MultiConnector) + mc._connectors = [ + MockCrossLayerConnector.__new__(MockCrossLayerConnector), + MockCrossLayerConnector.__new__(MockCrossLayerConnector), + ] + assert mc.prefer_cross_layer_blocks is True + + def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self): + mc = MultiConnector.__new__(MultiConnector) + mc._connectors = [ + MockCrossLayerConnector.__new__(MockCrossLayerConnector), + MockConnector.__new__(MockConnector), # default False + ] + assert mc.prefer_cross_layer_blocks is False diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index c05e5485a835..0829336f0d50 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -38,7 +38,7 @@ import enum from abc import ABC, abstractmethod from collections.abc import Callable, Iterable -from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional +from typing import TYPE_CHECKING, Any, Literal, Optional import torch @@ -144,15 +144,15 @@ class KVConnectorMetadata(ABC): # noqa: B024 class KVConnectorBase_V1(ABC): """ Base class for KV connectors. - - Attributes: - prefer_cross_layer_blocks (bool): Indicates whether this connector - prefers KV blocks that hold KV data for all layers (for speeding - up KV data transfers). - Defaults to False. """ - prefer_cross_layer_blocks: ClassVar[bool] = False + @property + def prefer_cross_layer_blocks(self) -> bool: + """ + Indicates whether this connector prefers KV blocks that hold KV data for all + layers, which can speed up KV data transfers. Defaults to False. + """ + return False def __init__( self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 682574537495..3fa1cdc1e100 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -7,7 +7,7 @@ import torch -from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType @@ -138,6 +138,12 @@ def __init__( # Propagated from scheduler to worker side via the connector metadata. self._extra_async_saves: dict[str, int] = {} + @property + def prefer_cross_layer_blocks(self) -> bool: + if not self._connectors: + return False + return all(c.prefer_cross_layer_blocks for c in self._connectors) + @classmethod def _get_connector_classes_and_configs( cls, vllm_config: "VllmConfig" @@ -164,6 +170,13 @@ def _get_connector_classes_and_configs( ) return ret + def register_cross_layers_kv_cache( + self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend] + ): + # Register on all connectors + for c in self._connectors: + c.register_cross_layers_kv_cache(kv_cache, attn_backend) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): for c in self._connectors: c.register_kv_caches(kv_caches) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 99f6f9157b36..7f03e0d88b9d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from dataclasses import dataclass from itertools import islice -from typing import Any, ClassVar +from typing import Any import torch @@ -44,7 +44,9 @@ class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnector(KVConnectorBase_V1): - prefer_cross_layer_blocks: ClassVar[bool] = True + @property + def prefer_cross_layer_blocks(self) -> bool: + return True def __init__( self,