Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,17 @@ def __init__(
local_rank: int,
config: "VllmConfig",
):
raise NotImplementedError
"""
Initialize the KV connector.

Args:
rank (int): The global rank of the current process.
local_rank (int): The local rank of the current process.
config (VllmConfig): The configuration object for vLLM.
"""
self.config = config
self.rank = rank
self.local_rank = local_rank

@abstractmethod
def close(self) -> None:
Expand Down Expand Up @@ -84,6 +94,46 @@ def send_kv_caches_and_hidden_states(

raise NotImplementedError

def send_kv_caches_and_hidden_states_with_ori_input(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
model_input_before_recv: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
"""
Send KV caches and hidden states to the connector.

This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.

Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
model_input_before_recv (ModelInputForGPUWithSamplingMetadata): The
original input metadata from vLLM before receiving data.
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.

Returns:
None

"""

self.send_kv_caches_and_hidden_states(
model_executable=model_executable,
model_input=model_input,
kv_caches=kv_caches,
hidden_or_intermediate_states=hidden_or_intermediate_states)

@abstractmethod
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
Expand Down Expand Up @@ -123,5 +173,9 @@ def recv_kv_caches_and_hidden_states(

raise NotImplementedError

def get_config(self) -> "VllmConfig":
"""Get the vllmConfig."""
return self.config


KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
9 changes: 8 additions & 1 deletion vllm/distributed/kv_transfer/kv_connector/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def create_connector_v0(cls, rank: int, local_rank: int,

connector_cls = cls._registry[connector_name]()
assert issubclass(connector_cls, KVConnectorBase)
return connector_cls(rank, local_rank, config)
connector = connector_cls(rank, local_rank, config)
KVConnectorBase.__init__(connector, rank, local_rank, config)
return connector

@classmethod
def create_connector_v1(
Expand Down Expand Up @@ -105,6 +107,11 @@ def create_connector_v1(
"vllm.distributed.kv_transfer.kv_connector.mooncake_store_connector",
"MooncakeStoreConnector")

KVConnectorFactory.register_connector(
"MultiConnectorV0",
"vllm.distributed.kv_transfer.kv_connector.multi_connector",
"MultiConnectorV0")

KVConnectorFactory.register_connector(
"SharedStorageConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector",
Expand Down
96 changes: 96 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/multi_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
"""
MultiConnectorV0 - v0 implementation combining multiple KV connectors
"""
import copy
from typing import TYPE_CHECKING, Union

import torch

from vllm.config import KVTransferConfig, VllmConfig, logger
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.sequence import IntermediateTensors

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata


class MultiConnectorV0(KVConnectorBase):

def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self._connectors = []
ktcs = config.kv_transfer_config.kv_connector_extra_config.get(
"connectors", [])
assert ktcs is not None
for ktc in ktcs:
temp_config = copy.copy(config)
temp_config.kv_transfer_config = KVTransferConfig(**ktc)
self._connectors.append(
KVConnectorFactory.create_connector_v0(rank, local_rank,
temp_config))

def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
for connector in self._connectors:
if connector.get_config().kv_transfer_config.is_kv_consumer:
hidden, bypass, new_input = (
connector.recv_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches))
# if bypass or model_input changed, return immediately
if len(self._connectors
) == 1 or bypass or new_input is not model_input:
return hidden, bypass, new_input
return None, False, model_input

def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
raise RuntimeError("Should not call this method in multi connector")

def send_kv_caches_and_hidden_states_with_ori_input(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
model_input_before_recv: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
for connector in self._connectors:
kv_transfer_config = connector.get_config().kv_transfer_config
if kv_transfer_config.is_kv_producer:
is_transfer = kv_transfer_config.kv_connector_extra_config.get(
"transfer", False)
if is_transfer:
model_input = model_input_before_recv
connector.send_kv_caches_and_hidden_states(
model_executable, model_input, kv_caches,
hidden_or_intermediate_states)
logger.debug(
"sent to connector %s with mode transfer=%s",
connector.get_config().kv_transfer_config.kv_connector,
is_transfer)
else:
logger.debug(
"not sending to connector %s",
connector.get_config().kv_transfer_config.kv_connector)

def close(self) -> None:
for connector in self._connectors:
connector.close()
5 changes: 4 additions & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,7 @@ def execute_model(
# we can skip prefilling on tokens that successfully received KV caches
# NOTE: The receive operation is blocking
bypass_model_exec = False
model_input_before_recv = model_input
if self.need_recv_kv(model_input, kv_caches):
hidden_or_intermediate_states, bypass_model_exec, model_input = \
get_kv_transfer_group().recv_kv_caches_and_hidden_states(
Expand Down Expand Up @@ -1861,12 +1862,14 @@ def execute_model(
# Sending KV cache in distributed KV cache transfer setting
# NOTE: the send operation is non-blocking
if self.need_send_kv(model_input, kv_caches):
get_kv_transfer_group().send_kv_caches_and_hidden_states(
get_kv_transfer_group(
).send_kv_caches_and_hidden_states_with_ori_input(
# model_executable is used to know which layer the current
# worker is working on, so that we can send KV for only those
# layers.
model_executable,
model_input,
model_input_before_recv,
kv_caches,
hidden_or_intermediate_states,
)
Expand Down