diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py index 7d902bbc6fc2..d49874adc998 100644 --- a/tests/v1/metrics/test_stats.py +++ b/tests/v1/metrics/test_stats.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.v1.engine import FinishReason -from vllm.v1.metrics.stats import IterationStats, RequestStateStats +from vllm.v1.metrics.stats import IterationStats, PromptTokenStats, RequestStateStats def test_iteration_stats_repr(): @@ -107,3 +107,105 @@ def test_prefill_kv_computed_edge_cases(): finished_req2.num_cached_tokens, 0 ) assert prefill_kv_computed2 == 0 # All cached, nothing computed + + +def test_prompt_token_stats_all_computed(): + """Test all tokens computed locally, no caching.""" + stats = PromptTokenStats() + + # Case 1: No caching (All tokens computed locally) + stats.update_from_output( + num_cached_tokens=0, + num_external_computed_tokens=0, + prompt_len=1000, + ) + + assert stats.computed == 1000 + assert stats.local_cache_hit == 0 + assert stats.external_kv_transfer == 0 + assert stats.total == 1000 + + +def test_prompt_token_stats_partial_local_cache(): + """Test partial local prefix cache hit.""" + stats = PromptTokenStats() + + # Case 2: Partial local cache + stats.update_from_output( + num_cached_tokens=300, + num_external_computed_tokens=0, + prompt_len=1000, + ) + + assert stats.computed == 700 + assert stats.local_cache_hit == 300 + assert stats.external_kv_transfer == 0 + + +def test_prompt_token_stats_partial_external_transfer(): + """Test partial external KV transfer.""" + stats = PromptTokenStats() + + # Case 3: Partial external transfer + stats.update_from_output( + num_cached_tokens=500, + num_external_computed_tokens=500, + prompt_len=1000, + ) + + assert stats.computed == 500 + assert stats.local_cache_hit == 0 + assert stats.external_kv_transfer == 500 + + +def test_prompt_token_stats_mixed_sources(): + """Test mix of local cache and external transfer.""" + stats = PromptTokenStats() + + # Case 4: Mixed sources + stats.update_from_output( + num_cached_tokens=600, + num_external_computed_tokens=200, + prompt_len=1000, + ) + + assert stats.computed == 400 + assert stats.local_cache_hit == 400 + assert stats.external_kv_transfer == 200 + + +def test_prompt_token_stats_full_local_cache_recompute(): + """Test full local cache triggers last token recomputation. + + When all tokens are cached, the scheduler reduces num_cached_tokens by 1 + to force the model to recompute the last token. + """ + stats = PromptTokenStats() + + # Case 5: Full local cache (999 cached after reduction, 1 recomputed) + stats.update_from_output( + num_cached_tokens=999, + num_external_computed_tokens=0, + prompt_len=1000, + ) + + assert stats.computed == 1 + assert stats.local_cache_hit == 1000 + assert stats.recomputed_tokens == 1 + + +def test_prompt_token_stats_full_external_transfer_recompute(): + """Test full external transfer triggers last token recomputation.""" + stats = PromptTokenStats() + + # Case 6: Full external transfer (999 cached after reduction, 1 recomputed) + stats.update_from_output( + num_cached_tokens=999, + num_external_computed_tokens=1000, + prompt_len=1000, + ) + + assert stats.computed == 1 + assert stats.local_cache_hit == 0 + assert stats.external_kv_transfer == 1000 + assert stats.recomputed_tokens == 1 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3f7ac9374e15..b0e1f3bbe8a1 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1378,6 +1378,7 @@ def update_from_output( kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, + num_external_computed_tokens=request.num_external_computed_tokens, routed_experts=routed_experts, num_nans_in_logits=request.num_nans_in_logits, ) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index e8e44746bf47..5328a673554e 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -139,8 +139,10 @@ class EngineCoreOutput( kv_transfer_params: dict[str, Any] | None = None trace_headers: Mapping[str, str] | None = None - # The number of tokens with prefix cache hits. + # The number of tokens with prefix cache hits (local + external). num_cached_tokens: int = 0 + # The number of tokens computed remotely (original count from connector). + num_external_computed_tokens: int = 0 routed_experts: np.ndarray | None = None # The number of NaNs in logits. # A value greater than 0 indicates that the output is corrupted. diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 3a080f01a4d2..49b97e8f37a0 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -25,6 +25,7 @@ CachingMetrics, IterationStats, MultiModalCacheStats, + PromptTokenStats, SchedulerStats, ) from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm @@ -136,7 +137,8 @@ def _enable_perf_stats(self) -> bool: def _track_iteration_stats(self, iteration_stats: IterationStats): # Save tracked stats for token counters. - self.num_prompt_tokens += iteration_stats.num_prompt_tokens + # Use computed tokens for prompt throughput (excludes cached/transferred) + self.num_prompt_tokens += iteration_stats.prompt_token_stats.computed self.num_generation_tokens += iteration_stats.num_generation_tokens self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs self.num_preemptions += iteration_stats.num_preempted_reqs @@ -590,6 +592,41 @@ def __init__( counter_prompt_tokens, engine_indexes, model_name ) + # Labeled prompt token counters by source + counter_prompt_tokens_by_source = self._counter_cls( + name="vllm:prompt_tokens_by_source", + documentation="Number of prompt tokens by source.", + labelnames=labelnames + ["source"], + ) + self.counter_prompt_tokens_by_source: dict[str, dict[int, Counter]] = {} + for source in PromptTokenStats.ALL_SOURCES: + self.counter_prompt_tokens_by_source[source] = { + idx: counter_prompt_tokens_by_source.labels( + model_name, str(idx), source + ) + for idx in engine_indexes + } + + # Cached prompt tokens counter + counter_prompt_tokens_cached = self._counter_cls( + name="vllm:prompt_tokens_cached", + documentation="Number of cached prompt tokens (local + external).", + labelnames=labelnames, + ) + self.counter_prompt_tokens_cached = make_per_engine( + counter_prompt_tokens_cached, engine_indexes, model_name + ) + + # Recomputed tokens (last token recomputed when entire prompt is cached) + counter_prompt_tokens_recomputed = self._counter_cls( + name="vllm:prompt_tokens_recomputed", + documentation="Number of cached tokens recomputed for forward pass.", + labelnames=labelnames, + ) + self.counter_prompt_tokens_recomputed = make_per_engine( + counter_prompt_tokens_recomputed, engine_indexes, model_name + ) + counter_generation_tokens = self._counter_cls( name="vllm:generation_tokens", documentation="Number of generation tokens processed.", @@ -1070,6 +1107,14 @@ def record( iteration_stats.num_preempted_reqs ) self.counter_prompt_tokens[engine_idx].inc(iteration_stats.num_prompt_tokens) + # Labeled prompt token counters by source + pts = iteration_stats.prompt_token_stats + for source in PromptTokenStats.ALL_SOURCES: + self.counter_prompt_tokens_by_source[source][engine_idx].inc( + pts.get_by_source(source) + ) + self.counter_prompt_tokens_cached[engine_idx].inc(pts.cached_tokens) + self.counter_prompt_tokens_recomputed[engine_idx].inc(pts.recomputed_tokens) self.counter_generation_tokens[engine_idx].inc( iteration_stats.num_generation_tokens ) diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 3404a720e968..1b7ee105ebf2 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -231,13 +231,76 @@ class FinishedRequestStats: num_cached_tokens: int = 0 +@dataclass +class PromptTokenStats: + """Breakdown of prompt tokens by source. + + Fields: + computed: Tokens prefilled locally (actual compute work). + local_cache_hit: Tokens from local prefix cache. + external_kv_transfer: Tokens from external KV transfer. + cached_tokens: Tokens skipped during prefill (from scheduler). + recomputed_tokens: Cached tokens that were recomputed (see below). + total: Total prompt tokens. + + Invariants: + computed + local_cache_hit + external_kv_transfer - recomputed_tokens = total + local_cache_hit + external_kv_transfer - recomputed_tokens = cached_tokens + """ + + ALL_SOURCES: tuple[str, ...] = ( + "local_compute", + "local_cache_hit", + "external_kv_transfer", + ) + + computed: int = 0 + local_cache_hit: int = 0 + external_kv_transfer: int = 0 + cached_tokens: int = 0 + recomputed_tokens: int = 0 + total: int = 0 + + def update_from_output( + self, + num_cached_tokens: int, + num_external_computed_tokens: int, + prompt_len: int, + ) -> None: + """Update stats from a prefill output.""" + # When all tokens are cached, the scheduler reduces num_cached_tokens + # by 1 to force the model to recompute the last token, since the model + # needs at least one input token to run a forward pass. + recomputed = 1 if (num_cached_tokens + 1 == prompt_len) else 0 + + self.computed += prompt_len - num_cached_tokens + self.external_kv_transfer += num_external_computed_tokens + self.local_cache_hit += ( + num_cached_tokens + recomputed - num_external_computed_tokens + ) + self.cached_tokens += num_cached_tokens + self.recomputed_tokens += recomputed + self.total += prompt_len + + def get_by_source(self, source: str) -> int: + """Get token count by source label.""" + source_map = { + "local_compute": self.computed, + "local_cache_hit": self.local_cache_hit, + "external_kv_transfer": self.external_kv_transfer, + } + if source not in source_map: + raise ValueError(f"Unknown source: {source}") + return source_map[source] + + class IterationStats: """Stats associated with a single set of EngineCoreOutputs.""" def __init__(self): self.iteration_timestamp = time.time() self.num_generation_tokens = 0 - self.num_prompt_tokens = 0 + self.prompt_token_stats = PromptTokenStats() self.num_preempted_reqs = 0 self.finished_requests: list[FinishedRequestStats] = [] self.max_num_generation_tokens_iter: list[int] = [] @@ -250,6 +313,11 @@ def __repr__(self) -> str: field_to_value_str = ", ".join(f"{k}={v}" for k, v in vars(self).items()) return f"{self.__class__.__name__}({field_to_value_str})" + @property + def num_prompt_tokens(self) -> int: + """Total prompt tokens (for backward compatibility).""" + return self.prompt_token_stats.total + def _time_since(self, start: float) -> float: """Calculate an interval relative to this iteration's timestamp.""" return self.iteration_timestamp - start @@ -268,7 +336,11 @@ def update_from_output( self.num_generation_tokens += num_new_generation_tokens if is_prefilling: - self.num_prompt_tokens += prompt_len + self.prompt_token_stats.update_from_output( + num_cached_tokens=output.num_cached_tokens, + num_external_computed_tokens=output.num_external_computed_tokens, + prompt_len=prompt_len, + ) first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency)