Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch

from tests.v1.kv_connector.unit.utils import MockKVConnector
from vllm.config import (
CacheConfig,
ECTransferConfig,
Expand All @@ -15,6 +16,7 @@
SpeculativeConfig,
VllmConfig,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorSchedulerOutput
from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalKwargsItem,
Expand All @@ -31,7 +33,12 @@
KVCacheConfig,
KVCacheGroupSpec,
)
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
DraftTokenIds,
KVConnectorOutput,
ModelRunnerOutput,
)
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager

Expand Down Expand Up @@ -1415,6 +1422,110 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role):
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1


def test_kv_connector_lock_blocks():
"""
Test a KV connector locking (holding back from eviction) GPU blocks.
"""
block_size = 16
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=mock_kv(matched_tokens=0, is_async=False),
block_size=block_size,
)
connector = scheduler.connector
assert isinstance(connector, MockKVConnector)

kv_cache_manager = scheduler.kv_cache_manager
block_pool = kv_cache_manager.block_pool
free_block_queue = block_pool.free_block_queue
num_blocks = free_block_queue.num_free_blocks

# single request with 3 blocks + 4 decoded tokens
request = create_requests(
num_requests=1,
num_tokens=3 * block_size,
max_tokens=4,
block_size=block_size,
)[0]
scheduler.add_request(request)

# decoded token #1, no blocks locked/unlocked
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.num_tokens == 3 * block_size + 1

# extract request block IDs
req_block_id_groups = kv_cache_manager.get_block_ids(request.request_id)
assert len(req_block_id_groups) == 1
req_block_ids = req_block_id_groups[0]

# assert that all request blocks have ref_cnt == 1
req_blocks = [block_pool.blocks[block_id] for block_id in req_block_ids]
assert [block.ref_cnt for block in req_blocks] == [1, 1, 1]

