diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index e9b70610e8cd..e1ce823e94f0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -36,7 +36,17 @@ def __init__( local_rank: int, config: "VllmConfig", ): - raise NotImplementedError + """ + Initialize the KV connector. + + Args: + rank (int): The global rank of the current process. + local_rank (int): The local rank of the current process. + config (VllmConfig): The configuration object for vLLM. + """ + self.config = config + self.rank = rank + self.local_rank = local_rank @abstractmethod def close(self) -> None: @@ -84,6 +94,46 @@ def send_kv_caches_and_hidden_states( raise NotImplementedError + def send_kv_caches_and_hidden_states_with_ori_input( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + model_input_before_recv: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + """ + Send KV caches and hidden states to the connector. + + This method processes the input tokens, KV caches, and + hidden/intermediate states for a given model and sends the data to the + decode instance. + + Args: + model_executable (torch.nn.Module): The model executable containing + start and end layer information. + model_input (ModelInputForGPUWithSamplingMetadata): The input + metadata from vLLM. + model_input_before_recv (ModelInputForGPUWithSamplingMetadata): The + original input metadata from vLLM before receiving data. + kv_caches (list[torch.Tensor]): List of KV caches (keys and values) + for each layer. + hidden_or_intermediate_states (Union[torch.Tensor, + IntermediateTensors]): + The hidden or intermediate states associated with the tokens. + + Returns: + None + + """ + + self.send_kv_caches_and_hidden_states( + model_executable=model_executable, + model_input=model_input, + kv_caches=kv_caches, + hidden_or_intermediate_states=hidden_or_intermediate_states) + @abstractmethod def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, @@ -123,5 +173,9 @@ def recv_kv_caches_and_hidden_states( raise NotImplementedError + def get_config(self) -> "VllmConfig": + """Get the vllmConfig.""" + return self.config + KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 06b3983ed68b..0ba8c2e2bdc1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -46,7 +46,9 @@ def create_connector_v0(cls, rank: int, local_rank: int, connector_cls = cls._registry[connector_name]() assert issubclass(connector_cls, KVConnectorBase) - return connector_cls(rank, local_rank, config) + connector = connector_cls(rank, local_rank, config) + KVConnectorBase.__init__(connector, rank, local_rank, config) + return connector @classmethod def create_connector_v1( @@ -105,6 +107,11 @@ def create_connector_v1( "vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector", "MooncakeStoreConnector") +KVConnectorFactory.register_connector( + "MultiConnectorV0", + "vllm.distributed.kv_transfer.kv_connector.multi_connector", + "MultiConnectorV0") + KVConnectorFactory.register_connector( "SharedStorageConnector", "vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector", diff --git a/vllm/distributed/kv_transfer/kv_connector/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/multi_connector.py new file mode 100644 index 000000000000..47a9a3330b95 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/multi_connector.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +MultiConnectorV0 - v0 implementation combining multiple KV connectors +""" +import copy +from typing import TYPE_CHECKING, Union + +import torch + +from vllm.config import KVTransferConfig, VllmConfig, logger +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) +from vllm.sequence import IntermediateTensors + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + + +class MultiConnectorV0(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self._connectors = [] + ktcs = config.kv_transfer_config.kv_connector_extra_config.get( + "connectors", []) + assert ktcs is not None + for ktc in ktcs: + temp_config = copy.copy(config) + temp_config.kv_transfer_config = KVTransferConfig(**ktc) + self._connectors.append( + KVConnectorFactory.create_connector_v0(rank, local_rank, + temp_config)) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor] + ) -> tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + for connector in self._connectors: + if connector.get_config().kv_transfer_config.is_kv_consumer: + hidden, bypass, new_input = ( + connector.recv_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches)) + # if bypass or model_input changed, return immediately + if len(self._connectors + ) == 1 or bypass or new_input is not model_input: + return hidden, bypass, new_input + return None, False, model_input + + def send_kv_caches_and_hidden_states( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + raise RuntimeError("Should not call this method in multi connector") + + def send_kv_caches_and_hidden_states_with_ori_input( + self, + model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + model_input_before_recv: "ModelInputForGPUWithSamplingMetadata", + kv_caches: list[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, + IntermediateTensors], + ) -> None: + for connector in self._connectors: + kv_transfer_config = connector.get_config().kv_transfer_config + if kv_transfer_config.is_kv_producer: + is_transfer = kv_transfer_config.kv_connector_extra_config.get( + "transfer", False) + if is_transfer: + model_input = model_input_before_recv + connector.send_kv_caches_and_hidden_states( + model_executable, model_input, kv_caches, + hidden_or_intermediate_states) + logger.debug( + "sent to connector %s with mode transfer=%s", + connector.get_config().kv_transfer_config.kv_connector, + is_transfer) + else: + logger.debug( + "not sending to connector %s", + connector.get_config().kv_transfer_config.kv_connector) + + def close(self) -> None: + for connector in self._connectors: + connector.close() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 12025617e512..7a0e3f6ba75c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1815,6 +1815,7 @@ def execute_model( # we can skip prefilling on tokens that successfully received KV caches # NOTE: The receive operation is blocking bypass_model_exec = False + model_input_before_recv = model_input if self.need_recv_kv(model_input, kv_caches): hidden_or_intermediate_states, bypass_model_exec, model_input = \ get_kv_transfer_group().recv_kv_caches_and_hidden_states( @@ -1861,12 +1862,14 @@ def execute_model( # Sending KV cache in distributed KV cache transfer setting # NOTE: the send operation is non-blocking if self.need_send_kv(model_input, kv_caches): - get_kv_transfer_group().send_kv_caches_and_hidden_states( + get_kv_transfer_group( + ).send_kv_caches_and_hidden_states_with_ori_input( # model_executable is used to know which layer the current # worker is working on, so that we can send KV for only those # layers. model_executable, model_input, + model_input_before_recv, kv_caches, hidden_or_intermediate_states, )