Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 151 additions & 9 deletions tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from unittest.mock import MagicMock

import pytest

Expand All @@ -10,8 +11,16 @@
)
from tests.v1.kv_connector.unit.utils import EOS_TOKEN_ID
from vllm.distributed.kv_events import BlockRemoved, BlockStored
from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import (
OffloadingConnectorScheduler,
)
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.kv_offload.abstract import OffloadingEvent
from vllm.v1.kv_offload.abstract import (
OffloadingEvent,
OffloadingManager,
ReqContext,
get_offload_block_hash,
)
from vllm.v1.request import RequestStatus


Expand Down Expand Up @@ -105,16 +114,15 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
lambda keys, req_context: generate_store_output([])
)
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.manager.lookup.assert_called()
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
runner.manager.lookup.assert_called_once()

# single block lookup with a hit
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.manager.lookup.return_value = 1
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2)
)
Expand All @@ -126,7 +134,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool):
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output([])
)
runner.manager.lookup.return_value = 1
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run(
decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5)
)
Expand Down Expand Up @@ -210,7 +218,7 @@ def test_request_preemption(request_runner, async_scheduling: bool):

# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 3
runner.manager.prepare_store.side_effect = (
lambda keys, req_context: generate_store_output(keys)
)
Expand Down Expand Up @@ -251,7 +259,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:
# start a request to load the first block, but don't complete
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
Expand All @@ -263,7 +271,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling:

# start a new request to load the same first block
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
Expand Down Expand Up @@ -311,7 +319,7 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):
# start a request to load the first block, but don't complete
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
Expand All @@ -336,3 +344,137 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool):

# assert request is deleted
assert req_id not in runner.scheduler.requests


# ---------------------------------------------------------------------------
# Unit tests for _maximal_prefix_lookup / _sliding_window_lookup
# ---------------------------------------------------------------------------


def _make_scheduler_with_lookup(
lookup_results: dict[int, bool | None],
) -> OffloadingConnectorScheduler:
"""Create an OffloadingConnectorScheduler with a mocked manager.lookup."""
manager = MagicMock(spec=OffloadingManager)
manager.lookup.side_effect = lambda key, req_context: lookup_results.get(
int(get_offload_block_hash(key).decode()), False
)

scheduler = object.__new__(OffloadingConnectorScheduler)
scheduler.manager = manager
return scheduler


_EMPTY_REQ_CTX = ReqContext()