# decoded token #2, block #0 locked once, block #2 locked twice
scheduler_output = scheduler.schedule()
connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput(
block_ids_to_lock=[req_block_ids[2], req_block_ids[0], req_block_ids[2]]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.num_tokens == 3 * block_size + 2
assert [block.ref_cnt for block in req_blocks] == [2, 1, 3]

# decoded token #3, block #1 locked three times, block #1 unlocked once
scheduler_output = scheduler.schedule()
connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput(
block_ids_to_lock=[req_block_ids[1], req_block_ids[1], req_block_ids[1]],
block_ids_to_unlock=[req_block_ids[1]],
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert request.num_tokens == 3 * block_size + 3
assert [block.ref_cnt for block in req_blocks] == [2, 3, 3]

# decoded token #4 (last), block #2 unlocked twice, request is freed
scheduler_output = scheduler.schedule()
connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput(
block_ids_to_unlock=[req_block_ids[1], req_block_ids[1]]
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert not scheduler.running
assert not scheduler.waiting
assert request.num_tokens == 3 * block_size + 4
assert [block.ref_cnt for block in req_blocks] == [1, 0, 2]
assert scheduler.has_work()

# step with no KV connector output
scheduler_output = scheduler.schedule()
connector.kv_connector_scheduler_output = None
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert [block.ref_cnt for block in req_blocks] == [1, 0, 2]
assert free_block_queue.num_free_blocks == num_blocks - 2
assert not scheduler.has_finished_requests()
assert not scheduler.has_unfinished_requests()
assert scheduler.has_work()

# block #0 unlocked once, block #2 unlocked once
scheduler_output = scheduler.schedule()
connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput(
block_ids_to_unlock=[req_block_ids[0], req_block_ids[2]]
)
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert [block.ref_cnt for block in req_blocks] == [0, 0, 1]
assert free_block_queue.num_free_blocks == num_blocks - 1
assert scheduler.has_work()

# block #2 unlocked once
scheduler_output = scheduler.schedule()
connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput(
block_ids_to_unlock=[req_block_ids[2]]
)
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert [block.ref_cnt for block in req_blocks] == [0, 0, 0]
assert free_block_queue.num_free_blocks == num_blocks
assert not scheduler.has_work()


def make_output(scheduler: Scheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
Expand Down
6 changes: 3 additions & 3 deletions tests/v1/engine/test_engine_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,11 @@ def test_engine_core():
req0.request_id = req1.request_id = "test"
engine_core.add_request(*engine_core.preprocess_add_request(req0))

while engine_core.scheduler.has_requests():
while engine_core.scheduler.has_work():
engine_core.step_fn()

engine_core.add_request(*engine_core.preprocess_add_request(req1))
while engine_core.scheduler.has_requests():
while engine_core.scheduler.has_work():
engine_core.step_fn()

assert len(engine_core.scheduler.waiting) == 0
Expand Down Expand Up @@ -222,7 +222,7 @@ def _check_engine_state():
assert len(engine_core.scheduler.waiting) == 1
assert len(engine_core.scheduler.running) == 0
# Loop through until they are all done.
while engine_core.scheduler.has_requests():
while engine_core.scheduler.has_work():
engine_core.step_fn()
assert len(engine_core.scheduler.waiting) == 0
assert len(engine_core.scheduler.running) == 0
Expand Down
126 changes: 102 additions & 24 deletions tests/v1/kv_connector/unit/test_multi_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
import tempfile
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock

import pytest

from tests.v1.kv_connector.unit.utils import create_vllm_config
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorSchedulerOutput,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
MultiConnector,
Expand All @@ -21,6 +27,7 @@
NixlKVConnectorStats,
)
from vllm.platforms import current_platform
from vllm.v1.kv_cache_interface import KVCacheConfig

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"

Expand All @@ -41,7 +48,14 @@ class MockConnectorStats(KVConnectorStats):


class MockConnector(KVConnectorBase_V1):
"""Mock connector that implements build_kv_connector_stats for testing."""
"""Mock connector for testing."""

def __new__(cls, *args, **kwargs):
# mock all KVConnectorBase_V1 functions
mock = MagicMock(spec_set=KVConnectorBase_V1)
# Override just build_kv_connector_stats
mock.build_kv_connector_stats = cls.build_kv_connector_stats
return mock

@classmethod
def build_kv_connector_stats(
Expand Down Expand Up @@ -71,16 +85,42 @@ def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
pass


class MockCrossLayerConnector(MockConnector):
@property
def prefer_cross_layer_blocks(self) -> bool:
return True


# Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)


@pytest.fixture
def mc() -> MultiConnector:
"""MultiConnector using two mocked connectors"""
vllm_config = create_vllm_config()

mock_connector_config = {
"kv_connector": "MockConnector",
"kv_role": "kv_both",
"kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector",
}

vllm_config.kv_transfer_config = KVTransferConfig(
kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [mock_connector_config, mock_connector_config],
},
)

kv_cache_config = KVCacheConfig(
num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[]
)

mc = MultiConnector(
vllm_config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)

return mc


# Helper function to compare directories recursively
def _compare_directories(dir1: Path, dir2: Path) -> bool:
"""Compares two directories recursively for identical content."""
Expand Down Expand Up @@ -630,19 +670,57 @@ def test_is_empty_with_multiple_connectors(self):
assert not stats.is_empty()


class TestMultiConnectorPreferCrossLayerBlocks:
def test_all_connectors_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
]
assert mc.prefer_cross_layer_blocks is True

def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False
def test_multi_connector_prefer_cross_layer_blocks(mc):
mc._connectors[0].prefer_cross_layer_blocks = False
mc._connectors[1].prefer_cross_layer_blocks = True
assert mc.prefer_cross_layer_blocks is False

mc._connectors[0].prefer_cross_layer_blocks = True
mc._connectors[1].prefer_cross_layer_blocks = True
assert mc.prefer_cross_layer_blocks is True


def test_multi_connector_report_to_scheduler(mc):
# both return None
mc._connectors[0].report_to_scheduler.return_value = None
mc._connectors[1].report_to_scheduler.return_value = None
output = mc.report_to_scheduler()
assert output is None

# only first is None
kv_connector_scheduler_output = KVConnectorSchedulerOutput(
block_ids_to_lock=[1, 2, 3],
block_ids_to_unlock=[4, 5, 6],
)
mc._connectors[0].report_to_scheduler.return_value = None
mc._connectors[1].report_to_scheduler.return_value = kv_connector_scheduler_output

output = mc.report_to_scheduler()
assert output is not None
assert output.block_ids_to_lock == [1, 2, 3]
assert output.block_ids_to_unlock == [4, 5, 6]

# only second is None
kv_connector_scheduler_output = KVConnectorSchedulerOutput(
block_ids_to_lock=[1, 2, 3],
block_ids_to_unlock=[4, 5, 6],
)
mc._connectors[0].report_to_scheduler.return_value = kv_connector_scheduler_output
mc._connectors[1].report_to_scheduler.return_value = None
output = mc.report_to_scheduler()
assert output is not None
assert output.block_ids_to_lock == [1, 2, 3]
assert output.block_ids_to_unlock == [4, 5, 6]

# two outputs
kv_connector_scheduler_output2 = KVConnectorSchedulerOutput(
block_ids_to_lock=[7, 1, 8],
block_ids_to_unlock=[9, 2, 10],
)
mc._connectors[0].report_to_scheduler.return_value = kv_connector_scheduler_output
mc._connectors[1].report_to_scheduler.return_value = kv_connector_scheduler_output2

output = mc.report_to_scheduler()
assert output is not None
assert output.block_ids_to_lock == [1, 2, 3, 7, 1, 8]
assert output.block_ids_to_unlock == [4, 5, 6, 9, 2, 10]
5 changes: 5 additions & 0 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
KVConnectorBase_V1,
KVConnectorMetadata,
KVConnectorRole,
KVConnectorSchedulerOutput,
)
from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa
ExampleConnector,
Expand Down Expand Up @@ -358,6 +359,7 @@ def __init__(
matched_tokens=extra_config["matched_tokens"],
is_async=extra_config["is_async"],
)
self.kv_connector_scheduler_output: KVConnectorSchedulerOutput | None = None

def get_num_new_matched_tokens(
self,
Expand Down Expand Up @@ -402,6 +404,9 @@ def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
def wait_for_save(self):
pass

def report_to_scheduler(self) -> KVConnectorSchedulerOutput | None:
return self.kv_connector_scheduler_output


KVConnectorFactory.register_connector(
"TestExampleConnector", __name__, TestExampleConnector.__name__
Expand Down
Loading
Loading