diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py index d49874adc998..9ee9993926e8 100644 --- a/tests/v1/metrics/test_stats.py +++ b/tests/v1/metrics/test_stats.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + from vllm.v1.engine import FinishReason from vllm.v1.metrics.stats import IterationStats, PromptTokenStats, RequestStateStats @@ -209,3 +211,82 @@ def test_prompt_token_stats_full_external_transfer_recompute(): assert stats.local_cache_hit == 0 assert stats.external_kv_transfer == 1000 assert stats.recomputed_tokens == 1 + + +def test_prompt_token_stats_negative_external_clamped(): + """Negative num_external_computed_tokens is clamped to 0. + + This can happen when KV load failures cause + num_external_computed_tokens to be over-subtracted in + _update_requests_with_invalid_blocks (GitHub issue #36533). + """ + stats = PromptTokenStats() + stats.update_from_output( + num_cached_tokens=500, + num_external_computed_tokens=-100, + prompt_len=1000, + ) + assert stats.external_kv_transfer == 0 + assert stats.local_cache_hit == 500 + assert stats.computed == 500 + assert stats.total == 1000 + + +def test_prompt_token_stats_external_exceeds_cached(): + """num_external_computed_tokens > num_cached_tokens is clamped. + + This can happen when num_cached_tokens is stale after retry + while num_external_computed_tokens is re-queried from connector. + """ + stats = PromptTokenStats() + stats.update_from_output( + num_cached_tokens=300, + num_external_computed_tokens=500, + prompt_len=1000, + ) + # external clamped to cached (no recomputed token in this case) + assert stats.external_kv_transfer == 300 + assert stats.local_cache_hit == 0 + assert stats.computed == 700 + assert stats.total == 1000 + + +def test_prompt_token_stats_negative_cached_clamped(): + """Negative num_cached_tokens (e.g. sentinel -1) is clamped to 0.""" + stats = PromptTokenStats() + stats.update_from_output( + num_cached_tokens=-1, + num_external_computed_tokens=0, + prompt_len=1000, + ) + assert stats.cached_tokens == 0 + assert stats.computed == 1000 + assert stats.local_cache_hit == 0 + assert stats.external_kv_transfer == 0 + + +@pytest.mark.parametrize( + "num_cached,num_external,prompt_len", + [ + (0, 0, 100), + (50, -50, 100), # negative external + (-10, 0, 100), # negative cached + (50, 100, 100), # external > cached + (-5, -5, 100), # both negative + (99, 100, 100), # full cache, external > cached + ], +) +def test_prompt_token_stats_all_non_negative(num_cached, num_external, prompt_len): + """All PromptTokenStats fields must be non-negative after update.""" + stats = PromptTokenStats() + stats.update_from_output( + num_cached_tokens=num_cached, + num_external_computed_tokens=num_external, + prompt_len=prompt_len, + ) + assert stats.computed >= 0 + assert stats.local_cache_hit >= 0 + assert stats.external_kv_transfer >= 0 + assert stats.cached_tokens >= 0 + assert stats.recomputed_tokens >= 0 + assert stats.total >= 0 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index bf397ad681ca..b4fe0c663c67 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -922,6 +922,7 @@ def _preempt_request(self, request: Request, timestamp: float) -> None: self.encoder_cache_manager.free(request) request.status = RequestStatus.PREEMPTED request.num_computed_tokens = 0 + request.num_cached_tokens = -1 if request.spec_token_ids: request.spec_token_ids = [] request.num_preemptions += 1 @@ -1987,6 +1988,9 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: # There may be a local cache hit on retry. self.kv_cache_manager.free(request) + # Reset so num_cached_tokens is re-captured on reschedule, + # consistent with the (potentially changed) external token count. + request.num_cached_tokens = -1 self.failed_recving_kv_req_ids.remove(request.request_id) else: # Now that the blocks are ready, actually cache them. @@ -2123,7 +2127,10 @@ def _update_requests_with_invalid_blocks( req_num_computed_tokens - request.num_computed_tokens ) total_affected_tokens += num_affected_tokens - request.num_external_computed_tokens -= num_affected_tokens + request.num_external_computed_tokens = max( + 0, + request.num_external_computed_tokens - num_affected_tokens, + ) # collect invalid block and all downstream dependent blocks if evict_blocks: blocks_to_evict.update(req_block_ids[idx:]) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 4a1e8b6f35ce..a4c774836048 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -279,6 +279,14 @@ def update_from_output( # needs at least one input token to run a forward pass. recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0 + # Guard against inconsistent values from KV load failures or + # stale num_cached_tokens after request retry/preemption. + num_cached_tokens = max(0, num_cached_tokens) + num_external_computed_tokens = max(0, num_external_computed_tokens) + num_external_computed_tokens = min( + num_external_computed_tokens, num_cached_tokens + recomputed + ) + self.computed += prompt_len - num_cached_tokens self.external_kv_transfer += num_external_computed_tokens self.local_cache_hit += (