diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 855c3411713..51024fb9217 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -8,13 +8,18 @@ from unittest.mock import MagicMock import pytest +import torch from tests.v1.kv_connector.unit.utils import create_vllm_config from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + SupportsHMA, + supports_hma, +) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( MultiConnector, @@ -83,8 +88,43 @@ def update_state_after_alloc(self, request, blocks, num_tokens) -> None: pass -# Register the mock connector +class MockHMAConnector(KVConnectorBase_V1, SupportsHMA): + """Mock connector that supports HMA for testing.""" + + def __new__(cls, *args, **kwargs): + mock = MagicMock(spec_set=cls) + return mock + + def start_load_kv(self, forward_context, **kwargs): + pass + + def wait_for_layer_load(self, layer_name): + pass + + def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs): + pass + + def wait_for_save(self): + pass + + def build_connector_meta(self, scheduler_output): + return None + + def get_num_new_matched_tokens(self, request, num_computed_tokens): + return (0, False) + + def update_state_after_alloc(self, request, blocks, num_tokens) -> None: + pass + + def request_finished_all_groups(self, request, block_ids): + return (False, None) + + +# Register mock connectors KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__) +KVConnectorFactory.register_connector( + "MockHMAConnector", __name__, MockHMAConnector.__name__ +) @pytest.fixture @@ -920,3 +960,133 @@ def assert_update_connector_output_called(mc: MultiConnector): mc.update_connector_output(kv_connector_output) assert_update_connector_output_called(mc) assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b + + +def _make_multi_connector(connector_names: list[str]) -> MultiConnector: + """Build a MultiConnector wrapping the given registered connectors.""" + vllm_config = create_vllm_config() + connectors = [ + { + "kv_connector": name, + "kv_role": "kv_both", + "kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector", # noqa: E501 + } + for name in connector_names + ] + vllm_config.kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={"connectors": connectors}, + ) + kv_cache_config = KVCacheConfig( + num_blocks=0, + kv_cache_tensors=[], + kv_cache_groups=[], + ) + return MultiConnector( + vllm_config=vllm_config, + role=KVConnectorRole.WORKER, + kv_cache_config=kv_cache_config, + ) + + +def test_multi_connector_hma_opt_in(): + """ + MultiConnector currently assumes HMA is opt-in: it needs + --no-disable-hybrid-kv-cache-manager to be enabled. + + At runtime, _all_support_hma is True only when every sub-connector + implements SupportsHMA. Test all combinations of HMA / non-HMA + sub-connectors. + """ + + assert supports_hma(MultiConnector) + + # -- All non-HMA connectors => _all_support_hma is False -- + mc_none = _make_multi_connector(["MockConnector", "MockConnector"]) + assert not supports_hma(mc_none._connectors[0]) + assert not supports_hma(mc_none._connectors[1]) + assert mc_none._all_support_hma is False + + # -- All HMA connectors => _all_support_hma is True -- + mc_all = _make_multi_connector(["MockHMAConnector", "MockHMAConnector"]) + assert supports_hma(mc_all._connectors[0]) + assert supports_hma(mc_all._connectors[1]) + assert mc_all._all_support_hma is True + + # -- Mixed: first HMA, second non-HMA => _all_support_hma is False -- + mc_mixed1 = _make_multi_connector(["MockHMAConnector", "MockConnector"]) + assert supports_hma(mc_mixed1._connectors[0]) + assert not supports_hma(mc_mixed1._connectors[1]) + assert mc_mixed1._all_support_hma is False + + # -- Mixed: first non-HMA, second HMA => _all_support_hma is False -- + mc_mixed2 = _make_multi_connector(["MockConnector", "MockHMAConnector"]) + assert not supports_hma(mc_mixed2._connectors[0]) + assert supports_hma(mc_mixed2._connectors[1]) + assert mc_mixed2._all_support_hma is False + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Requires GPU to instantiate LLM" +) +def test_multi_connector_mixed_hma_disables_hybrid_kv_cache(monkeypatch): + """ + When MultiConnector wraps a mix of HMA (NixlConnector) and non-HMA + (MockConnector) sub-connectors, verify that: + 1. The scheduler's MultiConnector has _all_support_hma == False. + 2. vLLM auto-disables the hybrid KV cache manager (no preference expressed by user) + """ + from unittest.mock import patch + + from tests.v1.kv_connector.unit.test_nixl_connector import FakeNixlWrapper + + monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") + + kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [ + { + "kv_connector": "NixlConnector", + "kv_role": "kv_both", + }, + { + "kv_connector": "MockConnector", + "kv_role": "kv_both", + "kv_connector_module_path": ( + "tests.v1.kv_connector.unit.test_multi_connector" + ), + }, + ], + }, + ) + + with patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl.worker.NixlWrapper", + FakeNixlWrapper, + ): + llm = LLM( + model="Qwen/Qwen3-0.6B", + enforce_eager=True, + gpu_memory_utilization=0.3, + max_model_len=128, + max_num_seqs=1, + max_num_batched_tokens=128, + kv_transfer_config=kv_transfer_config, + ) + try: + # HMA should be auto-disabled when user has not expressed a preference. + assert ( + llm.llm_engine.vllm_config.scheduler_config.disable_hybrid_kv_cache_manager + is True + ) + # The scheduler-side MultiConnector should detect the mixed + # HMA support among its sub-connectors. + scheduler = llm.llm_engine.engine_core.engine_core.scheduler + mc = scheduler.connector + assert isinstance(mc, MultiConnector) + assert mc._all_support_hma is False + finally: + llm.llm_engine.engine_core.shutdown() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 4ef8f0ac9c9..a340f313e0a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch @@ -18,6 +18,8 @@ KVConnectorMetadata, KVConnectorRole, KVConnectorWorkerMetadata, + SupportsHMA, + supports_hma, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -123,7 +125,7 @@ def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx) -class MultiConnector(KVConnectorBase_V1): +class MultiConnector(KVConnectorBase_V1, SupportsHMA): """ A wrapper for using multiple KVConnectors at the same time. @@ -166,6 +168,12 @@ def __init__( self._connectors.append(connector_cls(temp_config, role, kv_cache_config)) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) + self._all_support_hma = all(supports_hma(c) for c in self._connectors) + assert ( + vllm_config.scheduler_config.disable_hybrid_kv_cache_manager + or self._all_support_hma + ), "HMA should not be enabled unless all sub-connectors support it" + # A mapping from request id to the index of the connector chosen to # load the request from (if any). self._requests_to_connector: dict[str, int] = {} @@ -436,15 +444,17 @@ def set_xfer_handshake_metadata( for c in self._connectors: c.set_xfer_handshake_metadata(metadata) - def request_finished( + def _aggregate_request_finished( self, request: "Request", - blocks: list[int], + per_connector_fn: Callable[ + [KVConnectorBase_V1], tuple[bool, dict[str, Any] | None] + ], ) -> tuple[bool, dict[str, Any] | None]: async_saves = 0 kv_txfer_params = None for c in self._connectors: - async_save, txfer_params = c.request_finished(request, blocks) + async_save, txfer_params = per_connector_fn(c) if async_save: async_saves += 1 if txfer_params is not None: @@ -458,11 +468,39 @@ def request_finished( if async_saves > 1: self._extra_async_saves[request.request_id] = async_saves - 1 - # Clean up other state for this request. self._requests_to_connector.pop(request.request_id, None) return async_saves > 0, kv_txfer_params + def request_finished( + self, + request: "Request", + blocks: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + return self._aggregate_request_finished( + request, + lambda c: c.request_finished(request, blocks), + ) + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + if not self._all_support_hma: + assert len(block_ids) == 1, ( + "HMA with multiple kv_cache_groups requires all " + "sub-connectors to support HMA" + ) + return self.request_finished(request, block_ids[0]) + + return self._aggregate_request_finished( + request, + lambda c: cast(SupportsHMA, c).request_finished_all_groups( + request, block_ids + ), + ) + def take_events(self) -> Iterable["KVCacheEvent"]: for c in self._connectors: yield from c.take_events()