Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,33 @@ def build_kv_connector_stats(
) -> KVConnectorStats | None:
return MockConnectorStats(data=data) if data is not None else None

def start_load_kv(self, forward_context, **kwargs):
pass

def wait_for_layer_load(self, layer_name):
pass

def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
pass

def wait_for_save(self):
pass

def build_connector_meta(self, scheduler_output):
return None

def get_num_new_matched_tokens(self, request, num_computed_tokens):
return (0, False)

def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
pass


class MockCrossLayerConnector(MockConnector):
@property
def prefer_cross_layer_blocks(self) -> bool:
return True


# Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
Expand Down Expand Up @@ -601,3 +628,21 @@ def test_is_empty_with_multiple_connectors(self):
# One non-empty
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
assert not stats.is_empty()


class TestMultiConnectorPreferCrossLayerBlocks:
def test_all_connectors_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
]
assert mc.prefer_cross_layer_blocks is True

def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False
16 changes: 8 additions & 8 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import enum
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional

import torch

Expand Down Expand Up @@ -144,15 +144,15 @@ class KVConnectorMetadata(ABC): # noqa: B024
class KVConnectorBase_V1(ABC):
"""
Base class for KV connectors.

Attributes:
prefer_cross_layer_blocks (bool): Indicates whether this connector
prefers KV blocks that hold KV data for all layers (for speeding
up KV data transfers).
Defaults to False.
"""

prefer_cross_layer_blocks: ClassVar[bool] = False
@property
def prefer_cross_layer_blocks(self) -> bool:
"""
Indicates whether this connector prefers KV blocks that hold KV data for all
layers, which can speed up KV data transfers. Defaults to False.
"""
return False

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
Expand Down Expand Up @@ -138,6 +138,12 @@ def __init__(
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}

@property
def prefer_cross_layer_blocks(self) -> bool:
if not self._connectors:
return False
return all(c.prefer_cross_layer_blocks for c in self._connectors)

@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
Expand All @@ -164,6 +170,13 @@ def _get_connector_classes_and_configs(
)
return ret

def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
# Register on all connectors
for c in self._connectors:
c.register_cross_layers_kv_cache(kv_cache, attn_backend)

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import Any, ClassVar
from typing import Any

import torch

Expand Down Expand Up @@ -44,7 +44,9 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):


class OffloadingConnector(KVConnectorBase_V1):
prefer_cross_layer_blocks: ClassVar[bool] = True
@property
def prefer_cross_layer_blocks(self) -> bool:
return True

def __init__(
self,
Expand Down