Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 103 additions & 1 deletion tests/v1/metrics/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
from vllm.v1.metrics.stats import IterationStats, PromptTokenStats, RequestStateStats


def test_iteration_stats_repr():
Expand Down Expand Up @@ -107,3 +107,105 @@ def test_prefill_kv_computed_edge_cases():
finished_req2.num_cached_tokens, 0
)
assert prefill_kv_computed2 == 0 # All cached, nothing computed


def test_prompt_token_stats_all_computed():
"""Test all tokens computed locally, no caching."""
stats = PromptTokenStats()

# Case 1: No caching (All tokens computed locally)
stats.update_from_output(
num_cached_tokens=0,
num_external_computed_tokens=0,
prompt_len=1000,
)

assert stats.computed == 1000
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 0
assert stats.total == 1000


def test_prompt_token_stats_partial_local_cache():
"""Test partial local prefix cache hit."""
stats = PromptTokenStats()

# Case 2: Partial local cache
stats.update_from_output(
num_cached_tokens=300,
num_external_computed_tokens=0,
prompt_len=1000,
)

assert stats.computed == 700
assert stats.local_cache_hit == 300
assert stats.external_kv_transfer == 0


def test_prompt_token_stats_partial_external_transfer():
"""Test partial external KV transfer."""
stats = PromptTokenStats()

# Case 3: Partial external transfer
stats.update_from_output(
num_cached_tokens=500,
num_external_computed_tokens=500,
prompt_len=1000,
)

assert stats.computed == 500
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 500


def test_prompt_token_stats_mixed_sources():
"""Test mix of local cache and external transfer."""
stats = PromptTokenStats()

# Case 4: Mixed sources
stats.update_from_output(
num_cached_tokens=600,
num_external_computed_tokens=200,
prompt_len=1000,
)

assert stats.computed == 400
assert stats.local_cache_hit == 400
assert stats.external_kv_transfer == 200


def test_prompt_token_stats_full_local_cache_recompute():
"""Test full local cache triggers last token recomputation.

When all tokens are cached, the scheduler reduces num_cached_tokens by 1
to force the model to recompute the last token.
"""
stats = PromptTokenStats()

# Case 5: Full local cache (999 cached after reduction, 1 recomputed)
stats.update_from_output(
num_cached_tokens=999,
num_external_computed_tokens=0,
prompt_len=1000,
)

assert stats.computed == 1
assert stats.local_cache_hit == 1000
assert stats.recomputed_tokens == 1


def test_prompt_token_stats_full_external_transfer_recompute():
"""Test full external transfer triggers last token recomputation."""
stats = PromptTokenStats()

# Case 6: Full external transfer (999 cached after reduction, 1 recomputed)
stats.update_from_output(
num_cached_tokens=999,
num_external_computed_tokens=1000,
prompt_len=1000,
)

assert stats.computed == 1
assert stats.local_cache_hit == 0
assert stats.external_kv_transfer == 1000
assert stats.recomputed_tokens == 1
1 change: 1 addition & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,7 @@ def update_from_output(
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
num_external_computed_tokens=request.num_external_computed_tokens,
routed_experts=routed_experts,
num_nans_in_logits=request.num_nans_in_logits,
)
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ class EngineCoreOutput(
kv_transfer_params: dict[str, Any] | None = None

trace_headers: Mapping[str, str] | None = None
# The number of tokens with prefix cache hits.
# The number of tokens with prefix cache hits (local + external).
num_cached_tokens: int = 0
# The number of tokens computed remotely (original count from connector).
num_external_computed_tokens: int = 0
routed_experts: np.ndarray | None = None
# The number of NaNs in logits.
# A value greater than 0 indicates that the output is corrupted.
Expand Down
47 changes: 46 additions & 1 deletion vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CachingMetrics,
IterationStats,
MultiModalCacheStats,
PromptTokenStats,
SchedulerStats,
)
from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
Expand Down Expand Up @@ -136,7 +137,8 @@ def _enable_perf_stats(self) -> bool:

def _track_iteration_stats(self, iteration_stats: IterationStats):
# Save tracked stats for token counters.
self.num_prompt_tokens += iteration_stats.num_prompt_tokens
# Use computed tokens for prompt throughput (excludes cached/transferred)
self.num_prompt_tokens += iteration_stats.prompt_token_stats.computed
self.num_generation_tokens += iteration_stats.num_generation_tokens
self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs
self.num_preemptions += iteration_stats.num_preempted_reqs
Expand Down Expand Up @@ -590,6 +592,41 @@ def __init__(
counter_prompt_tokens, engine_indexes, model_name
)

# Labeled prompt token counters by source
counter_prompt_tokens_by_source = self._counter_cls(
name="vllm:prompt_tokens_by_source",
documentation="Number of prompt tokens by source.",
labelnames=labelnames + ["source"],
)
self.counter_prompt_tokens_by_source: dict[str, dict[int, Counter]] = {}
for source in PromptTokenStats.ALL_SOURCES:
self.counter_prompt_tokens_by_source[source] = {
idx: counter_prompt_tokens_by_source.labels(
model_name, str(idx), source
)
for idx in engine_indexes
}

