Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 176 additions & 25 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,27 @@
import tempfile
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock

import pytest

from tests.v1.kv_connector.unit.utils import create_vllm_config
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
MultiKVConnectorStats,
MultiKVConnectorWorkerMetadata,
)
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
NixlKVConnectorStats,
)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import KVConnectorOutput, KVConnectorWorkerMetadata

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"

Expand All @@ -40,7 +46,14 @@ class MockConnectorStats(KVConnectorStats):


class MockConnector(KVConnectorBase_V1):
"""Mock connector that implements build_kv_connector_stats for testing."""
"""Mock connector for testing."""

def __new__(cls, *args, **kwargs):
# mock all KVConnectorBase_V1 functions
mock = MagicMock(spec_set=KVConnectorBase_V1)
# Override just build_kv_connector_stats
mock.build_kv_connector_stats = cls.build_kv_connector_stats
return mock

@classmethod
def build_kv_connector_stats(
Expand Down Expand Up @@ -70,16 +83,42 @@ def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
pass


class MockCrossLayerConnector(MockConnector):
@property
def prefer_cross_layer_blocks(self) -> bool:
return True


# Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)


@pytest.fixture
def mc() -> MultiConnector:
"""MultiConnector using two mocked connectors"""
vllm_config = create_vllm_config()

mock_connector_config = {
"kv_connector": "MockConnector",
"kv_role": "kv_both",
"kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector",
}

vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [mock_connector_config, mock_connector_config],
},
)

kv_cache_config = KVCacheConfig(
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
)

mc = MultiConnector(
vllm_config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)

return mc


# Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content."""
Expand Down Expand Up @@ -715,24 +754,6 @@ def test_is_empty_with_multiple_connectors(self):
assert not stats.is_empty()


class TestMultiConnectorPreferCrossLayerBlocks:
def test_all_connectors_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
]
assert mc.prefer_cross_layer_blocks is True

def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False


def test_multi_connector_overrides_all_base_methods():
"""
Ensure MultiConnector overrides all public methods from KVConnectorBase_V1.
Expand Down Expand Up @@ -767,3 +788,133 @@ def test_multi_connector_overrides_all_base_methods():
1. Add delegation in MultiConnector (preferred)
2. Add to INHERITED_OK if the base implementation works correctly
""")


def test_multi_connector_prefer_cross_layer_blocks(mc):
mc._connectors[0].prefer_cross_layer_blocks = False
mc._connectors[1].prefer_cross_layer_blocks = True
assert mc.prefer_cross_layer_blocks is False

mc._connectors[0].prefer_cross_layer_blocks = True
mc._connectors[1].prefer_cross_layer_blocks = True
assert mc.prefer_cross_layer_blocks is True


def test_multi_connector_worker_metadata(mc):
class MockConnectorWorkerMetadata(KVConnectorWorkerMetadata):
def __init__(self, data: set[str]):
self.data = data

class MockConnectorWorkerMetadata0(MockConnectorWorkerMetadata):
def aggregate(
self, other: KVConnectorWorkerMetadata
) -> KVConnectorWorkerMetadata:
assert isinstance(other, MockConnectorWorkerMetadata)
return MockConnectorWorkerMetadata0(data=self.data | other.data)

class MockConnectorWorkerMetadata1(MockConnectorWorkerMetadata):
def aggregate(
self, other: KVConnectorWorkerMetadata
) -> KVConnectorWorkerMetadata:
assert isinstance(other, MockConnectorWorkerMetadata)
return MockConnectorWorkerMetadata1(data=self.data | other.data)

# -------------------- test build_worker_connector_meta -------------------

# both connectors return None
mc._connectors[0].build_connector_worker_meta.return_value = None
mc._connectors[1].build_connector_worker_meta.return_value = None
assert mc.build_connector_worker_meta() is None

