Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,6 @@ steps:
# this test fails consistently.
# TODO: investigate and fix
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
- pytest -v -s models/multimodal/generation/test_maverick.py

Expand Down
120 changes: 0 additions & 120 deletions tests/kv_transfer/test_disagg.py

This file was deleted.

35 changes: 25 additions & 10 deletions tests/v1/kv_connector/unit/test_output_aggreagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from concurrent.futures import Future
from typing import Optional

from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorOutput
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.v1.outputs import ModelRunnerOutput

Expand All @@ -12,8 +13,22 @@ class DummyModelRunnerOutput(ModelRunnerOutput):
def __init__(self,
finished_sending: Optional[set[str]] = None,
finished_recving: Optional[set[str]] = None):
self.finished_sending = finished_sending
self.finished_recving = finished_recving
self.kv_connector_finish_output = (KVConnectorOutput(
finished_sending=finished_sending or set(),
finished_recving=finished_recving or set(),
finished_loading_num_tokens={}))

@property
def finished_sending(self) -> set[str]:
if self.kv_connector_finish_output is None:
return set()
return self.kv_connector_finish_output.finished_sending

@property
def finished_recving(self) -> set[str]:
if self.kv_connector_finish_output is None:
return set()
return self.kv_connector_finish_output.finished_recving


def test_aggregate_workers_output():
Expand All @@ -27,8 +42,8 @@ def test_aggregate_workers_output():
aggregated = aggregator.aggregate([output1, output2])

assert aggregated is output1
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert not aggregated.finished_sending
assert not aggregated.finished_recving

output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
Expand All @@ -39,7 +54,7 @@ def test_aggregate_workers_output():

assert aggregated is output1
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None
assert not aggregated.finished_recving

output1 = DummyModelRunnerOutput(finished_sending=None,
finished_recving=None)
Expand All @@ -49,7 +64,7 @@ def test_aggregate_workers_output():
aggregated = aggregator.aggregate([output1, output2])

assert aggregated is output1
assert aggregated.finished_sending is None
assert not aggregated.finished_sending
assert aggregated.finished_recving == {'req2'}


Expand All @@ -70,8 +85,8 @@ def test_async_aggregate_workers_output():
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
assert aggregated.finished_sending is None
assert aggregated.finished_recving is None
assert not aggregated.finished_sending
assert not aggregated.finished_recving

future1 = Future()
future2 = Future()
Expand All @@ -88,7 +103,7 @@ def test_async_aggregate_workers_output():
aggregated = result_future.result()
assert aggregated is output1
assert aggregated.finished_sending == {'req1'}
assert aggregated.finished_recving is None
assert not aggregated.finished_recving

future1 = Future()
future2 = Future()
Expand All @@ -104,5 +119,5 @@ def test_async_aggregate_workers_output():
assert result_future.done()
aggregated = result_future.result()
assert aggregated is output1
assert aggregated.finished_sending is None
assert not aggregated.finished_sending
assert aggregated.finished_recving == {'req2'}
6 changes: 4 additions & 2 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorOutput
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
Expand Down Expand Up @@ -188,8 +189,9 @@ def create_model_runner_output(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
finished_sending=finished_sending,
finished_recving=finished_recving,
kv_connector_finish_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving),
)


Expand Down
143 changes: 6 additions & 137 deletions vllm/distributed/kv_transfer/kv_connector/base.py
Original file line number Diff line number Diff line change
@@ -1,142 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KVConnectorBase Class for Distributed KV Cache & Hidden State communication
"""Defines the base type for KV cache connectors."""

The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorOutput)

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Union
KVConnectorBase = KVConnectorBase_V1
KVConnectorBaseType = KVConnectorBase_V1

import torch

from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.sequence import IntermediateTensors

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata


class KVConnectorBase(ABC):
"""
Abstract base class for a KV connector.

The class provides two primary abstract methods:
1. send_kv_caches_and_hidden_states(): Send KV caches and hidden states
2. recv_kv_caches_and_hidden_states(): Recv KV caches and hidden states
"""

@abstractmethod
def __init__(
self,
rank: int,
local_rank: int,
config: "VllmConfig",
):
raise NotImplementedError

@abstractmethod
def close(self) -> None:
"""Close the buffer and release resources.

This method is responsible for cleaning up resources related to the
connector when it is no longer needed.

Raises:
NotImplementedError: This method must be implemented in subclasses.
"""
raise NotImplementedError

@abstractmethod
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
"""
Send KV caches and hidden states to the connector.

This method processes the input tokens, KV caches, and
hidden/intermediate states for a given model and sends the data to the
decode instance.

Args:
model_executable (torch.nn.Module): The model executable containing
start and end layer information.
model_input (ModelInputForGPUWithSamplingMetadata): The input
metadata from vLLM.
kv_caches (list[torch.Tensor]): List of KV caches (keys and values)
for each layer.
hidden_or_intermediate_states (Union[torch.Tensor,
IntermediateTensors]):
The hidden or intermediate states associated with the tokens.

Returns:
None

"""

raise NotImplementedError

@abstractmethod
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: list[torch.Tensor]
) -> tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
"""
Receive KV caches and hidden states from the connector.

This method attempts to retrieve KV caches and hidden states for input
tokens. If all required KV caches and hidden states are received, it
will bypass model input, else it will fall back to normal vLLM model
forwarding.

Args:
model_executable (torch.nn.Module):
The model executable from vLLM modelrunner.
model_input (ModelInputForGPUWithSamplingMetadata):
The model input from vLLM modelrunner.
kv_caches (list[torch.Tensor]):
List of KV caches for each layer.

Returns:
- hidden_or_intermediate_states (torch.Tensor or
IntermediateTensors):
Concatenated hidden states if all required data is retrieved,
otherwise `None`.
- bypass_model_exec (bool):
Indicates whether the model execution can be skipped (True) or
needs to be redone (False).
- model_input (ModelInputForGPUWithSamplingMetadata):
Optionally adjusted input metadata for re-execution when
`bypass_model_exec=False`.

"""

raise NotImplementedError

@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.

Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
return None


KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
__all__ = ["KVConnectorBase", "KVConnectorBaseType", "KVConnectorOutput"]
Loading
Loading