# Cached prompt tokens counter
counter_prompt_tokens_cached = self._counter_cls(
name="vllm:prompt_tokens_cached",
documentation="Number of cached prompt tokens (local + external).",
labelnames=labelnames,
)
self.counter_prompt_tokens_cached = make_per_engine(
counter_prompt_tokens_cached, engine_indexes, model_name
)

# Recomputed tokens (last token recomputed when entire prompt is cached)
counter_prompt_tokens_recomputed = self._counter_cls(
name="vllm:prompt_tokens_recomputed",
documentation="Number of cached tokens recomputed for forward pass.",
labelnames=labelnames,
)
self.counter_prompt_tokens_recomputed = make_per_engine(
counter_prompt_tokens_recomputed, engine_indexes, model_name
)

counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens",
documentation="Number of generation tokens processed.",
Expand Down Expand Up @@ -1070,6 +1107,14 @@ def record(
iteration_stats.num_preempted_reqs
)
self.counter_prompt_tokens[engine_idx].inc(iteration_stats.num_prompt_tokens)
# Labeled prompt token counters by source
pts = iteration_stats.prompt_token_stats
for source in PromptTokenStats.ALL_SOURCES:
self.counter_prompt_tokens_by_source[source][engine_idx].inc(
pts.get_by_source(source)
)
self.counter_prompt_tokens_cached[engine_idx].inc(pts.cached_tokens)
self.counter_prompt_tokens_recomputed[engine_idx].inc(pts.recomputed_tokens)
self.counter_generation_tokens[engine_idx].inc(
iteration_stats.num_generation_tokens
)
Expand Down
76 changes: 74 additions & 2 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,76 @@ class FinishedRequestStats:
num_cached_tokens: int = 0


@dataclass
class PromptTokenStats:
"""Breakdown of prompt tokens by source.

Fields:
computed: Tokens prefilled locally (actual compute work).
local_cache_hit: Tokens from local prefix cache.
external_kv_transfer: Tokens from external KV transfer.
cached_tokens: Tokens skipped during prefill (from scheduler).
recomputed_tokens: Cached tokens that were recomputed (see below).
total: Total prompt tokens.

Invariants:
computed + local_cache_hit + external_kv_transfer - recomputed_tokens = total
local_cache_hit + external_kv_transfer - recomputed_tokens = cached_tokens
"""

ALL_SOURCES: tuple[str, ...] = (
"local_compute",
"local_cache_hit",
"external_kv_transfer",
)

computed: int = 0
local_cache_hit: int = 0
external_kv_transfer: int = 0
cached_tokens: int = 0
recomputed_tokens: int = 0
total: int = 0

def update_from_output(
self,
num_cached_tokens: int,
num_external_computed_tokens: int,
prompt_len: int,
) -> None:
"""Update stats from a prefill output."""
# When all tokens are cached, the scheduler reduces num_cached_tokens
# by 1 to force the model to recompute the last token, since the model
# needs at least one input token to run a forward pass.
recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0

self.computed += prompt_len - num_cached_tokens
self.external_kv_transfer += num_external_computed_tokens
self.local_cache_hit += (
num_cached_tokens + recomputed - num_external_computed_tokens
)
self.cached_tokens += num_cached_tokens
self.recomputed_tokens += recomputed
self.total += prompt_len

def get_by_source(self, source: str) -> int:
"""Get token count by source label."""
source_map = {
"local_compute": self.computed,
"local_cache_hit": self.local_cache_hit,
"external_kv_transfer": self.external_kv_transfer,
}
if source not in source_map:
raise ValueError(f"Unknown source: {source}")
return source_map[source]


class IterationStats:
"""Stats associated with a single set of EngineCoreOutputs."""

def __init__(self):
self.iteration_timestamp = time.time()
self.num_generation_tokens = 0
self.num_prompt_tokens = 0
self.prompt_token_stats = PromptTokenStats()
self.num_preempted_reqs = 0
self.finished_requests: list[FinishedRequestStats] = []
self.max_num_generation_tokens_iter: list[int] = []
Expand All @@ -250,6 +313,11 @@ def __repr__(self) -> str:
field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items())
return f"{self.__class__.__name__}({field_to_value_str})"

@property
def num_prompt_tokens(self) -> int:
"""Total prompt tokens (for backward compatibility)."""
return self.prompt_token_stats.total

def _time_since(self, start: float) -> float:
"""Calculate an interval relative to this iteration's timestamp."""
return self.iteration_timestamp - start
Expand All @@ -268,7 +336,11 @@ def update_from_output(

self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
self.num_prompt_tokens += prompt_len
self.prompt_token_stats.update_from_output(
num_cached_tokens=output.num_cached_tokens,
num_external_computed_tokens=output.num_external_computed_tokens,
prompt_len=prompt_len,
)

first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
Expand Down