# only first connector returns None
worker_meta1a = MockConnectorWorkerMetadata1({"1a"})
mc._connectors[0].build_connector_worker_meta.return_value = None
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1a
mc_worker_meta_none_1a = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_none_1a, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_none_1a.metadata == (None, worker_meta1a)

# only second connector returns None
worker_meta0a = MockConnectorWorkerMetadata0({"0a"})
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0a
mc._connectors[1].build_connector_worker_meta.return_value = None
mc_worker_meta_0a_none = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_0a_none, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_none.metadata == (worker_meta0a, None)

# both connectors do not return None
worker_meta0b = MockConnectorWorkerMetadata0({"0b"})
worker_meta1b = MockConnectorWorkerMetadata1({"1b"})
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
mc._connectors[1].build_connector_worker_meta.return_value = worker_meta1b
mc_worker_meta_0b_1b = mc.build_connector_worker_meta()
assert isinstance(mc_worker_meta_0b_1b, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0b_1b.metadata == (worker_meta0b, worker_meta1b)

# ----------------------------- test aggregate ----------------------------

# aggregate ({"0a"}, None) and (None, {"1a"}) -> ({"0a"}, {"1a"})
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit missing cae

aggregate (None, {"1a"}) and (None, {"1b"}) -> (None, {. . .} )

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've now added:

# aggregate ({"0a"}, None) and ({"0b"}, None) -> ({"0a", "0b"}, None)

mc_worker_meta_0a_1a = mc_worker_meta_0a_none.aggregate(mc_worker_meta_none_1a)
assert isinstance(mc_worker_meta_0a_1a, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_1a.metadata == (worker_meta0a, worker_meta1a)

# aggregate ({"0a"}, None) and ({"0b"}, None) -> ({"0a", "0b"}, None)
mc._connectors[0].build_connector_worker_meta.return_value = worker_meta0b
mc._connectors[1].build_connector_worker_meta.return_value = None
mc_worker_meta_0b_none = mc.build_connector_worker_meta()
mc_worker_meta_0a_0b = mc_worker_meta_0a_none.aggregate(mc_worker_meta_0b_none)
assert isinstance(mc_worker_meta_0a_0b, MultiKVConnectorWorkerMetadata)
assert mc_worker_meta_0a_0b.metadata[1] is None
connector0_md = mc_worker_meta_0a_0b.metadata[0]
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
assert connector0_md.data == {"0a", "0b"}

# aggregate ({"0a"}, {"1a"}) and ({"0b"}, {"1b"}) -> ({"0a", "0b"}, {"1a", "1b"})
mc_worker_meta_01a_01b = mc_worker_meta_0a_1a.aggregate(mc_worker_meta_0b_1b)
assert isinstance(mc_worker_meta_01a_01b, MultiKVConnectorWorkerMetadata)
metadata = mc_worker_meta_01a_01b.metadata
assert len(metadata) == 2
connector0_md, connector1_md = metadata
assert isinstance(connector0_md, MockConnectorWorkerMetadata0)
assert isinstance(connector1_md, MockConnectorWorkerMetadata1)
assert connector0_md.data == {"0a", "0b"}
assert connector1_md.data == {"1a", "1b"}

# ---------------------- test update_connector_output ---------------------

def verify_worker_metadata(expected_metadata: MockConnectorWorkerMetadata | None):
def _verify_worker_metadata(connector_output: KVConnectorOutput):
worker_meta = connector_output.kv_connector_worker_meta
if expected_metadata is None:
assert worker_meta is None
return

assert isinstance(worker_meta, MockConnectorWorkerMetadata)
assert type(worker_meta) is type(expected_metadata)
assert expected_metadata.data == worker_meta.data

return _verify_worker_metadata

def assert_update_connector_output_called(mc: MultiConnector):
for c in mc._connectors:
c.update_connector_output.assert_called_once()
c.update_connector_output.reset_mock()

# no worker meta
kv_connector_output = KVConnectorOutput()
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(None)
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(None)
mc.update_connector_output(kv_connector_output)
assert_update_connector_output_called(mc)

# multi worker meta
kv_connector_output.kv_connector_worker_meta = mc_worker_meta_01a_01b
mc._connectors[0].update_connector_output.side_effect = verify_worker_metadata(
connector0_md
)
mc._connectors[1].update_connector_output.side_effect = verify_worker_metadata(
connector1_md
)
mc.update_connector_output(kv_connector_output)
assert_update_connector_output_called(mc)
assert kv_connector_output.kv_connector_worker_meta == mc_worker_meta_01a_01b
13 changes: 13 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def update_finished_set(
finished_sending = set[str]()
finished_recving = set[str]()
aggregated_kv_connector_stats = None
aggregated_kv_connector_worker_meta = None
combined_kv_cache_events = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
Expand Down Expand Up @@ -127,6 +128,17 @@ def update_finished_set(
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
)

# Aggregate kv_connector_worker_meta from all workers.
if aggregated_kv_connector_worker_meta is None:
# Use the first worker's kv_connector_worker_meta as accumulator.
aggregated_kv_connector_worker_meta = kv_output.kv_connector_worker_meta
elif kv_connector_worker_meta := kv_output.kv_connector_worker_meta:
aggregated_kv_connector_worker_meta = (
aggregated_kv_connector_worker_meta.aggregate(
kv_connector_worker_meta
)
)

# Combine kv_cache_events from all workers.
if combined_kv_cache_events is None:
# Use the first worker's kv_cache events as start event list.
Expand All @@ -151,6 +163,7 @@ def update_finished_set(
finished_recving=finished_recving or None,
kv_connector_stats=aggregated_kv_connector_stats or None,
kv_cache_events=combined_kv_cache_events or None,
kv_connector_worker_meta=aggregated_kv_connector_worker_meta or None,
invalid_block_ids=invalid_block_ids,
expected_finished_count=self._expected_finished_count,
)
Expand Down
37 changes: 35 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

get_finished() - called with ids of finished requests, returns
ids of requests that have completed async sending/recving.
build_connector_worker_meta() - builds metadata to be sent
back to the scheduler-side connector
"""

import enum
Expand Down Expand Up @@ -137,13 +139,34 @@ class KVConnectorHandshakeMetadata(ABC): # noqa: B024

class KVConnectorMetadata(ABC): # noqa: B024
"""
Abstract Metadata used to communicate between the
Scheduler KVConnector and Worker KVConnector.
Abstract Metadata used to communicate
Scheduler KVConnector -> Worker KVConnector.
"""

pass


class KVConnectorWorkerMetadata(ABC):
"""
Abstract Metadata used to communicate back
Worker KVConnector -> Scheduler KVConnector.

Each worker can output its own metadata.
For a single engine step, all metadata objects returned by workers
will be aggregated using the `aggregate` method below, before
being passed to the Scheduler KVConnector.
"""

@abstractmethod
def aggregate(
self, other: "KVConnectorWorkerMetadata"
) -> "KVConnectorWorkerMetadata":
"""
Aggregate metadata with another `KVConnectorWorkerMetadata` object.
"""
pass


class KVConnectorBase_V1(ABC):
"""
Base class for KV connectors.
Expand Down Expand Up @@ -409,6 +432,16 @@ def get_handshake_metadata(self) -> KVConnectorHandshakeMetadata | None:
"""
return None

def build_connector_worker_meta(self) -> KVConnectorWorkerMetadata | None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this read better as get_connector_worker_meta(self)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought to follow the convention being used for the other direction (build_connector_meta).

"""
Build the KVConnector worker metadata for this engine step.

Returns:
KVConnectorWorkerMetadata: the worker metadata.
None if no worker metadata is available.
"""
return None

# ==============================
# Scheduler-side methods
# ==============================
Expand Down
Loading
Loading