diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py new file mode 100644 index 000000000000..64da0d79bf33 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 +import filecmp +import shutil +import tempfile +from collections import defaultdict +from pathlib import Path + +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa + SharedStorageConnector) + +MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" + +PROMPT_CONTEXT = "Hi " * 100 +PROMPTS = [ + PROMPT_CONTEXT + "Hello, my name is", + PROMPT_CONTEXT + "The capital of France is", +] + +SAMPLING_PARAMS = SamplingParams(temperature=0, max_tokens=20) + + +class TestSharedStorageConnector(SharedStorageConnector): + + def __init__(self, config: VllmConfig, role): + self.name = config.kv_transfer_config.kv_connector_extra_config["name"] + self._connector = SharedStorageConnector(config, role) + self.call_record: dict[str, int] = defaultdict(int) + # Use a unique temp file per connector + self._event_file = tempfile.gettempdir( + ) + f"/connector_{self.name}_events.log" + # Start with an empty file + with open(self._event_file, "w") as _: + pass + + def __getattribute__(self, name): + if name in ("_connector", "call_record", "name", "_event_file", + "__class__", "__dict__", "__getattribute__", + "__init__"): # avoid recursion + return object.__getattribute__(self, name) + if not hasattr(self._connector, name): + return object.__getattribute__(self, name) + attr = getattr(self._connector, name) + + # Intercept calls to the connector interface and write an event + # for each one to a file, which can be read back in the main test proc. + if callable(attr): + + def wrapper(*args, **kwargs): + self.call_record[name] += 1 + # Log the event as a line to the file + try: + with open(self._event_file, "a") as f: + f.write(name + "\n") + except Exception as e: + print(f"[ERROR] Could not log event {name} " + f"for {self.name}: {e}") + return attr(*args, **kwargs) + + return wrapper + return attr + + +KVConnectorFactory.register_connector("TestSharedStorageConnector", + TestSharedStorageConnector.__module__, + TestSharedStorageConnector.__name__) + + +# Helper function to compare directories recursively +def _compare_directories(dir1: Path, dir2: Path) -> bool: + """Compares two directories recursively for identical content.""" + dcmp = filecmp.dircmp(dir1, dir2) + if dcmp.left_only or dcmp.right_only or dcmp.diff_files: + print(f"Differences found between {dir1} and {dir2}:") + print(f" Left only: {dcmp.left_only}") + print(f" Right only: {dcmp.right_only}") + print(f" Different files: {dcmp.diff_files}") + return False + for sub_dir in dcmp.common_dirs: + if not _compare_directories(dir1 / sub_dir, dir2 / sub_dir): + return False + return True + + +def test_multi_shared_storage_connector_consistency(): + """ + Tests that MultiConnector with two SharedStorageConnectors saves + identical KV cache data to separate storage locations. + """ + storage_1_path = Path("storage_1/") + storage_2_path = Path("storage_2/") + shutil.rmtree(storage_1_path, ignore_errors=True) + shutil.rmtree(storage_2_path, ignore_errors=True) + storage_1_path.mkdir() + storage_2_path.mkdir() + + # Configure MultiConnector with two SharedStorageConnectors + kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [{ + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_1_path), + "name": "storage1", + } + }, { + "kv_connector": "TestSharedStorageConnector", + "kv_role": "kv_both", + "kv_connector_extra_config": { + "shared_storage_path": str(storage_2_path), + "name": "storage2", + } + }] + }, + ) + + llm = LLM( + model=MODEL_NAME, + enforce_eager=True, + gpu_memory_utilization=0.5, + kv_transfer_config=kv_transfer_config, + ) + # Run generation - this should trigger saving KV cache + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + # --- Verification --- + + # Check that both storage directories were populated + local_subdirs = list(storage_1_path.iterdir()) + external_subdirs = list(storage_2_path.iterdir()) + + assert len( + local_subdirs + ) > 0, f"Local storage path {storage_1_path} is empty after generation." + assert len(external_subdirs) > 0, ( + f"External storage path {storage_2_path} is empty after generation.") + assert len(local_subdirs) == len(external_subdirs), ( + f"Mismatch in number of cache entries: " + f"Local={len(local_subdirs)}, External={len(external_subdirs)}") + + # The subdirectories should correspond to the prompt hashes + # Since prompts are the same, the hash directories should be the same name + local_subdir_names = sorted([d.name for d in local_subdirs]) + external_subdir_names = sorted([d.name for d in external_subdirs]) + assert local_subdir_names == external_subdir_names, ( + "Cache directory names do not match between local and external storage" + ) + + # Compare the contents of each corresponding cache directory + for subdir_name in local_subdir_names: + print(f"Comparing contents of cache directory: {subdir_name}") + assert _compare_directories(storage_1_path / subdir_name, + storage_2_path / subdir_name), \ + (f"Contents differ for cache directory '{subdir_name}' between " + f"{storage_1_path} and {storage_2_path}") + + events = get_connector_events() + # get_num_new_matched_tokens will be called on each connector in turn. + # neither of them have hits so update_state_after_alloc won't be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will return new tokens from the first + # connector so update_state_after_alloc will be called once blocks + # are allocated for the first connector. + # get_num_new_matched_tokens *won't* be called on the second connector + # in this case. + assert events["storage1"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + assert events["storage2"][:2] == [ + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Delete storage1 connector state + shutil.rmtree(storage_1_path) + + # Reset prefix cache or else we'll just get the tokens back from there. + llm.reset_prefix_cache() + + # Run generation again - this should trigger loading from the first + # connector. + _ = llm.generate(PROMPTS, SAMPLING_PARAMS) + + events = get_connector_events() + # get_num_new_matched_tokens will be called for the first connector but it + # won't have a hit so update_state_after_alloc won't be called. + # get_num_new_matched_tokens will also be called on the second connector, + # but it should have a hit so update_state_after_alloc will be called. + assert events["storage1"][:3] == [ + 'get_num_new_matched_tokens', 'build_connector_meta', + 'bind_connector_metadata' + ] + assert events["storage2"][:4] == [ + 'get_num_new_matched_tokens', 'update_state_after_alloc', + 'build_connector_meta', 'bind_connector_metadata' + ] + + # Clean up + shutil.rmtree(storage_1_path) + shutil.rmtree(storage_2_path) + + +def get_connector_events() -> dict[str, list[str]]: + # Read in connector events and reset the files. + import glob + event_files = glob.glob(tempfile.gettempdir() + "/connector_*_events.log") + connector_events = {} + for fname in event_files: + name = fname.split("connector_")[1].split("_events.log")[0] + try: + with open(fname, "r+") as f: + connector_events[name] = [ + line.strip() for line in f if line.strip() + ] + f.truncate(0) + except Exception as e: + print(f"[ERROR] Could not read connector events for {name}: {e}") + + return connector_events diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 6766d5a24542..f998f5dd7b15 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -110,3 +110,8 @@ def create_connector_v1( "NixlConnector", "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector", "NixlConnector") + +KVConnectorFactory.register_connector( + "MultiConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.multi_connector", + "MultiConnector") diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 03c99f20e775..9fdb5340f0e2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -22,7 +22,6 @@ import enum from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch @@ -48,7 +47,6 @@ class KVConnectorRole(enum.Enum): WORKER = 1 -@dataclass class KVConnectorMetadata: """ Abstract Metadata used to communicate between the diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py new file mode 100644 index 000000000000..cc4a7fbadf5c --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.config import KVTransferConfig, VllmConfig +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.logger import init_logger +from vllm.v1.core.sched.output import SchedulerOutput + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +class MultiKVConnectorMetadata(tuple[KVConnectorMetadata, ...], + KVConnectorMetadata): + pass + + +class MultiConnector(KVConnectorBase_V1): + """ + A wrapper for using multiple KVConnectors at the same time. + + The current logic is: + - Load KV from the first connector that advertises available tokens from + get_num_new_matched_tokens(), based on the order in the config. + - Save to all connectors. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self._connectors = [] + ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors") + assert ktcs is not None + for ktc in ktcs: + temp_config = copy.copy(vllm_config) + temp_config.kv_transfer_config = KVTransferConfig(**ktc) + self._connectors.append( + KVConnectorFactory.create_connector_v1(temp_config, role)) + + # A mapping from request id to the connector that is assigned to it. + self._requests_to_connector: dict[str, KVConnectorBase_V1] = {} + + # Keeps track of *additional* remaining async saves (beyond 1) to be + # finished per request. Not needed for async loads since we only allow + # a single connector to load. + self._extra_async_saves: dict[str, int] = {} + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + for c in self._connectors: + c.register_kv_caches(kv_caches) + + # We must override the base class method here because we need to bind + # the metadata to each connector in the order of the connectors in the + # MultiKVConnectorMetadata. + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + assert isinstance(connector_metadata, MultiKVConnectorMetadata) + for c, cm in zip(self._connectors, connector_metadata): + c.bind_connector_metadata(cm) + + def clear_connector_metadata(self) -> None: + for c in self._connectors: + c.clear_connector_metadata() + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + for c in self._connectors: + c.start_load_kv(forward_context, **kwargs) + + def wait_for_layer_load(self, layer_name: str) -> None: + for c in self._connectors: + c.wait_for_layer_load(layer_name) + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + for c in self._connectors: + c.save_kv_layer(layer_name, kv_layer, attn_metadata, **kwargs) + + def wait_for_save(self): + for c in self._connectors: + c.wait_for_save() + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + finished_recving: set[str] = set() + finished_sending: set[str] = set() + for c in self._connectors: + recving, sending = c.get_finished(finished_req_ids) + if not recving and not sending: + continue + # Aggregate finished recving request ids. + finished_recving.update(recving or ()) + # Aggregate finished sending request ids - only include + # once we've drained the "extra" count (for cases where + # more than one connector is async-saving the same request). + for req_id in sending or (): + extra_pending = self._extra_async_saves.get(req_id) + if extra_pending is None: + finished_sending.add(req_id) + continue + assert extra_pending > 0 + if extra_pending == 1: + del self._extra_async_saves[req_id] + else: + self._extra_async_saves[req_id] = extra_pending - 1 + + return finished_recving or None, finished_sending or None + + # ============================== + # Scheduler-side methods + # ============================== + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + for c in self._connectors: + toks, load_async = c.get_num_new_matched_tokens( + request, num_computed_tokens) + # The first connector that has new matched tokens will be assigned + # to this request. + if toks > 0: + self._requests_to_connector[request.request_id] = c + return toks, load_async + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + # If the request is not assigned to any connector, we do nothing. + if request.request_id not in self._requests_to_connector: + return + # We assume that the request is assigned to only one connector. + c = self._requests_to_connector.pop(request.request_id) + c.update_state_after_alloc(request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput) -> MultiKVConnectorMetadata: + return MultiKVConnectorMetadata( + c.build_connector_meta(scheduler_output) for c in self._connectors) + + def request_finished( + self, + request: "Request", + blocks: "KVCacheBlocks", + ) -> tuple[bool, Optional[dict[str, Any]]]: + async_saves = 0 + kv_txfer_params = None + for c in self._connectors: + async_save, txfer_params = c.request_finished(request, blocks) + if async_save: + async_saves += 1 + if txfer_params is not None: + if kv_txfer_params is not None: + #TODO we can probably change this to merge the dicts here, + # checking for key clashes. + raise RuntimeError( + "Only one connector can produce KV transfer params") + kv_txfer_params = txfer_params + if async_saves > 1: + self._extra_async_saves[request.request_id] = async_saves - 1 + return async_saves > 0, kv_txfer_params