From fdc30ae552e4afd75a577730d41947efa343f2eb Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Fri, 31 Oct 2025 09:58:35 -0400 Subject: [PATCH] [KV Connector] Make KVCacheConfig an explicit constructor argument Follow on from #25712 `VllmConfig` is explicitly designed as a dataclass containing user-provided configuration and model metadata. It is a global configuration object that lives throughout the entire engine lifetime and is meant to be immutable after `__post_init__()`. `KVCacheConfig` is worker-specific, runtime-computed state. It has limited lifetime, and its purpose is limited to initializing the KV Cache in the model runner. Even if we add KV cache hints to model config.json in future, this would be parsed into `ModelConfig`, used as input to the `get_kv_cache_configs()` computation, and the resulting `KVCacheConfig` would still be runtime state. We are currently creating per-worker copies of VllmConfig in order to attach the runtime `KVCacheConfig` state. But instead we should just explicitly pass `KVCacheConfig` to the connector. Make sure to handle backwards compatibility for external connector implementations (loaded via module path) that have the old style constructor signature. Signed-off-by: Mark McLoughlin --- .../unit/test_backwards_compatibility.py | 275 ++++++++++++++++++ tests/v1/kv_connector/unit/utils.py | 2 +- .../kv_transfer/kv_connector/factory.py | 41 ++- .../kv_transfer/kv_connector/v1/base.py | 16 +- .../kv_connector/v1/decode_bench_connector.py | 12 +- .../kv_connector/v1/lmcache_connector.py | 12 +- .../kv_connector/v1/multi_connector.py | 14 +- .../kv_connector/v1/nixl_connector.py | 12 +- .../kv_connector/v1/offloading_connector.py | 10 +- .../kv_connector/v1/p2p/p2p_nccl_connector.py | 16 +- .../v1/shared_storage_connector.py | 16 +- .../kv_transfer/kv_transfer_state.py | 11 +- vllm/v1/core/sched/scheduler.py | 12 +- vllm/v1/worker/gpu_worker.py | 4 +- 14 files changed, 410 insertions(+), 43 deletions(-) create mode 100644 tests/v1/kv_connector/unit/test_backwards_compatibility.py diff --git a/tests/v1/kv_connector/unit/test_backwards_compatibility.py b/tests/v1/kv_connector/unit/test_backwards_compatibility.py new file mode 100644 index 000000000000..f51001a6ec12 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_backwards_compatibility.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for backwards compatibility with external KV connector implementations. + +This test ensures that external connectors (loaded via kv_connector_module_path) +implemented with the old signature continue to work: +- Old signature: __init__(self, vllm_config, role) +- New signature: __init__(self, vllm_config, role, kv_cache_config) +""" + +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.v1.core.sched.output import SchedulerOutput + +from .utils import create_scheduler, create_vllm_config + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + + +class OldStyleTestConnector(KVConnectorBase_V1): + """ + Test connector using the old signature with 2 required arguments. + This simulates external connectors that haven't been updated yet. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + # Old-style call to super().__init__ with only 2 arguments + super().__init__(vllm_config=vllm_config, role=role) + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + ): + pass + + def build_connector_meta(self, scheduler_output: SchedulerOutput): + return None + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self): + pass + + +class NewStyleTestConnector(KVConnectorBase_V1): + """ + Test connector using the new signature with 3 required arguments. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + # New-style call to super().__init__ with all 3 arguments + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + ): + pass + + def build_connector_meta(self, scheduler_output: SchedulerOutput): + return None + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self): + pass + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_external_old_signature_factory_instantiation(role): + """ + Test that external connectors with old signature (2 required args) loaded + via kv_connector_module_path are correctly instantiated with backwards + compatibility support. + """ + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector" + vllm_config.kv_transfer_config.kv_connector_module_path = ( + "tests.v1.kv_connector.unit.test_backwards_compatibility" + ) + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) + + assert connector is not None + assert isinstance(connector, OldStyleTestConnector) + assert connector.role == role + assert connector._kv_cache_config is None + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_external_new_signature_factory_instantiation(role): + """ + Test that external connectors with new signature (3 required args) loaded + via kv_connector_module_path are correctly instantiated. + """ + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector" + vllm_config.kv_transfer_config.kv_connector_module_path = ( + "tests.v1.kv_connector.unit.test_backwards_compatibility" + ) + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) + + assert connector is not None + assert isinstance(connector, NewStyleTestConnector) + assert connector.role == role + assert connector._kv_cache_config is not None + assert connector._kv_cache_config == kv_cache_config + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_old_signature_super_init(role): + """ + Test that old-style connectors can call super().__init__() without + kv_cache_config parameter. + """ + vllm_config = create_vllm_config() + + connector = OldStyleTestConnector(vllm_config, role) + + assert connector is not None + assert connector.role == role + assert connector._kv_cache_config is None + + +def test_old_signature_super_init_with_kwargs(): + """ + Test that old-style connectors can call super().__init__() with keyword + arguments in different orders. + """ + vllm_config = create_vllm_config() + + # Test with vllm_config= and role= kwargs + connector1 = OldStyleTestConnector( + vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER + ) + assert connector1 is not None + assert connector1._kv_cache_config is None + + # Test with role= and vllm_config= in reversed order + connector2 = OldStyleTestConnector( + role=KVConnectorRole.WORKER, vllm_config=vllm_config + ) + assert connector2 is not None + assert connector2._kv_cache_config is None + + +def test_internal_connector_uses_new_signature(): + """ + Test that internal connectors (registered in factory) always use the new + signature and get kv_cache_config. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( + SharedStorageConnector, + ) + + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector" + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + + assert connector is not None + assert isinstance(connector, SharedStorageConnector) + assert connector._kv_cache_config is not None + assert connector._kv_cache_config == kv_cache_config + + +def test_signature_detection_with_mocking(): + """ + Test that the factory correctly applies compat_sig flag returned from + _get_connector_class_with_compat. + """ + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + # Mock _get_connector_class_with_compat to return old-style connector + with patch.object( + KVConnectorFactory, + "_get_connector_class_with_compat", + return_value=(OldStyleTestConnector, True), + ): + old_connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + assert old_connector is not None + assert isinstance(old_connector, OldStyleTestConnector) + assert old_connector._kv_cache_config is None + + # Mock _get_connector_class_with_compat to return new-style connector + with patch.object( + KVConnectorFactory, + "_get_connector_class_with_compat", + return_value=(NewStyleTestConnector, False), + ): + new_connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + assert new_connector is not None + assert isinstance(new_connector, NewStyleTestConnector) + assert new_connector._kv_cache_config is not None + assert new_connector._kv_cache_config == kv_cache_config diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 46ea46e53084..c1c0e13f7753 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -254,7 +254,7 @@ def create_model_runner_output( class TestSharedStorageConnector(SharedStorageConnector): - def __init__(self, config: VllmConfig, role): + def __init__(self, config: VllmConfig, role, kv_cache_config): self.name = config.kv_transfer_config.kv_connector_extra_config["name"] self._connector = SharedStorageConnector(config, role) self.call_record: dict[str, int] = defaultdict(int) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index c64996f13cd5..8d14200c5240 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -3,10 +3,9 @@ import importlib from collections.abc import Callable -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Optional, cast import vllm.envs as envs -from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import ( KVConnectorBase, KVConnectorBaseType, @@ -16,9 +15,12 @@ supports_hma, ) from vllm.logger import init_logger +from vllm.utils.func_utils import supports_kw if TYPE_CHECKING: + from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig + from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -41,8 +43,9 @@ def loader() -> type[KVConnectorBase]: @classmethod def create_connector( cls, - config: VllmConfig, + config: "VllmConfig", role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, ) -> KVConnectorBase: if not envs.VLLM_USE_V1: raise ValueError( @@ -53,7 +56,9 @@ def create_connector( kv_transfer_config = config.kv_transfer_config if kv_transfer_config is None: raise ValueError("kv_transfer_config must be set to create a connector") - connector_cls = cls.get_connector_class(kv_transfer_config) + connector_cls, compat_sig = cls._get_connector_class_with_compat( + kv_transfer_config + ) # check if the connector supports HMA hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager @@ -76,7 +81,12 @@ def create_connector( # - Co-locate with worker process # - Should only be used inside the forward context & attention layer # We build separately to enforce strict separation - return connector_cls(config, role) + if compat_sig: + # Old signature: __init__(self, vllm_config, role) + return connector_cls(config, role) + else: + # New signature: __init__(self, vllm_config, role, kv_cache_config) + return connector_cls(config, role, kv_cache_config) @classmethod def get_connector_class_by_name( @@ -97,13 +107,13 @@ def get_connector_class_by_name( return cls._registry[connector_name]() @classmethod - def get_connector_class( + def _get_connector_class_with_compat( cls, kv_transfer_config: "KVTransferConfig" - ) -> type[KVConnectorBaseType]: - """Get the connector class by name.""" + ) -> tuple[type[KVConnectorBaseType], bool]: connector_name = kv_transfer_config.kv_connector if connector_name is None: raise ValueError("Connector name is not set in KVTransferConfig") + compat_sig = False if connector_name in cls._registry: connector_cls = cls._registry[connector_name]() else: @@ -118,6 +128,21 @@ def get_connector_class( f"Class {connector_name} not found in {connector_module_path}" ) from e connector_cls = cast(type[KVConnectorBaseType], connector_cls) + if not supports_kw(connector_cls, "kv_cache_config"): + compat_sig = True + logger.warning( + "Connector %s uses deprecated signature with 2 required arguments. " + "Please update to include kv_cache_config as the second argument.", + connector_cls.__name__, + ) + return connector_cls, compat_sig + + @classmethod + def get_connector_class( + cls, kv_transfer_config: "KVTransferConfig" + ) -> type[KVConnectorBaseType]: + """Get the connector class by name.""" + connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config) return connector_cls diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2ed0fe592e37..f48ac3f1cebb 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -58,6 +58,7 @@ ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request # s_tensor_list, d_tensor_list, s_indices, d_indices, direction @@ -132,7 +133,12 @@ class KVConnectorMetadata(ABC): # noqa: B024 class KVConnectorBase_V1(ABC): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design." @@ -143,6 +149,14 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self._kv_transfer_config = vllm_config.kv_transfer_config else: raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1") + self._kv_cache_config = kv_cache_config + if self._kv_cache_config is None: + logger.warning( + "KVConnectorBase_V1 initialized without kv_cache_config. " + "This is deprecated - please update your connector to accept " + "kv_cache_config as the third constructor argument and pass it " + "to super().__init__()." + ) self._role = role @property diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py index ca251cd0c6eb..9cd7d93c92fa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py @@ -32,7 +32,7 @@ """ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import torch @@ -50,6 +50,7 @@ from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -79,8 +80,13 @@ class DecodeBenchConnector(KVConnectorBase_V1): testing of the decoder with larger input sequence lengths. """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) self.connector_scheduler: DecodeBenchConnectorScheduler | None = None self.connector_worker: DecodeBenchConnectorWorker | None = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 7232d947030c..575ab468be56 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -20,14 +20,22 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) class LMCacheConnectorV1(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) assert vllm_config.kv_transfer_config is not None use_native = vllm_config.kv_transfer_config.get_from_extra_config( "use_native", False diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d56f30bd11e5..d7bbf02c8367 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -31,6 +31,7 @@ from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -109,15 +110,22 @@ class MultiConnector(KVConnectorBase_V1): - Save to all connectors. """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self._connectors: list[KVConnectorBase_V1] = [] self._ktc_kv_transfer_config = [] for connector_cls, temp_config in self._get_connector_classes_and_configs( vllm_config ): - self._connectors.append(connector_cls(temp_config, role)) + self._connectors.append(connector_cls(temp_config, role, kv_cache_config)) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) # A mapping from request id to the index of the connector chosen to diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d5712bdd9feb..2f541710065f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -13,7 +13,7 @@ from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import msgspec import numpy as np @@ -51,6 +51,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request Transfer = tuple[int, float] # (xfer_handle, start_time) @@ -152,7 +153,14 @@ def add_new_req( class NixlConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 19344e5784c2..e5ea80e6ea45 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -21,6 +21,7 @@ from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_offload.abstract import OffloadingManager from vllm.v1.kv_offload.factory import OffloadingSpecFactory from vllm.v1.kv_offload.mediums import GPULoadStoreSpec @@ -41,8 +42,13 @@ class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) spec = OffloadingSpecFactory.create_spec(vllm_config) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 780dd12fccda..a124a0d519db 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import regex as re import torch @@ -25,6 +25,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -71,8 +72,17 @@ def add_request( class P2pNcclConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Any] = {} self.is_producer = self._kv_transfer_config.is_kv_producer diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 9c230d7d0d2f..016d1d45b359 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -3,7 +3,7 @@ import hashlib import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import safetensors import torch @@ -22,6 +22,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -86,8 +87,17 @@ class SharedStorageConnector(KVConnectorBase_V1): # It does extra work which will overwrite the existing prefix-cache in GPU # - to remove the overhead, need to add some "mask" in the ReqMeta class - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Request] = {} self._storage_path = self._kv_transfer_config.get_from_extra_config( diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index cabfc10e7f94..7501f0b373d4 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from vllm import envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType @@ -12,6 +12,7 @@ if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.v1.kv_cache_interface import KVCacheConfig _KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None @@ -48,7 +49,9 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo return isinstance(connector, KVConnectorBase_V1) -def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: +def ensure_kv_transfer_initialized( + vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None +) -> None: """ Initialize KV cache transfer parallel group. """ @@ -64,7 +67,9 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ): if envs.VLLM_USE_V1: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( - config=vllm_config, role=KVConnectorRole.WORKER + config=vllm_config, + role=KVConnectorRole.WORKER, + kv_cache_config=kv_cache_config, ) else: raise ValueError("V0 is no longer supported") diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 98c8f08b0aae..3a85aff0b98a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy import itertools import time from collections import defaultdict @@ -91,15 +90,10 @@ def __init__( assert not self.is_encoder_decoder, ( "Encoder-decoder models are not currently supported with KV connectors" ) - - connector_vllm_config = copy.copy(self.vllm_config) - - # We're dynamically inserting a kv_cache_config variable into the - # connector_vllm_config. This is distinct from the cache_config - # that is already in there. - connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) # type: ignore[attr-defined] self.connector = KVConnectorFactory.create_connector( - config=connector_vllm_config, role=KVConnectorRole.SCHEDULER + config=self.vllm_config, + role=KVConnectorRole.SCHEDULER, + kv_cache_config=self.kv_cache_config, ) if self.log_stats: self.connector_prefix_cache_stats = PrefixCacheStats() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 54c5f81fc7e8..a27c760ab0e1 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -359,9 +359,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, # because `initialize_kv_cache` will inject kv cache groups not # related to kv cache connector (e.g. kv cache sharing layers). - connector_vllm_config = copy.copy(self.vllm_config) - connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) - ensure_kv_transfer_initialized(connector_vllm_config) + ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config) if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator