Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion tests/v1/metrics/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.metrics.stats import IterationStats
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.stats import IterationStats, RequestStateStats


def test_iteration_stats_repr():
iteration_stats = IterationStats()
assert repr(iteration_stats).startswith("IterationStats(")


def test_prefill_kv_computed_with_cache():
"""Test that prefill KV compute correctly excludes cached tokens."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 5.0
req_stats.num_generation_tokens = 50

# Case 1: With prefix cache (1200 tokens cached)
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=10000,
max_tokens_param=100,
req_stats=req_stats,
num_cached_tokens=1200,
)

finished_req = iteration_stats.finished_requests[0]
assert finished_req.num_prompt_tokens == 10000
assert finished_req.num_cached_tokens == 1200

# Verify calculation: prefill KV = prompt tokens - cached tokens
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 8800 # 10000 - 1200


def test_prefill_kv_computed_no_cache():
"""Test prefill KV compute without prefix caching."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 2.0
req_stats.num_generation_tokens = 10

# Case 2: No prefix cache
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=2000,
max_tokens_param=100,
req_stats=req_stats,
num_cached_tokens=0,
)

finished_req = iteration_stats.finished_requests[0]
assert finished_req.num_prompt_tokens == 2000
assert finished_req.num_cached_tokens == 0

# Verify calculation: prefill KV = full prompt when no cache
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 2000


def test_prefill_kv_computed_edge_cases():
"""Test edge cases for prefill KV compute calculation."""
iteration_stats = IterationStats()
req_stats = RequestStateStats(arrival_time=0.0)
req_stats.scheduled_ts = 0.1
req_stats.first_token_ts = 0.5
req_stats.last_token_ts = 1.0
req_stats.num_generation_tokens = 1

# Case 3: Negative num_cached_tokens (shouldn't happen, but handle gracefully)
iteration_stats.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=100,
max_tokens_param=10,
req_stats=req_stats,
num_cached_tokens=-1,
)

finished_req = iteration_stats.finished_requests[0]
# max() should handle negative values
prefill_kv_computed = finished_req.num_prompt_tokens - max(
finished_req.num_cached_tokens, 0
)
assert prefill_kv_computed == 100 # Should treat negative as 0

# Case 4: All tokens cached (shouldn't happen in practice)
iteration_stats2 = IterationStats()
iteration_stats2.update_from_finished_request(
finish_reason=FinishReason.STOP,
num_prompt_tokens=100,
max_tokens_param=10,
req_stats=req_stats,
num_cached_tokens=100,
)

finished_req2 = iteration_stats2.finished_requests[0]
prefill_kv_computed2 = finished_req2.num_prompt_tokens - max(
finished_req2.num_cached_tokens, 0
)
assert prefill_kv_computed2 == 0 # All cached, nothing computed
1 change: 1 addition & 0 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ def _update_stats_from_finished(
),
max_tokens_param=req_state.max_tokens_param,
req_stats=req_state.stats,
num_cached_tokens=req_state.num_cached_tokens,
)
self.lora_states.request_finished(req_state.request_id, req_state.lora_name)

Expand Down
20 changes: 20 additions & 0 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,19 @@ def __init__(
histogram_decode_time_request, engine_indexes, model_name
)

histogram_prefill_kv_computed_request = self._histogram_cls(
name="vllm:request_prefill_kv_computed_tokens",
documentation=(
"Histogram of new KV tokens computed during prefill "
"(excluding cached tokens)."
),
buckets=build_1_2_5_buckets(max_model_len),
labelnames=labelnames,
)
self.histogram_prefill_kv_computed_request = make_per_engine(
histogram_prefill_kv_computed_request, engine_indexes, model_name
)

#
# KV Cache residency metrics
#
Expand Down Expand Up @@ -1115,6 +1128,13 @@ def record(
self.histogram_decode_time_request[engine_idx].observe(
finished_request.decode_time
)
# Calculate prefill KV compute (excludes cached tokens)
prefill_kv_computed = finished_request.num_prompt_tokens - max(
finished_request.num_cached_tokens, 0
)
self.histogram_prefill_kv_computed_request[engine_idx].observe(
prefill_kv_computed
)
self.histogram_num_prompt_tokens_request[engine_idx].observe(
finished_request.num_prompt_tokens
)
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ class FinishedRequestStats:
decode_time: float = 0.0
mean_time_per_output_token: float = 0.0
is_corrupted: bool = False
num_cached_tokens: int = 0


class IterationStats:
Expand Down Expand Up @@ -330,6 +331,7 @@ def update_from_finished_request(
num_prompt_tokens: int,
max_tokens_param: int | None,
req_stats: RequestStateStats,
num_cached_tokens: int = 0,
):
e2e_latency = self._time_since(req_stats.arrival_time)

Expand Down Expand Up @@ -367,6 +369,7 @@ def update_from_finished_request(
decode_time=decode_time,
mean_time_per_output_token=mean_time_per_output_token,
is_corrupted=req_stats.is_corrupted,
num_cached_tokens=num_cached_tokens,
)
self.finished_requests.append(finished_req)

Expand Down