Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/v1/metrics/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest

from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, PromptTokenStats, RequestStateStats

Expand Down Expand Up @@ -209,3 +211,82 @@ def test_prompt_token_stats_full_external_transfer_recompute():
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 1000
assert stats.recomputed_tokens == 1


def test_prompt_token_stats_negative_external_clamped():
"""Negative num_external_computed_tokens is clamped to 0.

This can happen when KV load failures cause
num_external_computed_tokens to be over-subtracted in
_update_requests_with_invalid_blocks (GitHub issue #36533).
"""
stats = PromptTokenStats()
stats.update_from_output(
num_cached_tokens=500,
num_external_computed_tokens=-100,
prompt_len=1000,
)
assert stats.external_kv_transfer == 0
assert stats.local_cache_hit == 500
assert stats.computed == 500
assert stats.total == 1000


def test_prompt_token_stats_external_exceeds_cached():
"""num_external_computed_tokens > num_cached_tokens is clamped.

This can happen when num_cached_tokens is stale after retry
while num_external_computed_tokens is re-queried from connector.
"""
stats = PromptTokenStats()
stats.update_from_output(
num_cached_tokens=300,
num_external_computed_tokens=500,
prompt_len=1000,
)
# external clamped to cached (no recomputed token in this case)
assert stats.external_kv_transfer == 300
assert stats.local_cache_hit == 0
assert stats.computed == 700
assert stats.total == 1000


def test_prompt_token_stats_negative_cached_clamped():
"""Negative num_cached_tokens (e.g. sentinel -1) is clamped to 0."""
stats = PromptTokenStats()
stats.update_from_output(
num_cached_tokens=-1,
num_external_computed_tokens=0,
prompt_len=1000,
)
assert stats.cached_tokens == 0
assert stats.computed == 1000
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 0


@pytest.mark.parametrize(
"num_cached,num_external,prompt_len",
[
(0, 0, 100),
(50, -50, 100), # negative external
(-10, 0, 100), # negative cached
(50, 100, 100), # external > cached
(-5, -5, 100), # both negative
(99, 100, 100), # full cache, external > cached
],
)
def test_prompt_token_stats_all_non_negative(num_cached, num_external, prompt_len):
"""All PromptTokenStats fields must be non-negative after update."""
stats = PromptTokenStats()
stats.update_from_output(
num_cached_tokens=num_cached,
num_external_computed_tokens=num_external,
prompt_len=prompt_len,
)
assert stats.computed >= 0
assert stats.local_cache_hit >= 0
assert stats.external_kv_transfer >= 0
assert stats.cached_tokens >= 0
assert stats.recomputed_tokens >= 0
assert stats.total >= 0
9 changes: 8 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,7 @@ def _preempt_request(self, request: Request, timestamp: float) -> None:
self.encoder_cache_manager.free(request)
request.status = RequestStatus.PREEMPTED
request.num_computed_tokens = 0
request.num_cached_tokens = -1
if request.spec_token_ids:
request.spec_token_ids = []
request.num_preemptions += 1
Expand Down Expand Up @@ -1987,6 +1988,9 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool:
# There may be a local cache hit on retry.
self.kv_cache_manager.free(request)

# Reset so num_cached_tokens is re-captured on reschedule,
# consistent with the (potentially changed) external token count.
request.num_cached_tokens = -1
self.failed_recving_kv_req_ids.remove(request.request_id)
else:
# Now that the blocks are ready, actually cache them.
Expand Down Expand Up @@ -2123,7 +2127,10 @@ def _update_requests_with_invalid_blocks(
req_num_computed_tokens - request.num_computed_tokens
)
total_affected_tokens += num_affected_tokens
request.num_external_computed_tokens -= num_affected_tokens
request.num_external_computed_tokens = max(
0,
request.num_external_computed_tokens - num_affected_tokens,
)
# collect invalid block and all downstream dependent blocks
if evict_blocks:
blocks_to_evict.update(req_block_ids[idx:])
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,14 @@ def update_from_output(
# needs at least one input token to run a forward pass.
recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0

# Guard against inconsistent values from KV load failures or
# stale num_cached_tokens after request retry/preemption.
num_cached_tokens = max(0, num_cached_tokens)
num_external_computed_tokens = max(0, num_external_computed_tokens)
num_external_computed_tokens = min(
num_external_computed_tokens, num_cached_tokens + recomputed
)
Comment on lines +282 to +288
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential logic issue with the order of operations. The recomputed variable is calculated on line 280 using num_cached_tokens before it is sanitized on line 284. If num_cached_tokens is negative (e.g., the sentinel value -1), this could lead to an incorrect value for recomputed and subsequently incorrect metrics, although it may not cause a crash in most scenarios.

To ensure correctness and improve robustness, the recomputed calculation should occur after num_cached_tokens has been clamped to a non-negative value.

I recommend reordering the logic like this:

        # Guard against inconsistent values from KV load failures or
        # stale num_cached_tokens after request retry/preemption.
        num_cached_tokens = max(0, num_cached_tokens)

        # When all tokens are cached, the scheduler reduces num_cached_tokens
        # by 1 to force the model to recompute the last token, since the model
        # needs at least one input token to run a forward pass.
        recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0

        num_external_computed_tokens = max(0, num_external_computed_tokens)
        num_external_computed_tokens = min(
            num_external_computed_tokens, num_cached_tokens + recomputed
        )

This would involve moving line 280 to after line 284.


self.computed += prompt_len - num_cached_tokens
self.external_kv_transfer += num_external_computed_tokens
self.local_cache_hit += (
Expand Down