diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 2bf0b6fd9a16..a2a6612dc8ba 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -729,7 +729,6 @@ steps: # this test fails consistently. # TODO: investigate and fix - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - pytest -v -s models/multimodal/generation/test_maverick.py diff --git a/tests/kv_transfer/test_disagg.py b/tests/kv_transfer/test_disagg.py deleted file mode 100644 index 9f2229cc41df..000000000000 --- a/tests/kv_transfer/test_disagg.py +++ /dev/null @@ -1,120 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os -import subprocess -import sys -import time -from subprocess import Popen - -import pytest -import requests -import torch - - -# Fixture to set up environment variables and teardown servers after tests -@pytest.fixture(scope="module", autouse=True) -def setup_servers(): - if torch.cuda.device_count() < 2: - pytest.skip("Skipping test: fewer than 2 GPUs available") - - # Set up environment variables - VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", - shell=True).decode().strip() - os.environ["VLLM_HOST_IP"] = VLLM_HOST_IP - - # Start prefill instance - prefill_cmd = [ - sys.executable, - "-m", - "vllm.entrypoints.openai.api_server", - "--model", - "meta-llama/Llama-3.2-1B-Instruct", - "--port", - "8100", - "--gpu-memory-utilization", - "0.5", - "--max-model-len", - "1000", - "--kv-transfer-config", - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer",'\ - '"kv_rank":0,"kv_parallel_size":2}', - ] - prefill_env = os.environ.copy() - prefill_env["CUDA_VISIBLE_DEVICES"] = "0" - prefill_proc = Popen(prefill_cmd, env=prefill_env) - - # Start decode instance - decode_cmd = [ - sys.executable, - "-m", - "vllm.entrypoints.openai.api_server", - "--model", - "meta-llama/Llama-3.2-1B-Instruct", - "--port", - "8200", - "--gpu-memory-utilization", - "0.5", - "--max-model-len", - "1000", - "--kv-transfer-config", - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer",'\ - '"kv_rank":1,"kv_parallel_size":2}', - ] - decode_env = os.environ.copy() - decode_env["CUDA_VISIBLE_DEVICES"] = "1" - decode_proc = Popen(decode_cmd, env=decode_env) - - # Wait for servers to be ready - assert wait_for_server(8100), "Prefill server did not start in time" - assert wait_for_server(8200), "Decode server did not start in time" - - # Yield to the test function and handle teardown after tests - yield - - # Cleanup: kill the processes - prefill_proc.terminate() - decode_proc.terminate() - - # Additional cleanup if needed - prefill_proc.wait() - decode_proc.wait() - - -# Helper function to wait for server -def wait_for_server(port, timeout=240): - start_time = time.time() - while time.time() - start_time < timeout: - try: - response = requests.get(f"http://localhost:{port}/v1/completions") - if response.status_code in [200, 405]: - return True - except requests.ConnectionError: - time.sleep(1) - return False - - -# Test function to send curl requests and validate responses -@pytest.mark.parametrize("prompt", ["San Francisco is a", "Santa Clara is a"]) -def test_disaggregated_prefilling(prompt): - # Send to prefill - response = requests.post("http://localhost:8100/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "prompt": prompt, - "max_tokens": 1, - "temperature": 0 - }) - assert response.status_code == 200 - - # Send to decode - response = requests.post("http://localhost:8200/v1/completions", - headers={"Content-Type": "application/json"}, - json={ - "model": "meta-llama/Llama-3.2-1B-Instruct", - "prompt": prompt, - "max_tokens": 10, - "temperature": 0 - }) - assert response.status_code == 200 diff --git a/tests/v1/kv_connector/unit/test_output_aggreagator.py b/tests/v1/kv_connector/unit/test_output_aggreagator.py index cad73f68e9f1..56b86087783b 100644 --- a/tests/v1/kv_connector/unit/test_output_aggreagator.py +++ b/tests/v1/kv_connector/unit/test_output_aggreagator.py @@ -3,6 +3,7 @@ from concurrent.futures import Future from typing import Optional +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorOutput from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator from vllm.v1.outputs import ModelRunnerOutput @@ -12,8 +13,22 @@ class DummyModelRunnerOutput(ModelRunnerOutput): def __init__(self, finished_sending: Optional[set[str]] = None, finished_recving: Optional[set[str]] = None): - self.finished_sending = finished_sending - self.finished_recving = finished_recving + self.kv_connector_finish_output = (KVConnectorOutput( + finished_sending=finished_sending or set(), + finished_recving=finished_recving or set(), + finished_loading_num_tokens={})) + + @property + def finished_sending(self) -> set[str]: + if self.kv_connector_finish_output is None: + return set() + return self.kv_connector_finish_output.finished_sending + + @property + def finished_recving(self) -> set[str]: + if self.kv_connector_finish_output is None: + return set() + return self.kv_connector_finish_output.finished_recving def test_aggregate_workers_output(): @@ -27,8 +42,8 @@ def test_aggregate_workers_output(): aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 - assert aggregated.finished_sending is None - assert aggregated.finished_recving is None + assert not aggregated.finished_sending + assert not aggregated.finished_recving output1 = DummyModelRunnerOutput(finished_sending=None, finished_recving=None) @@ -39,7 +54,7 @@ def test_aggregate_workers_output(): assert aggregated is output1 assert aggregated.finished_sending == {'req1'} - assert aggregated.finished_recving is None + assert not aggregated.finished_recving output1 = DummyModelRunnerOutput(finished_sending=None, finished_recving=None) @@ -49,7 +64,7 @@ def test_aggregate_workers_output(): aggregated = aggregator.aggregate([output1, output2]) assert aggregated is output1 - assert aggregated.finished_sending is None + assert not aggregated.finished_sending assert aggregated.finished_recving == {'req2'} @@ -70,8 +85,8 @@ def test_async_aggregate_workers_output(): assert result_future.done() aggregated = result_future.result() assert aggregated is output1 - assert aggregated.finished_sending is None - assert aggregated.finished_recving is None + assert not aggregated.finished_sending + assert not aggregated.finished_recving future1 = Future() future2 = Future() @@ -88,7 +103,7 @@ def test_async_aggregate_workers_output(): aggregated = result_future.result() assert aggregated is output1 assert aggregated.finished_sending == {'req1'} - assert aggregated.finished_recving is None + assert not aggregated.finished_recving future1 = Future() future2 = Future() @@ -104,5 +119,5 @@ def test_async_aggregate_workers_output(): assert result_future.done() aggregated = result_future.result() assert aggregated is output1 - assert aggregated.finished_sending is None + assert not aggregated.finished_sending assert aggregated.finished_recving == {'req2'} diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 480a7074cdf4..5ead7b31f315 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -9,6 +9,7 @@ from vllm import SamplingParams from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig, ModelConfig, SchedulerConfig, VllmConfig) +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorOutput from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa @@ -188,8 +189,9 @@ def create_model_runner_output( logprobs=None, prompt_logprobs_dict={}, pooler_output=None, - finished_sending=finished_sending, - finished_recving=finished_recving, + kv_connector_finish_output=KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving), ) diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 868b227fc899..0e1a47d88183 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -1,142 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -KVConnectorBase Class for Distributed KV Cache & Hidden State communication +"""Defines the base type for KV cache connectors.""" -The class provides two primary abstract methods: -1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states -2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states -""" +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorOutput) -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional, Union +KVConnectorBase = KVConnectorBase_V1 +KVConnectorBaseType = KVConnectorBase_V1 -import torch - -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.config import VllmConfig - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - - -class KVConnectorBase(ABC): - """ - Abstract base class for a KV connector. - - The class provides two primary abstract methods: - 1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states - 2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states - """ - - @abstractmethod - def __init__( - self, - rank: int, - local_rank: int, - config: "VllmConfig", - ): - raise NotImplementedError - - @abstractmethod - def close(self) -> None: - """Close the buffer and release resources. - - This method is responsible for cleaning up resources related to the - connector when it is no longer needed. - - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - raise NotImplementedError - - @abstractmethod - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - """ - Send KV caches and hidden states to the connector. - - This method processes the input tokens, KV caches, and - hidden/intermediate states for a given model and sends the data to the - decode instance. - - Args: - model_executable (torch.nn.Module): The model executable containing - start and end layer information. - model_input (ModelInputForGPUWithSamplingMetadata): The input - metadata from vLLM. - kv_caches (list[torch.Tensor]): List of KV caches (keys and values) - for each layer. - hidden_or_intermediate_states (Union[torch.Tensor, - IntermediateTensors]): - The hidden or intermediate states associated with the tokens. - - Returns: - None - - """ - - raise NotImplementedError - - @abstractmethod - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - """ - Receive KV caches and hidden states from the connector. - - This method attempts to retrieve KV caches and hidden states for input - tokens. If all required KV caches and hidden states are received, it - will bypass model input, else it will fall back to normal vLLM model - forwarding. - - Args: - model_executable (torch.nn.Module): - The model executable from vLLM modelrunner. - model_input (ModelInputForGPUWithSamplingMetadata): - The model input from vLLM modelrunner. - kv_caches (list[torch.Tensor]): - List of KV caches for each layer. - - Returns: - - hidden_or_intermediate_states (torch.Tensor or - IntermediateTensors): - Concatenated hidden states if all required data is retrieved, - otherwise `None`. - - bypass_model_exec (bool): - Indicates whether the model execution can be skipped (True) or - needs to be redone (False). - - model_input (ModelInputForGPUWithSamplingMetadata): - Optionally adjusted input metadata for re-execution when - `bypass_model_exec=False`. - - """ - - raise NotImplementedError - - @classmethod - def get_required_kvcache_layout( - cls, vllm_config: "VllmConfig") -> Optional[str]: - """ - Get the required KV cache layout for this connector. - Args: - vllm_config (VllmConfig): the vllm config. - - Returns: - str: the required KV cache layout. e.g. HND, or NHD. - None if the connector does not require a specific layout. - """ - return None - - -KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] +__all__ = ["KVConnectorBase", "KVConnectorBaseType", "KVConnectorOutput"] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index cf7cde2c4377..8bb5fa764172 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -5,14 +5,10 @@ from typing import TYPE_CHECKING, Callable import vllm.envs as envs -from vllm.config import KVTransferConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType -from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, - KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.base import (KVConnectorBase, + KVConnectorRole) from vllm.logger import init_logger -from .base import KVConnectorBase - if TYPE_CHECKING: from vllm.config import VllmConfig @@ -20,7 +16,7 @@ class KVConnectorFactory: - _registry: dict[str, Callable[[], type[KVConnectorBaseType]]] = {} + _registry: dict[str, Callable[[], type[KVConnectorBase]]] = {} @classmethod def register_connector(cls, name: str, module_path: str, @@ -29,28 +25,23 @@ def register_connector(cls, name: str, module_path: str, if name in cls._registry: raise ValueError(f"Connector '{name}' is already registered.") - def loader() -> type[KVConnectorBaseType]: + def loader() -> type[KVConnectorBase]: module = importlib.import_module(module_path) return getattr(module, class_name) cls._registry[name] = loader @classmethod - def create_connector_v0(cls, rank: int, local_rank: int, - config: "VllmConfig") -> KVConnectorBase: - if envs.VLLM_USE_V1: - raise ValueError("Attempting to initialize a V0 Connector, " + def create_connector( + cls, + config: "VllmConfig", + role: KVConnectorRole, + ) -> KVConnectorBase: + if not envs.VLLM_USE_V1: + raise ValueError("Attempting to initialize a V1 Connector, " f"but found {envs.VLLM_USE_V1=}") - connector_cls = cls.get_connector_class(config.kv_transfer_config) - assert issubclass(connector_cls, KVConnectorBase) - return connector_cls(rank, local_rank, config) - - @classmethod - def get_connector_class( - cls, kv_transfer_config: "KVTransferConfig" - ) -> type[KVConnectorBaseType]: - """Get the connector class by name.""" + kv_transfer_config = config.kv_transfer_config connector_name = kv_transfer_config.kv_connector if connector_name in cls._registry: connector_cls = cls._registry[connector_name]() @@ -61,21 +52,7 @@ def get_connector_class( f"Unsupported connector type: {connector_name}") connector_module = importlib.import_module(connector_module_path) connector_cls = getattr(connector_module, connector_name) - return connector_cls - - @classmethod - def create_connector_v1( - cls, - config: "VllmConfig", - role: KVConnectorRole, - ) -> KVConnectorBase_V1: - if not envs.VLLM_USE_V1: - raise ValueError("Attempting to initialize a V1 Connector, " - f"but found {envs.VLLM_USE_V1=}") - - kv_transfer_config = config.kv_transfer_config - connector_cls = cls.get_connector_class(kv_transfer_config) - assert issubclass(connector_cls, KVConnectorBase_V1) + assert issubclass(connector_cls, KVConnectorBase) logger.info("Creating v1 connector with name: %s and engine_id: %s", connector_cls.__name__, kv_transfer_config.engine_id) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. @@ -92,25 +69,6 @@ def create_connector_v1( # Register various connectors here. # The registration should not be done in each individual file, as we want to # only load the files corresponding to the current connector. -KVConnectorFactory.register_connector( - "PyNcclConnector", - "vllm.distributed.kv_transfer.kv_connector.simple_connector", - "SimpleConnector") - -KVConnectorFactory.register_connector( - "MooncakeConnector", - "vllm.distributed.kv_transfer.kv_connector.simple_connector", - "SimpleConnector") - -KVConnectorFactory.register_connector( - "LMCacheConnector", - "vllm.distributed.kv_transfer.kv_connector.lmcache_connector", - "LMCacheConnector") - -KVConnectorFactory.register_connector( - "MooncakeStoreConnector", - "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", - "MooncakeStoreConnector") KVConnectorFactory.register_connector( "SharedStorageConnector", diff --git a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py deleted file mode 100644 index 78bf3095613a..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +++ /dev/null @@ -1,99 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -LMCache KV Cache Connector for Distributed Machine Learning Inference - -The LMCacheConnector can (1) transfer KV caches between prefill vLLM worker -(KV cache producer) and decode vLLM worker (KV cache consumer) using LMCache; -(2) offload and share KV caches. -""" - -from typing import TYPE_CHECKING, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class LMCacheConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - - self.transfer_config = config.kv_transfer_config - self.vllm_config = config - - from lmcache.experimental.cache_engine import LMCacheEngineBuilder - from lmcache.integration.vllm.utils import ENGINE_NAME - from lmcache.integration.vllm.vllm_adapter import ( - RetrieveStatus, StoreStatus, init_lmcache_engine, - lmcache_retrieve_kv, lmcache_should_retrieve, lmcache_should_store, - lmcache_store_kv) - logger.info("Initializing LMCacheConfig under kv_transfer_config %s", - self.transfer_config) - - # TODO (Jiayi): Find model_config, parallel_config, and cache_config - self.engine = init_lmcache_engine(config.model_config, - config.parallel_config, - config.cache_config) - self.lmcache_engine_name = ENGINE_NAME - self.lmcache_engine_builder = LMCacheEngineBuilder - - self.model_config = config.model_config - self.parallel_config = config.parallel_config - self.cache_config = config.cache_config - self.lmcache_retrieve_kv = lmcache_retrieve_kv - self.lmcache_store_kv = lmcache_store_kv - self.lmcache_should_retrieve = lmcache_should_retrieve - self.lmcache_should_store = lmcache_should_store - self.store_status = StoreStatus - self.retrieve_status = RetrieveStatus - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - retrieve_status = self.lmcache_should_retrieve(model_input) - model_input, bypass_model_exec, hidden_or_intermediate_states =\ - self.lmcache_retrieve_kv( - model_executable, model_input, self.cache_config, kv_caches, - retrieve_status) - return hidden_or_intermediate_states, bypass_model_exec, model_input - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - store_status = self.lmcache_should_store(model_input) - self.lmcache_store_kv( - self.model_config, - self.parallel_config, - self.cache_config, - model_executable, - model_input, - kv_caches, - store_status, - ) - - def close(self): - self.lmcache_engine_builder.destroy(self.lmcache_engine_name) diff --git a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py b/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py deleted file mode 100644 index 94a7ce91acf1..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +++ /dev/null @@ -1,203 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -MooncakeStore Connector for Distributed Machine Learning Inference -The MooncakeStoreConnector transfers KV caches between prefill vLLM workers -(KV cache producer) and decode vLLM workers (KV cache consumer) using a -database-style KVStore. -""" -import hashlib -from typing import TYPE_CHECKING, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.utils import ( - model_aware_kv_ops_helper as kv_helper) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class MooncakeStoreConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - self.kv_transfer_config = config.kv_transfer_config - self.kv_helper = kv_helper(config) - self.local_tp_rank = local_rank - - # Init kv_store - if self.kv_transfer_config.kv_connector == "MooncakeStoreConnector": - # Check if MOONCAKE_CONFIG_PATH is set - import os - use_mooncake_store = os.getenv('MOONCAKE_CONFIG_PATH') is not None - - if not use_mooncake_store: - raise ValueError( - "To use MooncakeStoreConnector, you need to pass the ENV: " - "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") - else: - from vllm.distributed.kv_transfer.kv_lookup_buffer.mooncake_store import ( # noqa: E501 - MooncakeStore) - logger.info( - "Initializing KVStoreConnector under kv_transfer_config %s", - self.kv_transfer_config) - self.kv_store = MooncakeStore(config) - else: - logger.error("Can not find %s", - self.kv_transfer_config.kv_connector) - - assert self.kv_store is not None - - def close(self) -> None: - """Close the buffer and release resources. - This method is responsible for cleaning up resources related to the - connector when it is no longer needed. - Raises: - NotImplementedError: This method must be implemented in subclasses. - """ - self.kv_store.close() - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - num_heads, head_size = self.kv_helper.get_model_args(model_executable) - - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - current_tokens = input_tokens_tensor[start_pos:end_pos] - store_key_prefix = self.tensor_hash(current_tokens) - keys, values = [], [] - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = self.kv_helper.get_kv_from_cache( - kv_cache, num_heads, head_size) - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) - kvcache_to_sent = torch.stack((keys, values), dim=0) - store_kvcache_key = f"{store_key_prefix}_{self.local_tp_rank}" - self.kv_store.put(store_kvcache_key, kvcache_to_sent) - - hidden_key = f"{store_key_prefix}_hidden_{self.local_tp_rank}" - self.kv_store.put(hidden_key, - hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - bypass_model_exec = True - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - hidden_or_intermediate_states_for_one_req = [] - - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # This can happen during inflight batching. See: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You should set --enable_chunked_prefill=False " - "and --max_num_batched_tokens " - "should be equal to max_seq_len_to_capture") - bypass_model_exec = False - assert start_pos == num_prefill_tokens - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - - # get roi for current seq - load_key_prefix = self.tensor_hash(current_tokens) - load_kvcache_key = f"{load_key_prefix}_{self.local_tp_rank}" - remote_kv = self.kv_store.get(load_kvcache_key) - hidden_key = f"{load_key_prefix}_hidden_{self.local_tp_rank}" - hidden = self.kv_store.get(hidden_key) - - if remote_kv is None or hidden is None: - # didn't find any match. - bypass_model_exec = False - continue - - num_computed_tokens = current_tokens.shape[0] - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # call self.kv_store to get kv layer by layer - for layer_id in range(start_layer, end_layer): - layer = model_executable.model.layers[layer_id] - # get kvcache object - kv_cache = kv_caches[layer_id - start_layer] - - # get remote kvcache - remote_k, remote_v = remote_kv[0][layer_id], remote_kv[1][ - layer_id] - - self.kv_helper.put_kv_to_cache(model_executable, remote_k, - remote_v, layer, kv_cache, - slot_mapping, start_pos, - end_pos) - - hidden_or_intermediate_states_for_one_req.append(hidden) - - if not bypass_model_exec: - logger.warning( - "[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = None - - else: - logger.debug( - "[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input - - @staticmethod - def tensor_hash(tensor: torch.Tensor) -> int: - """Calculate the hash value of the tensor.""" - tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes() - hash_object = hashlib.blake2b(tensor_bytes) - hash_hex = hash_object.hexdigest() - return int(hash_hex[:16], 16) diff --git a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py b/vllm/distributed/kv_transfer/kv_connector/simple_connector.py deleted file mode 100644 index e7c079e1f115..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector/simple_connector.py +++ /dev/null @@ -1,329 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Simple KV Cache Connector for Distributed Machine Learning Inference - -The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache -producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or -MooncakePipe. - -But the logic can be extended to support other pipe and lookup buffer. -""" -from typing import TYPE_CHECKING, Optional, Union - -import torch - -from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase -from vllm.distributed.kv_transfer.kv_connector.utils import ( - model_aware_kv_ops_helper as kv_helper) -from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import ( - SimpleBuffer) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - -logger = init_logger(__name__) - - -class SimpleConnector(KVConnectorBase): - - def __init__( - self, - rank: int, - local_rank: int, - config: VllmConfig, - ): - - self.config = config.kv_transfer_config - self.kv_helper = kv_helper(config) - - if self.config.kv_connector == "PyNcclConnector": - from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import ( - PyNcclPipe) - logger.info( - "Initializing PyNcclConfig under kv_transfer_config %s", - self.config) - elif self.config.kv_connector == "MooncakeConnector": - # Check if MOONCAKE_CONFIG_PATH is set - import os - use_mooncake_distributed_pipe = os.getenv( - 'MOONCAKE_CONFIG_PATH') is not None - - if not use_mooncake_distributed_pipe: - raise ValueError( - "To use MooncakeConnector, you need to pass the ENV: " - "'MOONCAKE_CONFIG_PATH=/path/to/mooncake_config.json'.") - else: - from vllm.distributed.kv_transfer.kv_pipe.mooncake_pipe import ( # noqa: E501 - MooncakePipe) - logger.info( - "Initializing MooncakeConfig under kv_transfer_config %s", - self.config) - - self.lookup_buffer_size = self.config.kv_buffer_size - - self.producer_buffer: Optional[SimpleBuffer] = None - self.consumer_buffer: Optional[SimpleBuffer] = None - - self.producer_data_pipe: Union[PyNcclPipe, MooncakePipe] - self.consumer_data_pipe: Union[PyNcclPipe, MooncakePipe] - self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe] - - # 2 pipes for every rank in the world - port_offset_base = 2 * rank - - # In disaggregated prefill, the prefill vLLM only uses send pipe - # and the decode vLLM only uses recv pipe - if self.config.is_kv_producer: - - if self.config.kv_connector == "PyNcclConnector": - self.producer_data_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.producer_signal_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, - device="cpu", - ) - elif self.config.kv_connector == "MooncakeConnector": - self.producer_data_pipe = MooncakePipe( - local_rank=local_rank, - config=self.config, - ) - # We only need to initialize MooncakePipe once - self.producer_signal_pipe = self.producer_data_pipe - - self.producer_buffer = SimpleBuffer(self.producer_signal_pipe, - self.producer_data_pipe, - self.config.kv_buffer_size) - - else: - - # the current vLLM instance is KV consumer, so it needs to connect - # its recv pipe to the send pipe of KV producer - if self.config.kv_connector == "PyNcclConnector": - self.consumer_data_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base, - ) - self.consumer_signal_pipe = PyNcclPipe( - local_rank=local_rank, - config=self.config, - port_offset=port_offset_base + 1, - device="cpu", - ) - elif self.config.kv_connector == "MooncakeConnector": - self.consumer_data_pipe = MooncakePipe( - local_rank=local_rank, - config=self.config, - ) - self.consumer_signal_pipe = self.consumer_data_pipe - - self.consumer_buffer = SimpleBuffer( - self.consumer_signal_pipe, - self.consumer_data_pipe, - self.config.kv_buffer_size, - ) - - def select(self, input_tokens: Optional[torch.Tensor], - roi: Optional[torch.Tensor]) -> list[Optional[torch.Tensor]]: - - assert self.consumer_buffer is not None, "Please initialize the "\ - "consumer buffer before calling select." - return self.consumer_buffer.drop_select(input_tokens, roi) - - def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor, - key: torch.Tensor, value: torch.Tensor, - hidden: torch.Tensor) -> None: - - assert self.producer_buffer is not None, "Please initialize the "\ - "producer buffer before calling insert." - - self.producer_buffer.insert(input_tokens, roi, key, value, hidden) - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - num_heads, head_size = self.kv_helper.get_model_args(model_executable) - - # query_lens contains new KV caches that are added to vLLM. - # so we will send them to decode instance - # FIXME(Kuntai): This assume that all requests are prefill. - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You have some decode requests while using " - "SimpleConnector. Their KVCache won't be sent.") - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - - keys, values = [], [] - - for layer_id in range(start_layer, end_layer): - kv_cache = kv_caches[layer_id - start_layer] - key_cache, value_cache = self.kv_helper.get_kv_from_cache( - kv_cache, num_heads, head_size) - - current_slot_mapping = slot_mapping_flat[start_pos:end_pos] - - keys.append(key_cache[current_slot_mapping].unsqueeze(0)) - values.append(value_cache[current_slot_mapping].unsqueeze(0)) - - keys = torch.cat(keys, dim=0) - values = torch.cat(values, dim=0) - - self.insert(current_tokens, - torch.ones_like(current_tokens, - dtype=bool), keys, values, - hidden_or_intermediate_states[start_pos:end_pos]) - - logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank()) - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - # When bypass_model_exec is set to False, it means that at least for one - # request its corresponding KV cache or hidden state is missing. - # In this case we need to do prefilling to recompute missing KV cache - # and hidden states. - bypass_model_exec = True - - input_tokens_tensor = model_input.input_tokens - seq_lens = model_input.attn_metadata.seq_lens - num_prefill_tokens = model_input.attn_metadata.num_prefill_tokens - slot_mapping = model_input.attn_metadata.slot_mapping.flatten() - start_layer = model_executable.model.start_layer - end_layer = model_executable.model.end_layer - - hidden_or_intermediate_states_for_one_req = [] - - input_tokens_list = [] - num_computed_tokens_list = [] - start_pos_list = [] - - # enumerate different requests - # FIXME(Kuntai): This impl assumes that all requests are prefill. - for idx, slen in enumerate(seq_lens): - start_pos = sum(seq_lens[:idx]) - end_pos = start_pos + slen - - if start_pos >= num_prefill_tokens: - # This can happen during inflight batching. See: - # vllm/worker/model_runner.py::_prepare_model_input_tensors: - # - input_tokens[:num_prefill_tokens] contains prefill tokens. - # - input_tokens[num_prefill_tokens:] contains decode tokens. - logger.warning("You should set --enable_chunked_prefill=False " - "and --max_num_batched_tokens " - "should be equal to --max_seq_len_to_capture") - bypass_model_exec = False - assert start_pos == num_prefill_tokens - break - - current_tokens = input_tokens_tensor[start_pos:end_pos] - num_tokens = slen - - # collecting data for rebuilding the input - input_tokens_list.append(current_tokens) - start_pos_list.append(start_pos) - - ret = self.select(current_tokens, - torch.ones_like(current_tokens, dtype=bool)) - if ret[0] is None: - # didn't find any match. - bypass_model_exec = False - num_computed_tokens_list.append(0) - continue - - roi: torch.Tensor = ret[1] - keys: torch.Tensor = ret[2] - values: torch.Tensor = ret[3] - hidden: torch.Tensor = ret[4] - - num_computed_tokens = roi.shape[0] - num_computed_tokens_list.append(num_computed_tokens) - - # check if both KV cache and the hidden states are received - # If not, need to redo the forwarding to compute missing states - if not all([(num_computed_tokens == num_tokens), hidden is not None - ]): - bypass_model_exec = False - - # update the end position based on how many tokens are cached. - end_pos = start_pos + num_computed_tokens - - # put received KV caches into paged memory - for cur_layer in range(start_layer, end_layer): - - layer_id = cur_layer - start_layer - kv_cache = kv_caches[layer_id] - layer = model_executable.model.layers[cur_layer] - - # get remote kvcache - remote_k, remote_v = keys[layer_id], values[layer_id] - - self.kv_helper.put_kv_to_cache(model_executable, remote_k, - remote_v, layer, kv_cache, - slot_mapping, start_pos, - end_pos) - - hidden_or_intermediate_states_for_one_req.append(hidden) - - if not bypass_model_exec: - # Some of the KV cache is not retrieved - # Here we will fall back to normal model forwarding - # But optionally you can adjust model_input so that you only do - # prefilling on those tokens that are missing KV caches. - logger.warning( - "[rank%d]: Failed to receive all KVs and hidden " - "states, redo model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = None - - else: - logger.debug( - "[rank%d]: Successfully received all KVs and hidden " - "states, skip model forwarding.", torch.distributed.get_rank()) - hidden_or_intermediate_states = torch.cat( - hidden_or_intermediate_states_for_one_req, dim=0) - - return hidden_or_intermediate_states, bypass_model_exec, model_input - - def close(self): - self.producer_data_pipe.close() - self.consumer_data_pipe.close() - if self.config.kv_connector == "PyNcclConnector": - self.producer_signal_pipe.close() - self.consumer_signal_pipe.close() - elif self.config.kv_connector == "MooncakeConnector": - # MooncakePipe reuses data_pipe for signal_pipe, so we only have to - # close the data_pipe. - pass diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 559c233947ce..c07ad50fb148 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -13,6 +13,7 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorOutput from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.logger import init_logger @@ -131,22 +132,29 @@ def aggregate(self, output_rank: int = 0) -> ModelRunnerOutput: # aggregate finished_sending, finished_recving from all workers - def update_finished_set(req_ids: Optional[set[str]], + def update_finished_set(req_ids: set[str], remaining_count_dict: dict[str, int], finished_set: set[str]) -> None: - for req_id in req_ids or (): + for req_id in req_ids: remaining_count_dict[req_id] -= 1 if remaining_count_dict[req_id] == 0: finished_set.add(req_id) del remaining_count_dict[req_id] - finished_sending = set[str]() - finished_recving = set[str]() + final_kv_connector_finish_output = (KVConnectorOutput( + finished_sending=set(), finished_recving=set())) for output in outputs: - update_finished_set(output.finished_sending, - self._send_remaining_count, finished_sending) - update_finished_set(output.finished_recving, - self._recv_remaining_count, finished_recving) + kv_connector_finish_output = output.kv_connector_finish_output + if kv_connector_finish_output is None: + continue + update_finished_set( + kv_connector_finish_output.finished_sending, + self._send_remaining_count, + final_kv_connector_finish_output.finished_sending) + update_finished_set( + kv_connector_finish_output.finished_recving, + self._recv_remaining_count, + final_kv_connector_finish_output.finished_recving) # select output of the worker specified by output_rank output = outputs[output_rank] @@ -154,11 +162,9 @@ def update_finished_set(req_ids: Optional[set[str]], # set the aggregated finished_sending / finished_recving # if output.finished_sending/recving is not empty, but the other ranks # still have unfinished send/recv, we want to set the aggregated - # finished_sending/recving to None until all ranks have finished + # finished_sending/recving to empty set until all ranks have finished # send/recv - output.finished_sending = finished_sending if finished_sending else None - output.finished_recving = finished_recving if finished_recving else None - + output.kv_connector_finish_output = final_kv_connector_finish_output return output def async_aggregate(self, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 7a2ccb58656f..a4e4a1ab5dc7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -63,6 +63,23 @@ class KVConnectorRole(enum.Enum): WORKER = 1 +class KVConnectorOutput: + """Output of get_finished() method. + + - finished_sending: request ids that have finished sending KV + - finished_recving: request ids that have finished receiving KV + """ + + def __init__( + self, + *, + finished_sending: set[str], + finished_recving: set[str], + ): + self.finished_sending = finished_sending + self.finished_recving = finished_recving + + class KVConnectorMetadata(ABC): # noqa: B024 """ Abstract Metadata used to communicate between the @@ -201,9 +218,7 @@ def wait_for_save(self): """ pass - def get_finished( - self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + def get_finished(self, finished_req_ids: set[str]) -> KVConnectorOutput: """ Notifies worker-side connector ids of requests that have finished generating tokens on the worker. @@ -217,7 +232,8 @@ def get_finished( The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ - return None, None + return KVConnectorOutput(finished_sending=set(), + finished_recving=set()) # ============================== # Scheduler-side methods diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index e838ac2499c0..2d529037e927 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -7,7 +7,8 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorOutput, + KVConnectorRole) from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput @@ -87,9 +88,7 @@ def wait_for_save(self): """ self._lmcache_engine.wait_for_save() - def get_finished( - self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + def get_finished(self, finished_req_ids: set[str]) -> KVConnectorOutput: """ Notifies worker-side connector ids of requests that have finished generating tokens. @@ -101,7 +100,10 @@ def get_finished( The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ - return self._lmcache_engine.get_finished(finished_req_ids) + finished_sending, finished_recving = self._lmcache_engine.get_finished( + finished_req_ids) + return KVConnectorOutput(finished_sending=finished_sending, + finished_recving=finished_recving) # ============================== # Scheduler-side methods diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 934a03a12ee5..afbdd5718f9e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -10,7 +10,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import ( KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorOutput, + KVConnectorRole) from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput @@ -52,7 +53,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): temp_config.kv_transfer_config = KVTransferConfig( **ktc, engine_id=engine_id) self._connectors.append( - KVConnectorFactory.create_connector_v1(temp_config, role)) + KVConnectorFactory.create_connector(temp_config, role)) # A mapping from request id to the index of the connector chosen to # load the request from (if any). @@ -105,13 +106,13 @@ def wait_for_save(self): for c in self._connectors: c.wait_for_save() - def get_finished( - self, finished_req_ids: set[str] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + def get_finished(self, finished_req_ids: set[str]) -> KVConnectorOutput: finished_sending: set[str] = set() finished_recving: set[str] = set() for c in self._connectors: - sending, recving = c.get_finished(finished_req_ids) + output = c.get_finished(finished_req_ids) + sending = output.finished_sending + recving = output.finished_recving if not recving and not sending: continue # Aggregate finished recving request ids. @@ -130,7 +131,8 @@ def get_finished( else: self._extra_async_saves[req_id] = extra_pending - 1 - return finished_sending or None, finished_recving or None + return KVConnectorOutput(finished_sending=finished_sending, + finished_recving=finished_recving) # ============================== # Scheduler-side methods diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index e7fc2b118145..b488314df2be 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -21,7 +21,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorOutput, + KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) @@ -196,8 +197,7 @@ def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp): assert self.connector_worker is not None self.connector_worker.set_host_xfer_buffer_ops(copy_operation) - def get_finished(self, - finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + def get_finished(self, finished_req_ids: set[str]) -> KVConnectorOutput: """Get the finished recving and sending requests.""" assert self.connector_worker is not None return self.connector_worker.get_finished() @@ -1017,7 +1017,7 @@ def save_kv_to_host(self, metadata: NixlConnectorMetadata): self.copy_blocks(self.device_kv_caches, self.host_xfer_buffers, meta.local_block_ids, meta.local_block_ids, "d2h") - def get_finished(self) -> tuple[set[str], set[str]]: + def get_finished(self) -> KVConnectorOutput: """ Get requests that are done sending or recving on this specific worker. The scheduler process (via the MultiprocExecutor) will use this output @@ -1052,7 +1052,8 @@ def get_finished(self) -> tuple[set[str], set[str]]: del self._reqs_to_send[req_id] done_sending.add(req_id) - return done_sending, done_recving + return KVConnectorOutput(finished_sending=done_sending, + finished_recving=done_recving) def _get_new_notifs(self) -> set[str]: """ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 32d0e43d71af..2fdb98edbab9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -9,7 +9,8 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorOutput, + KVConnectorRole) from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import ( P2pNcclEngine) from vllm.distributed.parallel_state import get_world_group @@ -261,9 +262,8 @@ def wait_for_save(self): assert self.p2p_nccl_engine is not None self.p2p_nccl_engine.wait_for_sent() - def get_finished( - self, finished_req_ids: set[str], - **kwargs) -> tuple[Optional[set[str]], Optional[set[str]]]: + def get_finished(self, finished_req_ids: set[str], + **kwargs) -> KVConnectorOutput: """ Notifies worker-side connector ids of requests that have finished generating tokens. @@ -279,8 +279,12 @@ def get_finished( no_compile_layers = ( self._vllm_config.compilation_config.static_forward_context) - return self.p2p_nccl_engine.get_finished(finished_req_ids, - no_compile_layers) + finished_sending, finished_recving = self.p2p_nccl_engine.get_finished( + finished_req_ids, no_compile_layers) + return KVConnectorOutput( + finished_sending=finished_sending or set(), + finished_recving=finished_recving or set(), + ) # ============================== # Scheduler-side methods diff --git a/vllm/distributed/kv_transfer/kv_connector_agent.py b/vllm/distributed/kv_transfer/kv_connector_agent.py deleted file mode 100644 index 8633fdaf59f8..000000000000 --- a/vllm/distributed/kv_transfer/kv_connector_agent.py +++ /dev/null @@ -1,77 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A centralized entrypoint to perform distributed KV cache transfer. - -This implementation is a shim wrapper on two APIs exposed by `kv_connector`: -1. `send_kv_caches_and_hidden_states` -2. `recv_kv_caches_and_hidden_states -""" -from typing import TYPE_CHECKING, Union - -if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata - from vllm.config import VllmConfig - -import torch - -from vllm.distributed.kv_transfer.kv_connector.factory import ( - KVConnectorFactory) -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors - -logger = init_logger(__name__) - - -class KVTransferAgent: - """ - A class designated for distributed KV transfer - - Target use cases: - 1. Disaggregated prefill - 2. Remote KV cache storage - """ - - def __init__( - self, - rank: int, - local_rank: int, - config: "VllmConfig", - ): - - self.config = config - - if config.kv_transfer_config is None: - raise ValueError("KVTransferConfig is not set in the VllmConfig," - " cannot initialize KVConnector.") - - assert self.config.kv_transfer_config.is_kv_transfer_instance, "KV"\ - "TransferAgent should only be used when kv_connector is set." - - self.connector = KVConnectorFactory.create_connector_v0( - rank, local_rank, config) - - def send_kv_caches_and_hidden_states( - self, - model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor], - hidden_or_intermediate_states: Union[torch.Tensor, - IntermediateTensors], - ) -> None: - - self.connector.send_kv_caches_and_hidden_states( - model_executable, model_input, kv_caches, - hidden_or_intermediate_states) - - def close(self) -> None: - self.connector.close() - - def recv_kv_caches_and_hidden_states( - self, model_executable: torch.nn.Module, - model_input: "ModelInputForGPUWithSamplingMetadata", - kv_caches: list[torch.Tensor] - ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, - "ModelInputForGPUWithSamplingMetadata"]: - - return self.connector.recv_kv_caches_and_hidden_states( - model_executable, model_input, kv_caches) diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index 60f1d5d8bca7..5e0f64fca220 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -8,7 +8,6 @@ KVConnectorFactory) from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, KVConnectorRole) -from vllm.distributed.parallel_state import get_world_group if TYPE_CHECKING: from vllm.config import VllmConfig @@ -61,11 +60,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: if (vllm_config.kv_transfer_config.is_kv_transfer_instance and _KV_CONNECTOR_AGENT is None): if envs.VLLM_USE_V1: - _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v1( + _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( config=vllm_config, role=KVConnectorRole.WORKER) else: - _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector_v0( - rank=get_world_group().rank, - local_rank=get_world_group().local_rank, - config=vllm_config, - ) + raise ValueError("V0 is no longer supported") diff --git a/vllm/sequence.py b/vllm/sequence.py index fe87b52f9df1..017b6be9b5fb 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -15,6 +15,7 @@ import msgspec import torch +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorOutput from vllm.inputs import SingletonInputs from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict @@ -1164,9 +1165,8 @@ class IntermediateTensors: """ tensors: dict[str, torch.Tensor] - # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None + + kv_connector_finish_output: Optional[KVConnectorOutput] = None def __init__(self, tensors): # manually define this function, so that diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 446f98034cb8..6ee7e2703ddc 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -83,7 +83,7 @@ def __init__( assert len(self.kv_cache_config.kv_cache_groups) == 1, ( "Multiple KV cache groups are not currently supported " "with KV connectors") - self.connector = KVConnectorFactory.create_connector_v1( + self.connector = KVConnectorFactory.create_connector( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) self.kv_event_publisher = EventPublisherFactory.create( @@ -1139,9 +1139,15 @@ def _update_from_kv_xfer_finished(self, scheduler the request during the next step. """ # KV Connector:: update recv and send status from last step. - for req_id in (model_runner_output.finished_recving or ()): + kv_connector_finish_output = ( + model_runner_output.kv_connector_finish_output) + finished_recving = (kv_connector_finish_output.finished_recving + if kv_connector_finish_output else set()) + finished_sending = (kv_connector_finish_output.finished_sending + if kv_connector_finish_output else set()) + for req_id in finished_recving: logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (model_runner_output.finished_sending or ()): + for req_id in finished_sending: logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f78623f571b2..5dd745469832 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -6,6 +6,8 @@ import torch +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorOutput + class LogprobsLists(NamedTuple): @@ -104,9 +106,7 @@ class ModelRunnerOutput: # [num_reqs, hidden_size] pooler_output: list[Optional[torch.Tensor]] - # [req_ids] - finished_sending: Optional[set[str]] = None - finished_recving: Optional[set[str]] = None + kv_connector_finish_output: Optional[KVConnectorOutput] = None # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None @@ -119,6 +119,5 @@ class ModelRunnerOutput: logprobs=None, prompt_logprobs_dict={}, pooler_output=[], - finished_sending=None, - finished_recving=None, + kv_connector_finish_output=None, num_nans_in_logits=None) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 987ef22a1b7f..553ada8f3057 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -23,6 +23,7 @@ from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorOutput from vllm.distributed.parallel_state import ( get_pp_group, get_tp_group, graph_capture, is_global_first_rank, prepare_communication_buffer_for_model) @@ -1427,8 +1428,7 @@ def _pool( hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - finished_sending: Optional[set[str]], - finished_recving: Optional[set[str]], + kv_connector_finish_output: Optional[KVConnectorOutput], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1463,8 +1463,7 @@ def _pool( logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - finished_sending=finished_sending, - finished_recving=finished_recving, + kv_connector_finish_output=kv_connector_finish_output, ) @torch.inference_mode() @@ -1583,7 +1582,7 @@ def execute_model( ) self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( + kv_connector_finish_output = ( self.get_finished_kv_transfers(scheduler_output)) if self.use_aux_hidden_state_outputs: @@ -1602,9 +1601,9 @@ def execute_model( if not get_pp_group().is_last_rank: # For mid-pipeline stages, return the hidden states. if not broadcast_pp_output: - if finished_sending or finished_recving: - hidden_states.finished_sending = finished_sending - hidden_states.finished_recving = finished_recving + if kv_connector_finish_output is not None: + hidden_states.kv_connector_finish_output = ( + kv_connector_finish_output) return hidden_states assert isinstance(hidden_states, IntermediateTensors) get_pp_group().send_tensor_dict(hidden_states.tensors, @@ -1613,8 +1612,8 @@ def execute_model( else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, - finished_recving) + num_scheduled_tokens_np, + kv_connector_finish_output) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1764,8 +1763,7 @@ def execute_model( logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, + kv_connector_finish_output=kv_connector_finish_output, num_nans_in_logits=num_nans_in_logits, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0f46ed223ab8..64e971730265 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -370,10 +370,10 @@ def execute_model( # In case of PP with kv transfer, we need to pass through the # finished_sending and finished_recving buffers. new_output = EMPTY_MODEL_RUNNER_OUTPUT - if output.finished_sending or output.finished_recving: + if output.kv_connector_finish_output: new_output = copy.copy(new_output) - new_output.finished_sending = output.finished_sending - new_output.finished_recving = output.finished_recving + new_output.kv_connector_finish_output = ( + output.kv_connector_finish_output) output = new_output assert isinstance(output, ModelRunnerOutput) diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index 5a3186058fcf..0b77b44b99c6 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -9,7 +9,8 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) -from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.distributed.kv_transfer.kv_connector.base import (KVConnectorBase, + KVConnectorOutput) from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput @@ -28,7 +29,7 @@ def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): # Update KVConnector with the KVConnector metadata forward(). if has_kv_transfer_group(): kv_connector = get_kv_transfer_group() - assert isinstance(kv_connector, KVConnectorBase_V1) + assert isinstance(kv_connector, KVConnectorBase) assert scheduler_output.kv_connector_metadata is not None kv_connector.bind_connector_metadata( scheduler_output.kv_connector_metadata) @@ -46,25 +47,23 @@ def maybe_wait_for_kv_save() -> None: @staticmethod def get_finished_kv_transfers( - scheduler_output: "SchedulerOutput", - ) -> tuple[Optional[set[str]], Optional[set[str]]]: + scheduler_output: "SchedulerOutput", ) -> Optional[KVConnectorOutput]: if has_kv_transfer_group(): return get_kv_transfer_group().get_finished( scheduler_output.finished_req_ids) - return None, None + return None def kv_connector_no_forward(self, scheduler_output: "SchedulerOutput", vllm_config: VllmConfig) -> ModelRunnerOutput: # KV send/recv even if no work to do. with set_forward_context(None, vllm_config): self.maybe_setup_kv_connector(scheduler_output) - finished_sending, finished_recving = ( + kv_connector_finish_output = ( self.get_finished_kv_transfers(scheduler_output)) - if not finished_sending and not finished_recving: + if not kv_connector_finish_output: return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving + output.kv_connector_finish_output = kv_connector_finish_output return output diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 59cbb0150570..05555414073a 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1078,7 +1078,7 @@ def execute_model( # should be called right after each single forward pass, # instead of the forwards of the entire input batch. self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( + kv_connector_finish_output = ( self.get_finished_kv_transfers(scheduler_output)) selected_token_ids = torch.cat(combined_selected_tokens, dim=0) @@ -1175,8 +1175,7 @@ def concat_lists(input_lists): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, + kv_connector_finish_output=kv_connector_finish_output, ) # Check there are no new graphs compiled - all the graphs should be