From a167d11e7e234e3252d5d86c21e6c44c3207d1d7 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Wed, 1 Apr 2026 04:49:53 -0400 Subject: [PATCH] [Core][Metrics] Remove `vllm:prompt_tokens_recomputed` metric In the case of a full local prefix cache hit (prompt length N), we actually only use N-1 tokens. The `vllm:prompt_tokens_recomputed` was intended to count how many cached tokens we are effectively discarding because of this. ``` KVCacheManager.get_computed_blocks(): ... # NOTE: When all tokens hit the cache, we must recompute the last token # to obtain logits. [...] max_cache_hit_length = request.num_tokens - 1 ``` However, even here, we can't assume the last token would have been a cache hit and should be counted as "recomputed". Given this, the metric seems quite misguided, in retrospect. The metric was added as a side-effect in #33290 in order to make sense of the fact that: ``` vllm:prompt_tokens_by_source_total{source="external_kv_transfer"} ``` will include a token that is recomputed. See this comment: > Note: external_kv_transfer reports the actual number of tokens > transferred (e.g., prompt length N), while prompt_tokens_cached_total > reports the adjusted count (e.g., N-1). The last token is both > transferred AND recomputed locally, so there's overlap. However, it makes more sense for the `external_kv_transfer` count to reflect only tokens we actually used, not any recomputed tokens. This will be done in ##37460. I'm not aware of any user demand for this metric, or anyone relying on it now. So it seems safe to remove it, rather than go through a deprecation period. Signed-off-by: Mark McLoughlin --- tests/v1/metrics/test_stats.py | 8 +++----- vllm/v1/metrics/loggers.py | 11 ----------- vllm/v1/metrics/stats.py | 14 +++----------- 3 files changed, 6 insertions(+), 27 deletions(-) diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py index d49874adc998..48f6caefdbff 100644 --- a/tests/v1/metrics/test_stats.py +++ b/tests/v1/metrics/test_stats.py @@ -190,8 +190,7 @@ def test_prompt_token_stats_full_local_cache_recompute(): ) assert stats.computed == 1 - assert stats.local_cache_hit == 1000 - assert stats.recomputed_tokens == 1 + assert stats.local_cache_hit == 999 def test_prompt_token_stats_full_external_transfer_recompute(): @@ -201,11 +200,10 @@ def test_prompt_token_stats_full_external_transfer_recompute(): # Case 6: Full external transfer (999 cached after reduction, 1 recomputed) stats.update_from_output( num_cached_tokens=999, - num_external_computed_tokens=1000, + num_external_computed_tokens=999, prompt_len=1000, ) assert stats.computed == 1 assert stats.local_cache_hit == 0 - assert stats.external_kv_transfer == 1000 - assert stats.recomputed_tokens == 1 + assert stats.external_kv_transfer == 999 diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 5d5877d1692e..e85f6a75892a 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -622,16 +622,6 @@ def __init__( counter_prompt_tokens_cached, per_engine_labelvalues ) - # Recomputed tokens (last token recomputed when entire prompt is cached) - counter_prompt_tokens_recomputed = self._counter_cls( - name="vllm:prompt_tokens_recomputed", - documentation="Number of cached tokens recomputed for forward pass.", - labelnames=labelnames, - ) - self.counter_prompt_tokens_recomputed = create_metric_per_engine( - counter_prompt_tokens_recomputed, per_engine_labelvalues - ) - counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens", documentation="Number of generation tokens processed.", @@ -1122,7 +1112,6 @@ def record( pts.get_by_source(source) ) self.counter_prompt_tokens_cached[engine_idx].inc(pts.cached_tokens) - self.counter_prompt_tokens_recomputed[engine_idx].inc(pts.recomputed_tokens) self.counter_generation_tokens[engine_idx].inc( iteration_stats.num_generation_tokens ) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 45f002e01edb..79955815d582 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -246,12 +246,11 @@ class PromptTokenStats: local_cache_hit: Tokens from local prefix cache. external_kv_transfer: Tokens from external KV transfer. cached_tokens: Tokens skipped during prefill (from scheduler). - recomputed_tokens: Cached tokens that were recomputed (see below). total: Total prompt tokens. Invariants: - computed + local_cache_hit + external_kv_transfer - recomputed_tokens = total - local_cache_hit + external_kv_transfer - recomputed_tokens = cached_tokens + computed + local_cache_hit + external_kv_transfer = total + local_cache_hit + external_kv_transfer = cached_tokens """ ALL_SOURCES: tuple[str, ...] = ( @@ -264,7 +263,6 @@ class PromptTokenStats: local_cache_hit: int = 0 external_kv_transfer: int = 0 cached_tokens: int = 0 - recomputed_tokens: int = 0 total: int = 0 def update_from_output( @@ -274,11 +272,6 @@ def update_from_output( prompt_len: int, ) -> None: """Update stats from a prefill output.""" - # When all tokens are cached, the scheduler reduces num_cached_tokens - # by 1 to force the model to recompute the last token, since the model - # needs at least one input token to run a forward pass. - recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0 - self.computed += prompt_len - num_cached_tokens self.external_kv_transfer += num_external_computed_tokens # FIXME(yifan): local_cache_hit can go negative after preemption. @@ -290,10 +283,9 @@ def update_from_output( # as a separate metric rather than reusing num_external_computed_tokens # for metric directly. self.local_cache_hit += max( - 0, (num_cached_tokens + recomputed - num_external_computed_tokens) + 0, (num_cached_tokens - num_external_computed_tokens) ) self.cached_tokens += num_cached_tokens - self.recomputed_tokens += recomputed self.total += prompt_len def get_by_source(self, source: str) -> int: