diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index acac3753d712..d39e1d5bedd1 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -6,6 +6,7 @@ import pytest import torch +from tests.v1.kv_connector.unit.utils import MockKVConnector from vllm.config import ( CacheConfig, ECTransferConfig, @@ -15,6 +16,7 @@ SpeculativeConfig, VllmConfig, ) +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorSchedulerOutput from vllm.multimodal.inputs import ( MultiModalFeatureSpec, MultiModalKwargsItem, @@ -31,7 +33,12 @@ KVCacheConfig, KVCacheGroupSpec, ) -from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput +from vllm.v1.outputs import ( + EMPTY_MODEL_RUNNER_OUTPUT, + DraftTokenIds, + KVConnectorOutput, + ModelRunnerOutput, +) from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager @@ -1415,6 +1422,110 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role): assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 +def test_kv_connector_lock_blocks(): + """ + Test a KV connector locking (holding back from eviction) GPU blocks. + """ + block_size = 16 + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=mock_kv(matched_tokens=0, is_async=False), + block_size=block_size, + ) + connector = scheduler.connector + assert isinstance(connector, MockKVConnector) + + kv_cache_manager = scheduler.kv_cache_manager + block_pool = kv_cache_manager.block_pool + free_block_queue = block_pool.free_block_queue + num_blocks = free_block_queue.num_free_blocks + + # single request with 3 blocks + 4 decoded tokens + request = create_requests( + num_requests=1, + num_tokens=3 * block_size, + max_tokens=4, + block_size=block_size, + )[0] + scheduler.add_request(request) + + # decoded token #1, no blocks locked/unlocked + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert request.num_tokens == 3 * block_size + 1 + + # extract request block IDs + req_block_id_groups = kv_cache_manager.get_block_ids(request.request_id) + assert len(req_block_id_groups) == 1 + req_block_ids = req_block_id_groups[0] + + # assert that all request blocks have ref_cnt == 1 + req_blocks = [block_pool.blocks[block_id] for block_id in req_block_ids] + assert [block.ref_cnt for block in req_blocks] == [1, 1, 1] + + # decoded token #2, block #0 locked once, block #2 locked twice + scheduler_output = scheduler.schedule() + connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput( + block_ids_to_lock=[req_block_ids[2], req_block_ids[0], req_block_ids[2]] + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert request.num_tokens == 3 * block_size + 2 + assert [block.ref_cnt for block in req_blocks] == [2, 1, 3] + + # decoded token #3, block #1 locked three times, block #1 unlocked once + scheduler_output = scheduler.schedule() + connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput( + block_ids_to_lock=[req_block_ids[1], req_block_ids[1], req_block_ids[1]], + block_ids_to_unlock=[req_block_ids[1]], + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert request.num_tokens == 3 * block_size + 3 + assert [block.ref_cnt for block in req_blocks] == [2, 3, 3] + + # decoded token #4 (last), block #2 unlocked twice, request is freed + scheduler_output = scheduler.schedule() + connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput( + block_ids_to_unlock=[req_block_ids[1], req_block_ids[1]] + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert not scheduler.running + assert not scheduler.waiting + assert request.num_tokens == 3 * block_size + 4 + assert [block.ref_cnt for block in req_blocks] == [1, 0, 2] + assert scheduler.has_work() + + # step with no KV connector output + scheduler_output = scheduler.schedule() + connector.kv_connector_scheduler_output = None + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert [block.ref_cnt for block in req_blocks] == [1, 0, 2] + assert free_block_queue.num_free_blocks == num_blocks - 2 + assert not scheduler.has_finished_requests() + assert not scheduler.has_unfinished_requests() + assert scheduler.has_work() + + # block #0 unlocked once, block #2 unlocked once + scheduler_output = scheduler.schedule() + connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput( + block_ids_to_unlock=[req_block_ids[0], req_block_ids[2]] + ) + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert [block.ref_cnt for block in req_blocks] == [0, 0, 1] + assert free_block_queue.num_free_blocks == num_blocks - 1 + assert scheduler.has_work() + + # block #2 unlocked once + scheduler_output = scheduler.schedule() + connector.kv_connector_scheduler_output = KVConnectorSchedulerOutput( + block_ids_to_unlock=[req_block_ids[2]] + ) + scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT) + assert [block.ref_cnt for block in req_blocks] == [0, 0, 0] + assert free_block_queue.num_free_blocks == num_blocks + assert not scheduler.has_work() + + def make_output(scheduler: Scheduler): return ModelRunnerOutput( req_ids=[req.request_id for req in scheduler.running], diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 4f96ded7ec35..125a28b871dc 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -179,11 +179,11 @@ def test_engine_core(): req0.request_id = req1.request_id = "test" engine_core.add_request(*engine_core.preprocess_add_request(req0)) - while engine_core.scheduler.has_requests(): + while engine_core.scheduler.has_work(): engine_core.step_fn() engine_core.add_request(*engine_core.preprocess_add_request(req1)) - while engine_core.scheduler.has_requests(): + while engine_core.scheduler.has_work(): engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 @@ -222,7 +222,7 @@ def _check_engine_state(): assert len(engine_core.scheduler.waiting) == 1 assert len(engine_core.scheduler.running) == 0 # Loop through until they are all done. - while engine_core.scheduler.has_requests(): + while engine_core.scheduler.has_work(): engine_core.step_fn() assert len(engine_core.scheduler.waiting) == 0 assert len(engine_core.scheduler.running) == 0 diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 783678e9cefd..a6fe4adadb60 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -5,13 +5,19 @@ import tempfile from pathlib import Path from typing import Any +from unittest.mock import MagicMock import pytest +from tests.v1.kv_connector.unit.utils import create_vllm_config from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory -from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1 +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorSchedulerOutput, +) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import ( MultiConnector, @@ -21,6 +27,7 @@ NixlKVConnectorStats, ) from vllm.platforms import current_platform +from vllm.v1.kv_cache_interface import KVCacheConfig MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" @@ -41,7 +48,14 @@ class MockConnectorStats(KVConnectorStats): class MockConnector(KVConnectorBase_V1): - """Mock connector that implements build_kv_connector_stats for testing.""" + """Mock connector for testing.""" + + def __new__(cls, *args, **kwargs): + # mock all KVConnectorBase_V1 functions + mock = MagicMock(spec_set=KVConnectorBase_V1) + # Override just build_kv_connector_stats + mock.build_kv_connector_stats = cls.build_kv_connector_stats + return mock @classmethod def build_kv_connector_stats( @@ -71,16 +85,42 @@ def update_state_after_alloc(self, request, blocks, num_tokens) -> None: pass -class MockCrossLayerConnector(MockConnector): - @property - def prefer_cross_layer_blocks(self) -> bool: - return True - - # Register the mock connector KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__) +@pytest.fixture +def mc() -> MultiConnector: + """MultiConnector using two mocked connectors""" + vllm_config = create_vllm_config() + + mock_connector_config = { + "kv_connector": "MockConnector", + "kv_role": "kv_both", + "kv_connector_module_path": "tests.v1.kv_connector.unit.test_multi_connector", + } + + vllm_config.kv_transfer_config = KVTransferConfig( + kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [mock_connector_config, mock_connector_config], + }, + ) + + kv_cache_config = KVCacheConfig( + num_blocks=0, kv_cache_tensors=[], kv_cache_groups=[] + ) + + mc = MultiConnector( + vllm_config=vllm_config, + role=KVConnectorRole.WORKER, + kv_cache_config=kv_cache_config, + ) + + return mc + + # Helper function to compare directories recursively def _compare_directories(dir1: Path, dir2: Path) -> bool: """Compares two directories recursively for identical content.""" @@ -630,19 +670,57 @@ def test_is_empty_with_multiple_connectors(self): assert not stats.is_empty() -class TestMultiConnectorPreferCrossLayerBlocks: - def test_all_connectors_prefer_cross_layer_blocks(self): - mc = MultiConnector.__new__(MultiConnector) - mc._connectors = [ - MockCrossLayerConnector.__new__(MockCrossLayerConnector), - MockCrossLayerConnector.__new__(MockCrossLayerConnector), - ] - assert mc.prefer_cross_layer_blocks is True - - def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self): - mc = MultiConnector.__new__(MultiConnector) - mc._connectors = [ - MockCrossLayerConnector.__new__(MockCrossLayerConnector), - MockConnector.__new__(MockConnector), # default False - ] - assert mc.prefer_cross_layer_blocks is False +def test_multi_connector_prefer_cross_layer_blocks(mc): + mc._connectors[0].prefer_cross_layer_blocks = False + mc._connectors[1].prefer_cross_layer_blocks = True + assert mc.prefer_cross_layer_blocks is False + + mc._connectors[0].prefer_cross_layer_blocks = True + mc._connectors[1].prefer_cross_layer_blocks = True + assert mc.prefer_cross_layer_blocks is True + + +def test_multi_connector_report_to_scheduler(mc): + # both return None + mc._connectors[0].report_to_scheduler.return_value = None + mc._connectors[1].report_to_scheduler.return_value = None + output = mc.report_to_scheduler() + assert output is None + + # only first is None + kv_connector_scheduler_output = KVConnectorSchedulerOutput( + block_ids_to_lock=[1, 2, 3], + block_ids_to_unlock=[4, 5, 6], + ) + mc._connectors[0].report_to_scheduler.return_value = None + mc._connectors[1].report_to_scheduler.return_value = kv_connector_scheduler_output + + output = mc.report_to_scheduler() + assert output is not None + assert output.block_ids_to_lock == [1, 2, 3] + assert output.block_ids_to_unlock == [4, 5, 6] + + # only second is None + kv_connector_scheduler_output = KVConnectorSchedulerOutput( + block_ids_to_lock=[1, 2, 3], + block_ids_to_unlock=[4, 5, 6], + ) + mc._connectors[0].report_to_scheduler.return_value = kv_connector_scheduler_output + mc._connectors[1].report_to_scheduler.return_value = None + output = mc.report_to_scheduler() + assert output is not None + assert output.block_ids_to_lock == [1, 2, 3] + assert output.block_ids_to_unlock == [4, 5, 6] + + # two outputs + kv_connector_scheduler_output2 = KVConnectorSchedulerOutput( + block_ids_to_lock=[7, 1, 8], + block_ids_to_unlock=[9, 2, 10], + ) + mc._connectors[0].report_to_scheduler.return_value = kv_connector_scheduler_output + mc._connectors[1].report_to_scheduler.return_value = kv_connector_scheduler_output2 + + output = mc.report_to_scheduler() + assert output is not None + assert output.block_ids_to_lock == [1, 2, 3, 7, 1, 8] + assert output.block_ids_to_unlock == [4, 5, 6, 9, 2, 10] diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index e754a09179a9..17d01bddf258 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -24,6 +24,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + KVConnectorSchedulerOutput, ) from vllm.distributed.kv_transfer.kv_connector.v1.example_connector import ( # noqa ExampleConnector, @@ -358,6 +359,7 @@ def __init__( matched_tokens=extra_config["matched_tokens"], is_async=extra_config["is_async"], ) + self.kv_connector_scheduler_output: KVConnectorSchedulerOutput | None = None def get_num_new_matched_tokens( self, @@ -402,6 +404,9 @@ def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs): def wait_for_save(self): pass + def report_to_scheduler(self) -> KVConnectorSchedulerOutput | None: + return self.kv_connector_scheduler_output + KVConnectorFactory.register_connector( "TestExampleConnector", __name__, TestExampleConnector.__name__ diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 01b606b28dff..a23e2106d8ac 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -14,6 +14,8 @@ temporary buffer alloc by the CacheManager. update_connector_output() - update KVConnector state after output is received from worker-side connectors. + report_to_scheduler() - report back to scheduler on blocks + being transferred or finished transferring. request_finished() - called once when a request is finished, with the computed kv cache blocks for the request. Returns whether KV cache should be freed now or if the @@ -41,6 +43,7 @@ import enum from abc import ABC, abstractmethod from collections.abc import Callable, Iterable +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Optional import torch @@ -144,6 +147,24 @@ class KVConnectorMetadata(ABC): # noqa: B024 pass +@dataclass +class KVConnectorSchedulerOutput: + """ + Output of scheduler-side connector back to the core scheduler. + """ + + # list of block IDs to prevent evicting from the GPU KV cache + # repetitions allowed + block_ids_to_lock: list[int] | None = None + # list of previously locked block IDs to be released + # repetitions allowed + # denote by ref_cnt(block_id) = + # # times block_id appeared in block_ids_to_lock + # - # times block_id appeared in block_ids_to_unlock + # ref_cnt must be >= 0 and the block will not be freed unless ref_cnt == 0 + block_ids_to_unlock: list[int] | None = None + + class KVConnectorBase_V1(ABC): """ Base class for KV connectors. @@ -494,6 +515,17 @@ def update_connector_output(self, connector_output: KVConnectorOutput): """ return + def report_to_scheduler(self) -> KVConnectorSchedulerOutput | None: + """ + Update scheduler on transfers being made, so that the relevant + KV blocks can be protected from eviction or freed. + + Returns: + An optional KVConnectorSchedulerOutput, + which can be used to hold / release specific GPU blocks. + """ + return None + def request_finished( self, request: "Request", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 412e2c57133f..1a0492fee9f9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -15,6 +15,7 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, + KVConnectorSchedulerOutput, ) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -332,6 +333,34 @@ def update_connector_output(self, connector_output: KVConnectorOutput): for c in self._connectors: c.update_connector_output(connector_output) + def report_to_scheduler(self) -> KVConnectorSchedulerOutput | None: + block_ids_to_lock: list[int] | None = None + block_ids_to_unlock: list[int] | None = None + for c in self._connectors: + output = c.report_to_scheduler() + if output is None: + continue + + if output.block_ids_to_lock: + if not block_ids_to_lock: + block_ids_to_lock = output.block_ids_to_lock + else: + block_ids_to_lock.extend(output.block_ids_to_lock) + + if output.block_ids_to_unlock: + if not block_ids_to_unlock: + block_ids_to_unlock = output.block_ids_to_unlock + else: + block_ids_to_unlock.extend(output.block_ids_to_unlock) + + if not block_ids_to_lock and not block_ids_to_unlock: + return None + + return KVConnectorSchedulerOutput( + block_ids_to_lock=block_ids_to_lock, + block_ids_to_unlock=block_ids_to_unlock, + ) + def request_finished( self, request: "Request", diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index ce7e396d8a9a..42b9cfc7dbf5 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable, Sequence +from collections.abc import Iterable from typing import Any from vllm.distributed.kv_events import ( @@ -369,7 +369,7 @@ def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool: ) return True - def touch(self, blocks: Sequence[KVCacheBlock]) -> None: + def touch(self, blocks: Iterable[KVCacheBlock]) -> None: """Touch a block increases its reference count by 1, and may remove the block from the free queue. This is used when a block is hit by another request with the same prefix. @@ -395,12 +395,13 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: priority. """ # Materialize the iterable to allow multiple passes. - blocks_list = list(ordered_blocks) - for block in blocks_list: + blocks_to_free = [] + for block in ordered_blocks: block.ref_cnt -= 1 - self.free_block_queue.append_n( - [block for block in blocks_list if block.ref_cnt == 0 and not block.is_null] - ) + assert block.ref_cnt >= 0 + if block.ref_cnt == 0 and not block.is_null: + blocks_to_free.append(block) + self.free_block_queue.append_n(blocks_to_free) def evict_blocks(self, block_ids: set[int]) -> None: """evict blocks from the prefix cache by their block IDs. diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index 92d8d929287b..d0e375845786 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -162,10 +162,17 @@ def has_finished_requests(self) -> bool: """ raise NotImplementedError - def has_requests(self) -> bool: - """Returns True if there are unfinished requests, or finished requests - not yet returned in SchedulerOutputs.""" - return self.has_unfinished_requests() or self.has_finished_requests() + def has_work(self) -> bool: + """Returns True if one of the below exist: + 1. Unfinished requests + 2. Finished requests not yet returned in SchedulerOutputs + 3. KV Connector work (e.g. KV blocks being transferred) + """ + return ( + self.has_unfinished_requests() + or self.has_finished_requests() + or self.has_kv_connector_work() + ) @abstractmethod def reset_prefix_cache( @@ -203,3 +210,9 @@ def shutdown(self) -> None: def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]: return None + + def has_kv_connector_work(self) -> bool: + """Returns True if there's KV Connector work + (e.g. KV blocks being transferred) + """ + return False diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 30a459386a73..61bd066ada44 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -171,6 +171,7 @@ def __init__( # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() self.failed_recving_kv_req_ids: set[str] = set() + self.pending_kv_block_transfers = 0 # Encoder-related. # Calculate encoder cache size if applicable @@ -1419,6 +1420,25 @@ def update_from_output( # KV Connector: update state for finished KV Transfers. if kv_connector_output: self._update_from_kv_xfer_finished(kv_connector_output) + if self.connector is not None: + kv_connector_scheduler_output = self.connector.report_to_scheduler() + if kv_connector_scheduler_output is not None: + block_pool = self.kv_cache_manager.block_pool + + block_ids_to_lock = kv_connector_scheduler_output.block_ids_to_lock + if block_ids_to_lock: + block_pool.touch( + block_pool.blocks[block_id] for block_id in block_ids_to_lock + ) + self.pending_kv_block_transfers += len(block_ids_to_lock) + + block_ids_to_unlock = kv_connector_scheduler_output.block_ids_to_unlock + if block_ids_to_unlock: + block_pool.free_blocks( + block_pool.blocks[block_id] for block_id in block_ids_to_unlock + ) + self.pending_kv_block_transfers -= len(block_ids_to_unlock) + assert self.pending_kv_block_transfers >= 0 # collect KV cache events from KV cache manager events = self.kv_cache_manager.take_events() @@ -1707,6 +1727,9 @@ def get_num_unfinished_requests(self) -> int: def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 + def has_kv_connector_work(self) -> bool: + return self.pending_kv_block_transfers > 0 + def reset_prefix_cache( self, reset_running_requests: bool = False, reset_connector: bool = False ) -> bool: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index d5e75824d2e3..0eaedfca5465 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -373,9 +373,9 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: was executed. """ - # Check for any requests remaining in the scheduler - unfinished, - # or finished and not yet removed from the batch. - if not self.scheduler.has_requests(): + # Check for any work remaining in the scheduler - unfinished requests, + # finished requests and not yet removed from the batch, or KV connector work. + if not self.scheduler.has_work(): return {}, False scheduler_output = self.scheduler.schedule() future = self.model_executor.execute_model(scheduler_output, non_block=True) @@ -433,7 +433,7 @@ def step_with_batch_queue( model_executed = False deferred_scheduler_output = None - if self.scheduler.has_requests(): + if self.scheduler.has_work(): scheduler_output = self.scheduler.schedule() exec_future = self.model_executor.execute_model( scheduler_output, non_block=True @@ -971,7 +971,7 @@ def _process_input_queue(self): waited = False while ( not self.engines_running - and not self.scheduler.has_requests() + and not self.scheduler.has_work() and not self.batch_queue ): if self.input_queue.empty():