class TestMaximalPrefixLookup:
def test_all_hit(self):
sched = _make_scheduler_with_lookup({1: True, 2: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2

def test_all_miss(self):
sched = _make_scheduler_with_lookup({})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0

def test_partial_prefix(self):
sched = _make_scheduler_with_lookup({1: True, 2: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2

def test_miss_then_hit(self):
sched = _make_scheduler_with_lookup({2: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0

def test_single_hit(self):
sched = _make_scheduler_with_lookup({1: True})
assert sched._maximal_prefix_lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1

def test_empty(self):
sched = _make_scheduler_with_lookup({})
assert sched._maximal_prefix_lookup([], _EMPTY_REQ_CTX) == 0

def test_none_defers(self):
sched = _make_scheduler_with_lookup({1: None, 2: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) is None

def test_none_after_hit_defers(self):
sched = _make_scheduler_with_lookup({1: True, 2: None})
assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) is None

def test_none_stops_at_miss(self):
"""None is treated as hit for iteration, but miss stops the scan."""
sched = _make_scheduler_with_lookup({1: None, 2: False, 3: True})
assert sched._maximal_prefix_lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) is None
# lookup should have been called for blocks 1 and 2 (stops at miss)
assert sched.manager.lookup.call_count == 2


class TestSlidingWindowLookup:
def test_all_hit_exact_window(self):
sched = _make_scheduler_with_lookup({1: True, 2: True})
assert sched._sliding_window_lookup(to_keys([1, 2]), 2, _EMPTY_REQ_CTX) == 2

def test_all_miss(self):
sched = _make_scheduler_with_lookup({})
assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 1, _EMPTY_REQ_CTX) == 0

def test_window_at_end(self):
sched = _make_scheduler_with_lookup({2: True, 3: True})
assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 2, _EMPTY_REQ_CTX) == 3

def test_window_in_middle(self):
sched = _make_scheduler_with_lookup({2: True, 3: True})
assert (
sched._sliding_window_lookup(to_keys([1, 2, 3, 4]), 2, _EMPTY_REQ_CTX) == 3
)

def test_no_full_window_falls_back_to_prefix(self):
sched = _make_scheduler_with_lookup({1: True, 2: True})
assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 3, _EMPTY_REQ_CTX) == 2

def test_single_block_window(self):
sched = _make_scheduler_with_lookup({2: True, 3: True})
assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 1, _EMPTY_REQ_CTX) == 3

def test_gap_resets_consecutive(self):
sched = _make_scheduler_with_lookup({2: True, 3: True, 4: True})
# [1, 2, 3, 0, 4] — gap at 0 resets, window of 2 found at [2,3]
assert (
sched._sliding_window_lookup(to_keys([1, 2, 3, 0, 4]), 2, _EMPTY_REQ_CTX)
== 3
)

def test_window_prefers_rightmost(self):
sched = _make_scheduler_with_lookup({1: True, 2: True, 4: True, 5: True})
# two valid windows: [1,2] at positions 0-1 and [4,5] at positions 3-4
# scans right-to-left, finds [4,5] first
assert (
sched._sliding_window_lookup(to_keys([1, 2, 3, 4, 5]), 2, _EMPTY_REQ_CTX)
== 5
)

def test_prefix_fallback_with_gap(self):
sched = _make_scheduler_with_lookup({2: True, 3: True, 4: True, 5: True})
# window of 4 not found contiguously (gap at 1)
assert (
sched._sliding_window_lookup(to_keys([2, 1, 3, 4, 5]), 4, _EMPTY_REQ_CTX)
== 1
)

def test_empty(self):
sched = _make_scheduler_with_lookup({})
assert sched._sliding_window_lookup([], 1, _EMPTY_REQ_CTX) == 0

def test_none_defers(self):
sched = _make_scheduler_with_lookup({1: True, 2: None})
assert sched._sliding_window_lookup(to_keys([1, 2]), 2, _EMPTY_REQ_CTX) is None

def test_none_with_full_window_still_defers(self):
"""Even if a real window is found after a None, result is deferred."""
# Scan right-to-left: 4(True), 3(None) resets, 2(True), 1(True) = window
# but block 3 was None so defer_lookup is set
sched = _make_scheduler_with_lookup({1: True, 2: True, 3: None, 4: True})
assert (
sched._sliding_window_lookup(to_keys([1, 2, 3, 4]), 2, _EMPTY_REQ_CTX)
is None
)
19 changes: 12 additions & 7 deletions tests/v1/kv_connector/unit/offloading_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@
from vllm.v1.structured_output import StructuredOutputManager


def to_keys(int_ids: list[int]) -> list[OffloadKey]:
return [make_offload_key(str(i).encode(), 0) for i in int_ids]
def to_key(int_hash: int) -> OffloadKey:
return make_offload_key(str(int_hash).encode(), 0)


def to_keys(int_hashes: list[int]) -> list[OffloadKey]:
return [to_key(i) for i in int_hashes]


class MockLoadStoreSpec(LoadStoreSpec):
Expand Down Expand Up @@ -116,6 +120,7 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig):
self.manager = MagicMock(spec=OffloadingManager)
self.manager.lookup.return_value = 0
self.manager.prepare_load = lambda keys, req_context: MockLoadStoreSpec(keys)
self.manager.lookup.return_value = False
self.handler = MockOffloadingHandler()

def get_manager(self) -> OffloadingManager:
Expand Down Expand Up @@ -228,14 +233,14 @@ def __init__(
self.scheduler_connector: OffloadingConnector = scheduler_connector

# extract mocked OffloadingManager of scheduler connector
connector_scheduler = scheduler_connector.connector_scheduler
assert connector_scheduler is not None
manager = connector_scheduler.manager
self.connector_scheduler = scheduler_connector.connector_scheduler
assert self.connector_scheduler is not None
manager = self.connector_scheduler.manager
assert isinstance(manager, MagicMock)
self.manager: MagicMock = manager

assert len(connector_scheduler.config.kv_group_configs) == 1
kv_group_config = connector_scheduler.config.kv_group_configs[0]
assert len(self.connector_scheduler.config.kv_group_configs) == 1
kv_group_config = self.connector_scheduler.config.kv_group_configs[0]
assert kv_group_config.gpu_block_size == gpu_block_size
assert kv_group_config.offloaded_block_size == offloaded_block_size

Expand Down
Loading
Loading