Skip to content

[WIP][Bugfix] Fix negative prompt token counter crash under KV offloading#36638

Closed
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-36533
Closed

[WIP][Bugfix] Fix negative prompt token counter crash under KV offloading#36638
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-36533

Conversation

@haosdent
Copy link
Copy Markdown
Contributor

Purpose

Fix engine crash (ValueError: Counters can only be incremented by non-negative amounts) that occurs under high concurrency with CPU KV offloading enabled (GitHub issue #36533).

Root cause: In _update_requests_with_invalid_blocks (scheduler.py), when KV cache blocks fail to load, num_affected_tokens — which includes both locally-cached and externally-loaded tokens — is subtracted entirely from request.num_external_computed_tokens, driving it negative. This negative value propagates through EngineCoreOutputPromptTokenStatsPrometheusStatLogger.record()Counter.inc() crash.

Secondary issue: request.num_cached_tokens is set once during initial scheduling and never reset when a request is freed for retry after KV failure or preemption. On reschedule, num_external_computed_tokens may be re-queried to a new value while num_cached_tokens stays stale, creating a mismatch that can also produce negative local_cache_hit.

Fix:

  1. Clamp num_external_computed_tokens to non-negative after subtraction in _update_requests_with_invalid_blocks
  2. Reset num_cached_tokens to -1 (sentinel) in preemption and KV failure retry paths so it is re-captured consistently on reschedule
  3. Add defensive guards in PromptTokenStats.update_from_output to clamp inputs and maintain invariants

Test Plan

  • Added test_prompt_token_stats_negative_external_clamped — verifies negative num_external_computed_tokens is clamped to 0
  • Added test_prompt_token_stats_external_exceeds_cached — verifies num_external_computed_tokens > num_cached_tokens is clamped
  • Added test_prompt_token_stats_negative_cached_clamped — verifies negative num_cached_tokens is clamped to 0
  • Added test_prompt_token_stats_all_non_negative (parametrized, 6 cases) — fuzz-style check that all PromptTokenStats fields remain non-negative across edge cases

Test Result

tests/v1/metrics/test_stats.py::test_prompt_token_stats_negative_external_clamped PASSED
tests/v1/metrics/test_stats.py::test_prompt_token_stats_external_exceeds_cached PASSED
tests/v1/metrics/test_stats.py::test_prompt_token_stats_negative_cached_clamped PASSED
tests/v1/metrics/test_stats.py::test_prompt_token_stats_all_non_negative[0-0-100] PASSED
tests/v1/metrics/test_stats.py::test_prompt_token_stats_all_non_negative[50--50-100] PASSED
tests/v1/metrics/test_stats.py::test_prompt_token_stats_all_non_negative[-10-0-100] PASSED
tests/v1/metrics/test_stats.py::test_prompt_token_stats_all_non_negative[50-100-100] PASSED
tests/v1/metrics/test_stats.py::test_prompt_token_stats_all_non_negative[-5--5-100] PASSED
tests/v1/metrics/test_stats.py::test_prompt_token_stats_all_non_negative[99-100-100] PASSED

All 19 tests in test_stats.py pass. All 3 tests in test_invalid_blocks_correctness.py pass.

Clamp num_external_computed_tokens to non-negative in
_update_requests_with_invalid_blocks, reset stale num_cached_tokens
on preemption/retry paths, and add defensive guards in
PromptTokenStats.update_from_output to prevent Prometheus Counter.inc()
from receiving negative values.

Fixes vllm-project#36533

Signed-off-by: haosdent <haosdent@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a crash caused by negative token counters during KV offloading by introducing clamping in the scheduler and defensive guards in PromptTokenStats. It also correctly resets request.num_cached_tokens on preemption and KV failure to prevent using stale values. The changes are well-tested with new unit tests covering various edge cases. My review identified one potential logic issue in vllm/v1/metrics/stats.py where a variable is calculated before its dependent value is sanitized, which could lead to incorrect metrics. I've provided a recommendation to reorder the operations for improved robustness.

Comment thread vllm/v1/metrics/stats.py
Comment on lines +282 to +288
# Guard against inconsistent values from KV load failures or
# stale num_cached_tokens after request retry/preemption.
num_cached_tokens = max(0, num_cached_tokens)
num_external_computed_tokens = max(0, num_external_computed_tokens)
num_external_computed_tokens = min(
num_external_computed_tokens, num_cached_tokens + recomputed
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a potential logic issue with the order of operations. The recomputed variable is calculated on line 280 using num_cached_tokens before it is sanitized on line 284. If num_cached_tokens is negative (e.g., the sentinel value -1), this could lead to an incorrect value for recomputed and subsequently incorrect metrics, although it may not cause a crash in most scenarios.

To ensure correctness and improve robustness, the recomputed calculation should occur after num_cached_tokens has been clamped to a non-negative value.

I recommend reordering the logic like this:

        # Guard against inconsistent values from KV load failures or
        # stale num_cached_tokens after request retry/preemption.
        num_cached_tokens = max(0, num_cached_tokens)

        # When all tokens are cached, the scheduler reduces num_cached_tokens
        # by 1 to force the model to recompute the last token, since the model
        # needs at least one input token to run a forward pass.
        recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0

        num_external_computed_tokens = max(0, num_external_computed_tokens)
        num_external_computed_tokens = min(
            num_external_computed_tokens, num_cached_tokens + recomputed
        )

This would involve moving line 280 to after line 284.

Copy link
Copy Markdown
Contributor

@gambletan gambletan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good defensive fix for the negative prompt token counter crash. Clamping values at the stats layer is a pragmatic approach.

A few observations:

  1. num_cached_tokens = -1 as sentinel (in scheduler.py): Using -1 as a sentinel value for "needs to be re-captured" is functional but fragile. If any code path reads num_cached_tokens before it's re-set (e.g., during a race between scheduling and metrics collection), it will silently propagate -1 into calculations. Consider using Optional[int] with None as the sentinel instead — this would cause a TypeError rather than silently producing wrong metrics if the value is used before being re-initialized.

  2. Clamping order in stats.py: The clamping logic is:

    num_cached_tokens = max(0, num_cached_tokens)
    num_external_computed_tokens = max(0, num_external_computed_tokens)
    num_external_computed_tokens = min(num_external_computed_tokens, num_cached_tokens + recomputed)

    Note that recomputed is computed before clamping num_cached_tokens. If num_cached_tokens was originally negative, recomputed could be 0 when it should be 1 (or vice versa). For example, if num_cached_tokens=-1 and prompt_len=0 (edge case), the recomputed check num_cached_tokens + 1 == prompt_len would be True (since -1 + 1 == 0), incorrectly setting recomputed = 1. Moving the clamping before the recomputed calculation would be safer.

  3. Test coverage: The parameterized test test_prompt_token_stats_all_non_negative is a nice touch — it ensures no field goes negative regardless of input. The case (50, 100, 100) (external > cached) is particularly useful for catching the overflow scenario.

  4. _update_requests_with_invalid_blocks: The max(0, ...) clamping here is the right fix for the root cause. Good that it's applied at the source rather than only at the metrics layer.

@markmc
Copy link
Copy Markdown
Member

markmc commented Mar 10, 2026

xref #34079

@haosdent
Copy link
Copy Markdown
Contributor Author

Thx @markmc , let me close mine

@haosdent haosdent closed this Mar 11, 2026
markmc added a commit to markmc/vllm that referenced this pull request Mar 11, 2026
… non-negative amounts"

Since `num_computed_tokens`, `num_cached_tokens`, and
`num_external_computed_tokens` accounting seems quite brittle currently -
with preemption reset bugs and P/D disaggregation accounting issues -
add a defensive check to detect and prevent instances of Prometheus
counter errors:

```
ValueError: Counters can only be incremented by non-negative amounts
```

The invariant check enforces:

```
prompt_len >= num_cached_tokens >= num_external_computed_tokens >= 0
```

with the additional nuance that when all tokens are cached, the scheduler
forces recomputation of the last token, so the:

```
num_external_computed_tokens <= num_cached_tokens + recomputed
```

When the invariant is violated, we log a a warning once with diagnostic
details, and discard suspect cache metrics.

Obviously, the accounting should be fixed and made more robust and
future-proof, at which point we can remove this check (perhaps replacing
with a simple assertion).

Related to issues vllm-project#36533, vllm-project#36755 and PRs vllm-project#36638, vllm-project#36752, vllm-project#36757.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
@markmc markmc moved this from Backlog to In Review in Metrics & Tracing Apr 8, 2026
@markmc markmc moved this from In Review to Not planned in Metrics & Tracing Apr 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working v1

Projects

Status: Not planned

Development

Successfully merging this pull request may close these issues.

3 participants