diff --git a/tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py b/tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py index 26bd01b138f2..43d1fb94e709 100644 --- a/tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py +++ b/tests/v1/kv_connector/unit/offloading_connector/test_scheduler.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable +from unittest.mock import MagicMock import pytest @@ -10,8 +11,16 @@ ) from tests.v1.kv_connector.unit.utils import EOS_TOKEN_ID from vllm.distributed.kv_events import BlockRemoved, BlockStored +from vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler import ( + OffloadingConnectorScheduler, +) from vllm.v1.core.kv_cache_utils import BlockHash -from vllm.v1.kv_offload.abstract import OffloadingEvent +from vllm.v1.kv_offload.abstract import ( + OffloadingEvent, + OffloadingManager, + ReqContext, + get_offload_block_hash, +) from vllm.v1.request import RequestStatus @@ -105,8 +114,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool): lambda keys, req_context: generate_store_output([]) ) runner.run(decoded_tokens=[EOS_TOKEN_ID]) - runner.manager.lookup.assert_called() - assert len(list(runner.manager.lookup.call_args.args[0])) == 1 + runner.manager.lookup.assert_called_once() # single block lookup with a hit runner.scheduler.reset_prefix_cache() @@ -114,7 +122,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool): runner.manager.prepare_store.side_effect = ( lambda keys, req_context: generate_store_output([]) ) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run( decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(0, 1, 2) ) @@ -126,7 +134,7 @@ def test_offloading_connector(request_runner, async_scheduling: bool): runner.manager.prepare_store.side_effect = ( lambda keys, req_context: generate_store_output([]) ) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run( decoded_tokens=[EOS_TOKEN_ID], expected_loaded_gpu_block_indexes=(3, 4, 5) ) @@ -210,7 +218,7 @@ def test_request_preemption(request_runner, async_scheduling: bool): # request should now return from preemption # re-load [0, ..., 8] from the CPU and store [9, 10, 11] - runner.manager.lookup.return_value = 3 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 3 runner.manager.prepare_store.side_effect = ( lambda keys, req_context: generate_store_output(keys) ) @@ -251,7 +259,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: # start a request to load the first block, but don't complete runner.scheduler.reset_prefix_cache() runner.new_request(token_ids=[0] * offloaded_block_size) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run( decoded_tokens=[], complete_transfers=False, @@ -263,7 +271,7 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner, async_scheduling: # start a new request to load the same first block runner.new_request(token_ids=[0] * offloaded_block_size) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run( decoded_tokens=[], complete_transfers=False, @@ -311,7 +319,7 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool): # start a request to load the first block, but don't complete runner.scheduler.reset_prefix_cache() runner.new_request(token_ids=[0] * offloaded_block_size) - runner.manager.lookup.return_value = 1 + runner.connector_scheduler._maximal_prefix_lookup = lambda key, req_context: 1 runner.run( decoded_tokens=[], complete_transfers=False, @@ -336,3 +344,137 @@ def test_abort_loading_requests(request_runner, async_scheduling: bool): # assert request is deleted assert req_id not in runner.scheduler.requests + + +# --------------------------------------------------------------------------- +# Unit tests for _maximal_prefix_lookup / _sliding_window_lookup +# --------------------------------------------------------------------------- + + +def _make_scheduler_with_lookup( + lookup_results: dict[int, bool | None], +) -> OffloadingConnectorScheduler: + """Create an OffloadingConnectorScheduler with a mocked manager.lookup.""" + manager = MagicMock(spec=OffloadingManager) + manager.lookup.side_effect = lambda key, req_context: lookup_results.get( + int(get_offload_block_hash(key).decode()), False + ) + + scheduler = object.__new__(OffloadingConnectorScheduler) + scheduler.manager = manager + return scheduler + + +_EMPTY_REQ_CTX = ReqContext() + + +class TestMaximalPrefixLookup: + def test_all_hit(self): + sched = _make_scheduler_with_lookup({1: True, 2: True}) + assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2 + + def test_all_miss(self): + sched = _make_scheduler_with_lookup({}) + assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0 + + def test_partial_prefix(self): + sched = _make_scheduler_with_lookup({1: True, 2: True}) + assert sched._maximal_prefix_lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2 + + def test_miss_then_hit(self): + sched = _make_scheduler_with_lookup({2: True}) + assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0 + + def test_single_hit(self): + sched = _make_scheduler_with_lookup({1: True}) + assert sched._maximal_prefix_lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1 + + def test_empty(self): + sched = _make_scheduler_with_lookup({}) + assert sched._maximal_prefix_lookup([], _EMPTY_REQ_CTX) == 0 + + def test_none_defers(self): + sched = _make_scheduler_with_lookup({1: None, 2: True}) + assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) is None + + def test_none_after_hit_defers(self): + sched = _make_scheduler_with_lookup({1: True, 2: None}) + assert sched._maximal_prefix_lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) is None + + def test_none_stops_at_miss(self): + """None is treated as hit for iteration, but miss stops the scan.""" + sched = _make_scheduler_with_lookup({1: None, 2: False, 3: True}) + assert sched._maximal_prefix_lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) is None + # lookup should have been called for blocks 1 and 2 (stops at miss) + assert sched.manager.lookup.call_count == 2 + + +class TestSlidingWindowLookup: + def test_all_hit_exact_window(self): + sched = _make_scheduler_with_lookup({1: True, 2: True}) + assert sched._sliding_window_lookup(to_keys([1, 2]), 2, _EMPTY_REQ_CTX) == 2 + + def test_all_miss(self): + sched = _make_scheduler_with_lookup({}) + assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 1, _EMPTY_REQ_CTX) == 0 + + def test_window_at_end(self): + sched = _make_scheduler_with_lookup({2: True, 3: True}) + assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 2, _EMPTY_REQ_CTX) == 3 + + def test_window_in_middle(self): + sched = _make_scheduler_with_lookup({2: True, 3: True}) + assert ( + sched._sliding_window_lookup(to_keys([1, 2, 3, 4]), 2, _EMPTY_REQ_CTX) == 3 + ) + + def test_no_full_window_falls_back_to_prefix(self): + sched = _make_scheduler_with_lookup({1: True, 2: True}) + assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 3, _EMPTY_REQ_CTX) == 2 + + def test_single_block_window(self): + sched = _make_scheduler_with_lookup({2: True, 3: True}) + assert sched._sliding_window_lookup(to_keys([1, 2, 3]), 1, _EMPTY_REQ_CTX) == 3 + + def test_gap_resets_consecutive(self): + sched = _make_scheduler_with_lookup({2: True, 3: True, 4: True}) + # [1, 2, 3, 0, 4] — gap at 0 resets, window of 2 found at [2,3] + assert ( + sched._sliding_window_lookup(to_keys([1, 2, 3, 0, 4]), 2, _EMPTY_REQ_CTX) + == 3 + ) + + def test_window_prefers_rightmost(self): + sched = _make_scheduler_with_lookup({1: True, 2: True, 4: True, 5: True}) + # two valid windows: [1,2] at positions 0-1 and [4,5] at positions 3-4 + # scans right-to-left, finds [4,5] first + assert ( + sched._sliding_window_lookup(to_keys([1, 2, 3, 4, 5]), 2, _EMPTY_REQ_CTX) + == 5 + ) + + def test_prefix_fallback_with_gap(self): + sched = _make_scheduler_with_lookup({2: True, 3: True, 4: True, 5: True}) + # window of 4 not found contiguously (gap at 1) + assert ( + sched._sliding_window_lookup(to_keys([2, 1, 3, 4, 5]), 4, _EMPTY_REQ_CTX) + == 1 + ) + + def test_empty(self): + sched = _make_scheduler_with_lookup({}) + assert sched._sliding_window_lookup([], 1, _EMPTY_REQ_CTX) == 0 + + def test_none_defers(self): + sched = _make_scheduler_with_lookup({1: True, 2: None}) + assert sched._sliding_window_lookup(to_keys([1, 2]), 2, _EMPTY_REQ_CTX) is None + + def test_none_with_full_window_still_defers(self): + """Even if a real window is found after a None, result is deferred.""" + # Scan right-to-left: 4(True), 3(None) resets, 2(True), 1(True) = window + # but block 3 was None so defer_lookup is set + sched = _make_scheduler_with_lookup({1: True, 2: True, 3: None, 4: True}) + assert ( + sched._sliding_window_lookup(to_keys([1, 2, 3, 4]), 2, _EMPTY_REQ_CTX) + is None + ) diff --git a/tests/v1/kv_connector/unit/offloading_connector/utils.py b/tests/v1/kv_connector/unit/offloading_connector/utils.py index aaf9152a43cf..0888c0615367 100644 --- a/tests/v1/kv_connector/unit/offloading_connector/utils.py +++ b/tests/v1/kv_connector/unit/offloading_connector/utils.py @@ -56,8 +56,12 @@ from vllm.v1.structured_output import StructuredOutputManager -def to_keys(int_ids: list[int]) -> list[OffloadKey]: - return [make_offload_key(str(i).encode(), 0) for i in int_ids] +def to_key(int_hash: int) -> OffloadKey: + return make_offload_key(str(int_hash).encode(), 0) + + +def to_keys(int_hashes: list[int]) -> list[OffloadKey]: + return [to_key(i) for i in int_hashes] class MockLoadStoreSpec(LoadStoreSpec): @@ -116,6 +120,7 @@ def __init__(self, vllm_config: VllmConfig, kv_cache_config: KVCacheConfig): self.manager = MagicMock(spec=OffloadingManager) self.manager.lookup.return_value = 0 self.manager.prepare_load = lambda keys, req_context: MockLoadStoreSpec(keys) + self.manager.lookup.return_value = False self.handler = MockOffloadingHandler() def get_manager(self) -> OffloadingManager: @@ -228,14 +233,14 @@ def __init__( self.scheduler_connector: OffloadingConnector = scheduler_connector # extract mocked OffloadingManager of scheduler connector - connector_scheduler = scheduler_connector.connector_scheduler - assert connector_scheduler is not None - manager = connector_scheduler.manager + self.connector_scheduler = scheduler_connector.connector_scheduler + assert self.connector_scheduler is not None + manager = self.connector_scheduler.manager assert isinstance(manager, MagicMock) self.manager: MagicMock = manager - assert len(connector_scheduler.config.kv_group_configs) == 1 - kv_group_config = connector_scheduler.config.kv_group_configs[0] + assert len(self.connector_scheduler.config.kv_group_configs) == 1 + kv_group_config = self.connector_scheduler.config.kv_group_configs[0] assert kv_group_config.gpu_block_size == gpu_block_size assert kv_group_config.offloaded_block_size == offloaded_block_size diff --git a/tests/v1/kv_offload/test_cpu_manager.py b/tests/v1/kv_offload/test_cpu_manager.py index 651ea091ae61..733f9bf519e5 100644 --- a/tests/v1/kv_offload/test_cpu_manager.py +++ b/tests/v1/kv_offload/test_cpu_manager.py @@ -35,8 +35,12 @@ class ExpectedPrepareStoreOutput: evicted_keys: list[int] -def to_keys(int_ids: list[int]) -> list[OffloadKey]: - return [make_offload_key(str(i).encode(), 0) for i in int_ids] +def to_key(int_hash: int) -> OffloadKey: + return make_offload_key(str(int_hash).encode(), 0) + + +def to_keys(int_hashes: list[int]) -> list[OffloadKey]: + return [to_key(i) for i in int_hashes] def verify_store_output( @@ -136,7 +140,7 @@ def test_already_stored_block_not_evicted_during_prepare_store(eviction_policy): manager.complete_store(to_keys([2, 3, 4, 5])) # block 2 must still be present in the cache - assert manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 1 + assert manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True def test_cpu_manager(): @@ -160,7 +164,8 @@ def test_cpu_manager(): ) # lookup [1, 2] -> not ready - assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0 + assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False + assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is False # no events so far assert list(cpu_manager.take_events()) == [] @@ -170,9 +175,9 @@ def test_cpu_manager(): verify_events(cpu_manager.take_events(), expected_stores=({1, 2},)) # lookup [1, 2] - assert cpu_manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1 - assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2 - assert cpu_manager.lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2 + assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(3), _EMPTY_REQ_CTX) is False # prepare store [2, 3, 4, 5] -> evicts [1] prepare_store_output = cpu_manager.prepare_store( @@ -196,6 +201,14 @@ def test_cpu_manager(): # complete store [2, 3, 4, 5] cpu_manager.complete_store(to_keys([2, 3, 4, 5])) + # lookup (now that we have [2, 3, 4, 5]) + assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False + assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(3), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(4), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(5), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(0), _EMPTY_REQ_CTX) is False + # prepare load [2, 3] prepare_load_output = cpu_manager.prepare_load(to_keys([2, 3]), _EMPTY_REQ_CTX) verify_load_output(prepare_load_output, [1, 2]) @@ -238,8 +251,8 @@ def test_cpu_manager(): cpu_manager.complete_store(to_keys([7, 9]), success=False) # assert [7] is still stored, but [9] is not - assert cpu_manager.lookup(to_keys([7]), _EMPTY_REQ_CTX) == 1 - assert cpu_manager.lookup(to_keys([9]), _EMPTY_REQ_CTX) == 0 + assert cpu_manager.lookup(to_key(7), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(9), _EMPTY_REQ_CTX) is False verify_events( cpu_manager.take_events(), @@ -284,7 +297,8 @@ def test_basic(self): ) # lookup [1, 2] -> not ready - assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0 + assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False + assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is False # no events so far assert list(cpu_manager.take_events()) == [] @@ -294,9 +308,9 @@ def test_basic(self): verify_events(cpu_manager.take_events(), expected_stores=({1, 2},)) # lookup [1, 2] - assert cpu_manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 1 - assert cpu_manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 2 - assert cpu_manager.lookup(to_keys([1, 2, 3]), _EMPTY_REQ_CTX) == 2 + assert cpu_manager.lookup(to_key(1), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(3), _EMPTY_REQ_CTX) is False # blocks should be in T1 (recent) assert len(arc_policy.t1) == 2 @@ -500,7 +514,7 @@ def test_failed_store(self): cpu_manager.complete_store(to_keys([5]), success=False) # block 5 should not be in cache - assert cpu_manager.lookup(to_keys([5]), _EMPTY_REQ_CTX) == 0 + assert cpu_manager.lookup(to_key(5), _EMPTY_REQ_CTX) is False # block 5 should not be in T1 or T2 assert to_keys([5])[0] not in arc_policy.t1 assert to_keys([5])[0] not in arc_policy.t2 @@ -541,8 +555,8 @@ def test_full_scenario(self): cpu_manager.complete_store(to_keys([6])) # verify blocks 2, 3 (in T2) are still present - assert cpu_manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 1 - assert cpu_manager.lookup(to_keys([3]), _EMPTY_REQ_CTX) == 1 + assert cpu_manager.lookup(to_key(2), _EMPTY_REQ_CTX) is True + assert cpu_manager.lookup(to_key(3), _EMPTY_REQ_CTX) is True # verify events events = list(cpu_manager.take_events()) @@ -562,7 +576,8 @@ def test_filter_reused_manager(): ) # Lookup [1, 2] -> 1st time, added to tracker but not eligible for store yet - assert manager.lookup(to_keys([1, 2]), _EMPTY_REQ_CTX) == 0 + assert manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False + assert manager.lookup(to_key(2), _EMPTY_REQ_CTX) is False # prepare store [1, 2] -> should be filtered prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX) @@ -570,7 +585,7 @@ def test_filter_reused_manager(): assert prepare_store_output.keys_to_store == [] # Lookup [1] -> 2nd time, eligible now - assert manager.lookup(to_keys([1]), _EMPTY_REQ_CTX) == 0 + assert manager.lookup(to_key(1), _EMPTY_REQ_CTX) is False # prepare store [1, 2] -> [1] should be eligible, [2] should be filtered prepare_store_output = manager.prepare_store(to_keys([1, 2]), _EMPTY_REQ_CTX) @@ -579,12 +594,13 @@ def test_filter_reused_manager(): # Lookup [3, 4] -> 1st time # (evicts [2] from tracker since max_size is 3 and tracker has [1]) - assert manager.lookup(to_keys([3, 4]), _EMPTY_REQ_CTX) == 0 + assert manager.lookup(to_key(3), _EMPTY_REQ_CTX) is False + assert manager.lookup(to_key(4), _EMPTY_REQ_CTX) is False # Verify [2] was evicted from the tracker (tracker now has: [1], [3], [4]) assert to_keys([2])[0] not in manager.counts # Lookup [2] again -> (this adds [2] back to the tracker as 1st time) - assert manager.lookup(to_keys([2]), _EMPTY_REQ_CTX) == 0 + assert manager.lookup(to_key(2), _EMPTY_REQ_CTX) is False # Verify [2] was re-added with count=1 (not eligible yet) assert manager.counts.get(to_keys([2])[0]) == 1 diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index c5272ea2778e..cd5a4f113dc2 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import defaultdict -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from itertools import islice from typing import Any, NamedTuple @@ -132,6 +132,49 @@ def __init__(self, spec: OffloadingSpec): self._reqs_being_stored = defaultdict[ReqId, set[OffloadKey]](set) self._reqs_being_loaded = defaultdict[ReqId, set[OffloadKey]](set) + def _maximal_prefix_lookup( + self, keys: Iterable[OffloadKey], req_context: ReqContext + ) -> int | None: + """Find the length of the maximal prefix of offloaded blocks.""" + hit_count = 0 + defer_lookup = False + for key in keys: + result = self.manager.lookup(key, req_context) + if result is None: + defer_lookup = True + # continue lookup to allow manager to kick-off async lookups + # for all blocks (until a miss is detected) + result = True + if not result: + break + hit_count += 1 + return hit_count if not defer_lookup else None + + def _sliding_window_lookup( + self, + keys: Sequence[OffloadKey], + sliding_window_size: int, + req_context: ReqContext, + ) -> int | None: + """Find the maximal ending position of consecutive offloaded blocks + within a sliding window.""" + defer_lookup = False + consecutive_hits = 0 + for idx in range(len(keys) - 1, -1, -1): + result = self.manager.lookup(keys[idx], req_context) + if result is None: + defer_lookup = True + # continue lookup to allow manager to kick-off async lookups + # for all blocks (until a hit is detected) + result = False + if not result: + consecutive_hits = 0 + else: + consecutive_hits += 1 + if consecutive_hits == sliding_window_size: + return idx + sliding_window_size if not defer_lookup else None + return consecutive_hits if not defer_lookup else None + def get_num_new_matched_tokens( self, request: Request, num_computed_tokens: int ) -> tuple[int | None, bool]: @@ -184,9 +227,10 @@ def get_num_new_matched_tokens( return 0, False start_block_idx = num_computed_tokens // group_config.offloaded_block_size - hits = self.manager.lookup( - offload_keys[start_block_idx:], - req_status.req_context, + # Full attention relays on all previous KV cache blocks. + # Thus, we search for a maximal prefix of KV cache which are all cached. + hits = self._maximal_prefix_lookup( + offload_keys[start_block_idx:], req_status.req_context ) if hits is None: # indicates a lookup that should be tried later diff --git a/vllm/v1/kv_offload/abstract.py b/vllm/v1/kv_offload/abstract.py index 39ee360e8f68..8f809ceaa08a 100644 --- a/vllm/v1/kv_offload/abstract.py +++ b/vllm/v1/kv_offload/abstract.py @@ -7,8 +7,7 @@ and their address. The class provides the following primitives: - lookup() - find the length of the maximal series of blocks, - starting from the first one, that are all offloaded. + lookup() - check whether a single block is offloaded and ready. prepare_load() - prepare given blocks to be read. The given blocks will be protected from eviction. This function returns a LoadSpec which encapsulates @@ -91,23 +90,18 @@ class OffloadingEvent: class OffloadingManager(ABC): @abstractmethod - def lookup( - self, - keys: Iterable[OffloadKey], - req_context: ReqContext, - ) -> int | None: + def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: """ - Finds the length of the maximal series of blocks, starting from the - first one, that are all offloaded. + Checks whether a single block is offloaded and ready to be read. Args: - keys: the keys identifying the blocks to lookup. + key: the key identifying the block to lookup. req_context: per-request context (e.g. kv_transfer_params). Returns: - An integer representing the maximal number of blocks that - are currently offloaded, or None if the lookup should be retried - later. Returning None will delay the request handling by the vLLM + True if the block is offloaded and ready, False if not, + or None if the lookup should be retried later. + Returning None will delay the request handling by the vLLM scheduler. """ pass diff --git a/vllm/v1/kv_offload/cpu/manager.py b/vllm/v1/kv_offload/cpu/manager.py index 5ae7454430f3..fcfaa919a3b3 100644 --- a/vllm/v1/kv_offload/cpu/manager.py +++ b/vllm/v1/kv_offload/cpu/manager.py @@ -84,18 +84,9 @@ def _get_load_store_spec( # --- OffloadingManager interface --- - def lookup( - self, - keys: Iterable[OffloadKey], - req_context: ReqContext, - ) -> int | None: - hit_count = 0 - for key in keys: - block = self._policy.get(key) - if block is None or not block.is_ready: - break - hit_count += 1 - return hit_count + def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: + block = self._policy.get(key) + return block is not None and block.is_ready def prepare_load( self, diff --git a/vllm/v1/kv_offload/reuse_manager.py b/vllm/v1/kv_offload/reuse_manager.py index a9650e38c51b..96b8f969e758 100644 --- a/vllm/v1/kv_offload/reuse_manager.py +++ b/vllm/v1/kv_offload/reuse_manager.py @@ -27,8 +27,9 @@ class FilterReusedOffloadingManager(OffloadingManager): All methods are delegated to the *backing* manager. Two methods are intercepted: - * ``lookup`` — records each visited key in an internal LRU counter. * ``prepare_store`` — filters out keys that have not yet + * ``lookup`` — records the visited key in an internal LRU + counter, then delegates to the backing manager. crossed the threshold *before* calling the backing ``prepare_store``. @@ -66,18 +67,16 @@ def __init__( # Intercepted methods # ------------------------------------------------------------------ - def lookup(self, keys: Iterable[OffloadKey], req_context: ReqContext) -> int | None: - """Record each key, then delegate lookup to backing manager.""" - keys = list(keys) - for key in keys: - if key in self.counts: - self.counts.move_to_end(key) - self.counts[key] += 1 - else: - if len(self.counts) >= self.max_tracker_size: - self.counts.popitem(last=False) # evict LRU - self.counts[key] = 1 - return self._backing.lookup(keys, req_context) + def lookup(self, key: OffloadKey, req_context: ReqContext) -> bool | None: + """Record the key, then delegate lookup to backing manager.""" + if key in self.counts: + self.counts.move_to_end(key) + self.counts[key] += 1 + else: + if len(self.counts) >= self.max_tracker_size: + self.counts.popitem(last=False) # evict LRU + self.counts[key] = 1 + return self._backing.lookup(key, req_context) def prepare_store( self, keys: Iterable[OffloadKey], req_context: ReqContext