diff --git a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py index 53fe599849b6..af523a814d6c 100644 --- a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py +++ b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py @@ -15,9 +15,18 @@ from unittest.mock import Mock import pytest +import torch +from vllm.distributed.kv_transfer.kv_connector.v1 import SupportsHMA from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + SlidingWindowSpec, +) from vllm.v1.request import FinishReason, Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager from .utils import ( create_model_runner_output, @@ -478,3 +487,110 @@ def cache_blocks_spy(req, num_tokens): # request should be in the running queue assert request in recompute_scheduler.running + + +@pytest.fixture +def multi_group_fail_scheduler(): + vllm_config = create_vllm_config(kv_load_failure_policy="fail") + kv_cache_config = KVCacheConfig( + num_blocks=1000, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec( + block_size=16, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec( + block_size=32, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + sliding_window=32 * 4, + ), + ), + ], + ) + vllm_config.cache_config.num_gpu_blocks = 1000 + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + block_size=16, + ) + + +def test_sync_fail_multi_group_invalid_blocks_evicted(multi_group_fail_scheduler): + """ + Test sync fail case with multiple KV groups + + Same as test_sync_fail_invalid_blocks_evicted but for multiple KV groups. + """ + scheduler = multi_group_fail_scheduler + + num_prompt_tokens = 100 * scheduler.block_size + num_external_computed_tokens = 75 * scheduler.block_size + + request = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request) + + req_num_new_matched_tokens = { + request.request_id: num_external_computed_tokens, + } + + class HMAMockConnector(Mock, SupportsHMA): + def request_finished_all_groups(self, _, __): + return False, None + + scheduler.connector = HMAMockConnector() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, False) + ) + scheduler.connector.request_finished.return_value = (False, None) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + + # get allocated block IDs + req_block_ids = scheduler_output.scheduled_new_reqs[0].block_ids[-1] + req_block_ids = [block_id for block_id in req_block_ids if block_id] + invalid_block_id = req_block_ids[1] + invalid_block_ids = {invalid_block_id} + + # report invalid blocks - request should fail + model_runner_output = create_model_runner_output( + [request], + invalid_block_ids=invalid_block_ids, + ) + + _ = scheduler.update_from_output(scheduler_output, model_runner_output) + + # verify the request was removed from scheduler + assert request.request_id not in scheduler.requests + assert len(scheduler.running) == 0 + + # verify invalid block was actually freed from cache + try: + block_id_groups = scheduler.kv_cache_manager.get_block_ids(request.request_id) + # if we get here, check if blocks were actually freed + if block_id_groups is not None: + num_allocated_blocks = 0 + for block_ids in block_id_groups: + num_allocated_blocks += sum(block_id != 0 for block_id in block_ids) + if num_allocated_blocks > 0: + pytest.fail( + f"Invalid blocks still tracked for finished request! " + f"Request {request.request_id} should have been freed but " + f"still has {num_allocated_blocks} blocks allocated." + ) + # blocks list exists but is empty - this is fine, they were freed + except KeyError: + # expected - request completely removed from tracking + pass diff --git a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py index fcdb2869d7dc..b39f5dca92f0 100644 --- a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py +++ b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py @@ -5,9 +5,17 @@ from unittest.mock import Mock import pytest +import torch from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheConfig, + KVCacheGroupSpec, + SlidingWindowSpec, +) from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager from .utils import ( create_model_runner_output, @@ -333,3 +341,164 @@ def test_async_progressive_load_failure( assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert scheduler.failed_recving_kv_req_ids == {request.request_id} assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 + + +@pytest.fixture +def multi_group_scheduler(): + vllm_config = create_vllm_config(kv_load_failure_policy="recompute") + kv_cache_config = KVCacheConfig( + num_blocks=1000, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer1"], + FullAttentionSpec( + block_size=16, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + ), + ), + KVCacheGroupSpec( + ["layer2"], + SlidingWindowSpec( + block_size=32, + num_kv_heads=1, + head_size=1, + dtype=torch.float16, + sliding_window=32 * 4, + ), + ), + ], + ) + vllm_config.cache_config.num_gpu_blocks = 1000 + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + block_size=16, + ) + + +def test_async_load_failure_multiple_kv_groups(multi_group_scheduler): + scheduler = multi_group_scheduler + + num_prompt_tokens = 100 * scheduler.block_size + num_external_computed_tokens = 50 * scheduler.block_size + + request1 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request1) + request2 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request2) + request3 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request3) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: num_external_computed_tokens, + request3.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=True) + ) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 + + # Simulate a failure in loading one of request2 full attention blocks. + # and a failure in loading one of request3 sliding window blocks. + req2_block_ids = scheduler.kv_cache_manager.get_block_ids(request2.request_id)[0] + req3_block_ids = scheduler.kv_cache_manager.get_block_ids(request3.request_id)[1] + invalid_block_ids = {req2_block_ids[-1], req3_block_ids[-1]} + model_runner_output = create_model_runner_output( + reqs=[], + finished_recving={request1.request_id}, + invalid_block_ids=invalid_block_ids, + use_eos=True, + ) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + assert len(scheduler.waiting) == 3 + for request in scheduler.running: + if request.request_id != request1.request_id: + assert request.num_computed_tokens == 0 + assert request.num_external_computed_tokens == 0 + else: + assert request.num_computed_tokens == num_external_computed_tokens + assert request.num_external_computed_tokens == num_external_computed_tokens + assert scheduler.failed_recving_kv_req_ids == { + request2.request_id, + request3.request_id, + } + + +def test_sync_load_failure_multiple_kv_groups(multi_group_scheduler): + scheduler = multi_group_scheduler + scheduler.max_num_scheduled_tokens = 10000 + + num_prompt_tokens = 100 * scheduler.block_size + num_external_computed_tokens = 75 * scheduler.block_size + common_prefix_len = 50 * scheduler.block_size + + request1 = create_request( + num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len + ) + scheduler.add_request(request=request1) + request2 = create_request( + num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len + ) + scheduler.add_request(request=request2) + request3 = create_request( + num_tokens=num_prompt_tokens, common_prefix_len=common_prefix_len + ) + scheduler.add_request(request=request3) + request4 = create_request(num_tokens=num_prompt_tokens) + scheduler.add_request(request=request4) + + # Mock KV connector method. + # req_id -> num_external_computed_tokens + req_num_new_matched_tokens = { + request1.request_id: num_external_computed_tokens, + request2.request_id: num_external_computed_tokens, + request3.request_id: num_external_computed_tokens, + request4.request_id: num_external_computed_tokens, + } + + scheduler.connector = Mock() + scheduler.connector.get_num_new_matched_tokens.side_effect = ( + _make_get_num_new_matched_tokens(req_num_new_matched_tokens, async_load=False) + ) + scheduler.connector.take_events.return_value = () + + scheduler_output = scheduler.schedule() + assert scheduler.connector.get_num_new_matched_tokens.call_count == 4 + + # Simulate a failure in loading: + # - one of request2 full attention blocks (shared block). + # - one of request3 sliding window blocks (non-shared block). + req2_block_ids = scheduler.kv_cache_manager.get_block_ids(request2.request_id)[0] + req3_block_ids = scheduler.kv_cache_manager.get_block_ids(request3.request_id)[1] + req3_block_ids = [block_id for block_id in req3_block_ids if block_id] + invalid_block_ids = {req2_block_ids[1], req3_block_ids[1]} + model_runner_output = create_model_runner_output( + reqs=[request1, request2, request3, request4], + invalid_block_ids=invalid_block_ids, + ) + + scheduler.update_from_output(scheduler_output, model_runner_output) + + assert len(scheduler.running) == 4 + for request in scheduler.waiting: + if request.request_id != request4.request_id: + assert request.num_computed_tokens == 0 + assert request.num_external_computed_tokens == 0 + else: + assert request.num_computed_tokens == num_external_computed_tokens + assert request.num_external_computed_tokens == num_external_computed_tokens diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cb99de93b6fb..29c1ce986c5e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -32,6 +32,7 @@ ) from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.encoder_budget import MultiModalBudget +from vllm.utils.math_utils import cdiv from vllm.v1.core.encoder_cache_manager import ( EncoderCacheManager, EncoderDecoderCacheManager, @@ -2094,8 +2095,7 @@ def _update_requests_with_invalid_blocks( is_affected = False marked_invalid_block = False req_id = request.request_id - # TODO (davidb): add support for hybrid memory allocator - (req_block_ids,) = self.kv_cache_manager.get_block_ids(req_id) + # We iterate only over blocks that may contain externally computed # tokens if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: @@ -2105,15 +2105,42 @@ def _update_requests_with_invalid_blocks( # Sync loading. num_computed_tokens includes new tokens req_num_computed_tokens = request.num_cached_tokens - req_num_computed_blocks = ( - req_num_computed_tokens + self.block_size - 1 - ) // self.block_size - for idx, block_id in zip(range(req_num_computed_blocks), req_block_ids): + block_sizes = ( + kv_cache_group.kv_cache_spec.block_size + for kv_cache_group in self.kv_cache_config.kv_cache_groups + ) + + req_block_ids = itertools.chain.from_iterable( + group_block_ids[: cdiv(req_num_computed_tokens, block_size)] + for group_block_ids, block_size in zip( + self.kv_cache_manager.get_block_ids(req_id), block_sizes + ) + ) + + for idx, block_id in enumerate(req_block_ids): if block_id not in invalid_block_ids: continue is_affected = True + if len(self.kv_cache_config.kv_cache_groups) > 1: + # We have more than one KV cache group. + # This means that not all layers are full attention, + # but instead some layers are sliding window attention or SSM. + # This means the entire request has to be re-computed. + total_affected_tokens += request.num_computed_tokens + request.num_computed_tokens = 0 + request.num_external_computed_tokens = 0 + marked_invalid_block = True + if evict_blocks: + # evict the entire request + blocks_to_evict.update( + itertools.chain.from_iterable( + self.kv_cache_manager.get_block_ids(req_id) + ) + ) + break + if block_id in marked_invalid_block_ids: # This invalid block is shared with a previous request # and was already marked for recomputation. @@ -2140,7 +2167,9 @@ def _update_requests_with_invalid_blocks( request.num_external_computed_tokens -= num_affected_tokens # collect invalid block and all downstream dependent blocks if evict_blocks: - blocks_to_evict.update(req_block_ids[idx:]) + blocks_to_evict.update( + self.kv_cache_manager.get_block_ids(req_id)[0][idx:] + ) if is_affected: if not marked_invalid_block: