Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/v1/metrics/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ def test_prompt_token_stats_full_local_cache_recompute():
)

assert stats.computed == 1
assert stats.local_cache_hit == 1000
assert stats.recomputed_tokens == 1
assert stats.local_cache_hit == 999


def test_prompt_token_stats_full_external_transfer_recompute():
Expand All @@ -201,11 +200,10 @@ def test_prompt_token_stats_full_external_transfer_recompute():
# Case 6: Full external transfer (999 cached after reduction, 1 recomputed)
stats.update_from_output(
num_cached_tokens=999,
num_external_computed_tokens=1000,
num_external_computed_tokens=999,
prompt_len=1000,
)

assert stats.computed == 1
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 1000
assert stats.recomputed_tokens == 1
assert stats.external_kv_transfer == 999
11 changes: 0 additions & 11 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,16 +622,6 @@ def __init__(
counter_prompt_tokens_cached, per_engine_labelvalues
)

# Recomputed tokens (last token recomputed when entire prompt is cached)
counter_prompt_tokens_recomputed = self._counter_cls(
name="vllm:prompt_tokens_recomputed",
documentation="Number of cached tokens recomputed for forward pass.",
labelnames=labelnames,
)
self.counter_prompt_tokens_recomputed = create_metric_per_engine(
counter_prompt_tokens_recomputed, per_engine_labelvalues
)

counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens",
documentation="Number of generation tokens processed.",
Expand Down Expand Up @@ -1122,7 +1112,6 @@ def record(
pts.get_by_source(source)
)
self.counter_prompt_tokens_cached[engine_idx].inc(pts.cached_tokens)
self.counter_prompt_tokens_recomputed[engine_idx].inc(pts.recomputed_tokens)
self.counter_generation_tokens[engine_idx].inc(
iteration_stats.num_generation_tokens
)
Expand Down
14 changes: 3 additions & 11 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,11 @@ class PromptTokenStats:
local_cache_hit: Tokens from local prefix cache.
external_kv_transfer: Tokens from external KV transfer.
cached_tokens: Tokens skipped during prefill (from scheduler).
recomputed_tokens: Cached tokens that were recomputed (see below).
total: Total prompt tokens.

Invariants:
computed + local_cache_hit + external_kv_transfer - recomputed_tokens = total
local_cache_hit + external_kv_transfer - recomputed_tokens = cached_tokens
computed + local_cache_hit + external_kv_transfer = total
local_cache_hit + external_kv_transfer = cached_tokens
"""

ALL_SOURCES: tuple[str, ...] = (
Expand All @@ -264,7 +263,6 @@ class PromptTokenStats:
local_cache_hit: int = 0
external_kv_transfer: int = 0
cached_tokens: int = 0
recomputed_tokens: int = 0
total: int = 0

def update_from_output(
Expand All @@ -274,11 +272,6 @@ def update_from_output(
prompt_len: int,
) -> None:
"""Update stats from a prefill output."""
# When all tokens are cached, the scheduler reduces num_cached_tokens
# by 1 to force the model to recompute the last token, since the model
# needs at least one input token to run a forward pass.
recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0

self.computed += prompt_len - num_cached_tokens
self.external_kv_transfer += num_external_computed_tokens
# FIXME(yifan): local_cache_hit can go negative after preemption.
Expand All @@ -290,10 +283,9 @@ def update_from_output(
# as a separate metric rather than reusing num_external_computed_tokens
# for metric directly.
self.local_cache_hit += max(
0, (num_cached_tokens + recomputed - num_external_computed_tokens)
0, (num_cached_tokens - num_external_computed_tokens)
)
self.cached_tokens += num_cached_tokens
self.recomputed_tokens += recomputed
self.total += prompt_len

def get_by_source(self, source: str) -> int:
Expand Down
Loading