diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 2fe45242153c..01e7856f5daf 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -4140,3 +4140,54 @@ def test_eagle3_mm_encoder_cache_with_shift(): f"shifted_end={scheduled_end_with_shift}) overlapping MM at " f"{start_pos}. The fix must schedule encoder inputs." ) + + +def test_preempt_resets_num_cached_tokens(): + """Preemption must reset num_cached_tokens to -1. + + If it is not reset, the sentinel guard ``if request.num_cached_tokens < 0`` + never fires again on reschedule, leaving a stale cached-token count that + can exceed num_external_computed_tokens and produce a negative + local_cache_hit in stats. + """ + BLOCK_SIZE = 4 + MATCHED_TOKENS = BLOCK_SIZE # connector always reports this many ext tokens + + scheduler = create_scheduler( + enable_prefix_caching=True, + use_kv_connector=mock_kv(matched_tokens=MATCHED_TOKENS, is_async=False), + block_size=BLOCK_SIZE, + ) + [request] = create_requests( + num_requests=1, + num_tokens=BLOCK_SIZE * 3, + max_tokens=16, + block_size=BLOCK_SIZE, + ) + scheduler.add_request(request) + + # First schedule: num_cached_tokens is initialised by the sentinel guard. + output = scheduler.schedule() + assert len(scheduler.running) == 1 + scheduler.update_from_output(output, make_output(scheduler)) + assert request.num_cached_tokens >= 0, ( + "num_cached_tokens should be set after first schedule" + ) + + # Directly preempt (the scheduler pops from running before calling this). + scheduler.running.remove(request) + request.status = RequestStatus.RUNNING # guard inside _preempt_request + scheduler._preempt_request(request, timestamp=0.0) + + # The sentinel must be restored so the guard fires on the next schedule. + assert request.num_cached_tokens == -1 + + # Reschedule: num_cached_tokens must be recalculated. + output = scheduler.schedule() + assert len(scheduler.running) == 1 + assert request.num_cached_tokens >= 0 + # Invariant: local_cache_hit = num_cached_tokens - num_external_computed_tokens >= 0 + assert request.num_cached_tokens >= request.num_external_computed_tokens, ( + f"negative local_cache_hit: num_cached_tokens={request.num_cached_tokens}, " + f"num_external_computed_tokens={request.num_external_computed_tokens}" + ) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 4628e634475c..d8b4ae9663e5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -924,6 +924,7 @@ def _preempt_request(self, request: Request, timestamp: float) -> None: self.encoder_cache_manager.free(request) request.status = RequestStatus.PREEMPTED request.num_computed_tokens = 0 + request.num_cached_tokens = -1 if request.spec_token_ids: request.spec_token_ids = [] request.num_preemptions += 1