diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py index d49874adc998..fd47bbba8c70 100644 --- a/tests/v1/metrics/test_stats.py +++ b/tests/v1/metrics/test_stats.py @@ -1,7 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.v1.engine import FinishReason -from vllm.v1.metrics.stats import IterationStats, PromptTokenStats, RequestStateStats +from vllm.v1.engine import EngineCoreOutput, FinishReason +from vllm.v1.metrics.stats import ( + CompletedTiming, + IterationStats, + LoRARequestStates, + PromptTokenStats, + RequestStateStats, + ScheduledTiming, +) def test_iteration_stats_repr(): @@ -209,3 +216,101 @@ def test_prompt_token_stats_full_external_transfer_recompute(): assert stats.local_cache_hit == 0 assert stats.external_kv_transfer == 1000 assert stats.recomputed_tokens == 1 + + +def test_no_tokens_generated_during_prefill_no_timestamp_update(): + """When prefill produces no tokens (e.g., KV load failure), don't update timestamps. + + This can happen when is_prefilling=True but the request fails to generate + tokens due to KV-cache load failures or other errors. + """ + stats = IterationStats() + req_stats = RequestStateStats(arrival_time=100.0) + lora_states = LoRARequestStates(log_stats=False) + + # call update_from_output with is_prefilling=True but no tokens + stats.update_from_output( + output=EngineCoreOutput(request_id="test-req", new_token_ids=[]), + engine_core_timestamp=200.0, + is_prefilling=True, + prompt_len=100, + req_stats=req_stats, + lora_states=lora_states, + lora_name=None, + ) + + # regression assertions + assert req_stats.first_token_ts == 0.0, ( + "first_token_ts should not be set when no tokens are generated" + ) + assert req_stats.last_token_ts == 0.0, ( + "last_token_ts should not be set when no tokens are generated" + ) + assert len(stats.time_to_first_tokens_iter) == 0, ( + "No TTFT should be recorded when no tokens are generated" + ) + + +def test_no_tokens_generated_during_decode_no_itl(): + """When decode produces no tokens, don't calculate ITL or update timestamps.""" + stats = IterationStats() + req_stats = RequestStateStats(arrival_time=100.0) + req_stats.first_token_ts = 150.0 # Already got first token + req_stats.last_token_ts = 150.0 + lora_states = LoRARequestStates(log_stats=False) + + # call update_from_output with is_prefilling=False (decode) but no tokens + stats.update_from_output( + output=EngineCoreOutput(request_id="test-req", new_token_ids=[]), + engine_core_timestamp=200.0, + is_prefilling=False, + prompt_len=100, + req_stats=req_stats, + lora_states=lora_states, + lora_name=None, + ) + + # regression assertions + assert req_stats.last_token_ts == 150.0, ( + "last_token_ts should not change when no tokens are generated" + ) + assert len(stats.inter_token_latencies_iter) == 0, ( + "No ITL should be recorded when no tokens are generated" + ) + + +def test_timing_reflects_request_progress(): + """Timing type reflects how far request progressed before finishing.""" + # completed request -> CompletedTiming + req = RequestStateStats(arrival_time=0.0) + req.queued_ts = 0.1 + req.scheduled_ts = 0.2 + req.first_token_ts = 0.5 + req.last_token_ts = 2.0 + req.num_generation_tokens = 10 + + stats = IterationStats() + stats.update_from_finished_request( + FinishReason.STOP, num_prompt_tokens=100, max_tokens_param=50, req_stats=req + ) + assert isinstance(stats.finished_requests[0].timing, CompletedTiming) + + # scheduled but aborted -> ScheduledTiming + req2 = RequestStateStats(arrival_time=0.0) + req2.queued_ts = 0.1 + req2.scheduled_ts = 0.2 + + stats2 = IterationStats() + stats2.update_from_finished_request( + FinishReason.ABORT, num_prompt_tokens=100, max_tokens_param=50, req_stats=req2 + ) + assert isinstance(stats2.finished_requests[0].timing, ScheduledTiming) + + # rejected before scheduling -> None + req3 = RequestStateStats(arrival_time=0.0) + + stats3 = IterationStats() + stats3.update_from_finished_request( + FinishReason.ABORT, num_prompt_tokens=100, max_tokens_param=50, req_stats=req3 + ) + assert stats3.finished_requests[0].timing is None diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 49b97e8f37a0..dd917fab5dec 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -23,6 +23,7 @@ from vllm.v1.metrics.prometheus import unregister_vllm_metrics from vllm.v1.metrics.stats import ( CachingMetrics, + CompletedTiming, IterationStats, MultiModalCacheStats, PromptTokenStats, @@ -1140,39 +1141,45 @@ def record( self.histogram_e2e_time_request[engine_idx].observe( finished_request.e2e_latency ) - self.histogram_queue_time_request[engine_idx].observe( - finished_request.queued_time - ) - self.histogram_prefill_time_request[engine_idx].observe( - finished_request.prefill_time - ) - self.histogram_inference_time_request[engine_idx].observe( - finished_request.inference_time - ) - self.histogram_decode_time_request[engine_idx].observe( - finished_request.decode_time - ) - # Calculate prefill KV compute (excludes cached tokens) - prefill_kv_computed = finished_request.num_prompt_tokens - max( - finished_request.num_cached_tokens, 0 - ) - self.histogram_prefill_kv_computed_request[engine_idx].observe( - prefill_kv_computed - ) self.histogram_num_prompt_tokens_request[engine_idx].observe( finished_request.num_prompt_tokens ) self.histogram_num_generation_tokens_request[engine_idx].observe( finished_request.num_generation_tokens ) - self.histogram_request_time_per_output_token[engine_idx].observe( - finished_request.mean_time_per_output_token - ) if finished_request.max_tokens_param: self.histogram_max_tokens_request[engine_idx].observe( finished_request.max_tokens_param ) + timing = finished_request.timing + if timing is not None: + self.histogram_queue_time_request[engine_idx].observe( + timing.queued_time + ) + + if isinstance(timing, CompletedTiming): + self.histogram_prefill_time_request[engine_idx].observe( + timing.prefill_time + ) + self.histogram_inference_time_request[engine_idx].observe( + timing.inference_time + ) + self.histogram_decode_time_request[engine_idx].observe( + timing.decode_time + ) + # prefill KV compute (excludes cached tokens) + prefill_kv_computed = finished_request.num_prompt_tokens - max( + finished_request.num_cached_tokens, 0 + ) + self.histogram_prefill_kv_computed_request[engine_idx].observe( + prefill_kv_computed + ) + if timing.mean_time_per_output_token is not None: + self.histogram_request_time_per_output_token[engine_idx].observe( + timing.mean_time_per_output_token + ) + def record_sleep_state(self, sleep: int = 0, level: int = 0): awake = 1 discard_all = 0 diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 1b7ee105ebf2..098b291be771 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -213,20 +213,49 @@ class RequestStateStats: is_corrupted: bool = False +@dataclass +class ScheduledTiming: + """Timing for a request that was scheduled but didn't generate tokens. + + This occurs when a request is aborted during prefill or fails due to + errors like KV load failures. + """ + + queued_time: float + + +@dataclass +class CompletedTiming: + """Timing for a request that generated at least one token.""" + + queued_time: float + prefill_time: float + decode_time: float + inference_time: float + # None if request generated only a single token + mean_time_per_output_token: float | None = None + + +RequestTiming = ScheduledTiming | CompletedTiming | None + + @dataclass class FinishedRequestStats: - """Stats associated with a finished request.""" + """Stats associated with a finished request. + + The timing field uses its type to encode how far the + request progressed before finishing: + - None: rejected before scheduling (e.g., client abort before scheduling) + - ScheduledTiming: scheduled but no tokens generated (abort/error during prefill) + - CompletedTiming: generated at least one token (normal completion) + """ finish_reason: "FinishReason" e2e_latency: float = 0.0 num_prompt_tokens: int = 0 num_generation_tokens: int = 0 max_tokens_param: int | None = None - queued_time: float = 0.0 - prefill_time: float = 0.0 - inference_time: float = 0.0 - decode_time: float = 0.0 - mean_time_per_output_token: float = 0.0 + timing: RequestTiming = None is_corrupted: bool = False num_cached_tokens: int = 0 @@ -342,6 +371,10 @@ def update_from_output( prompt_len=prompt_len, ) + # Only record first token latency when a token was actually generated. + # is_prefilling can be True even when no tokens are produced (e.g., + # KV-cache load failures, aborts during prefill). + if is_prefilling and num_new_generation_tokens > 0: first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency) req_stats.first_token_latency = first_token_latency @@ -368,14 +401,16 @@ def update_from_output( lora_name, ) - # Process the batch-level "new tokens" engine core event - if is_prefilling: - req_stats.first_token_ts = engine_core_timestamp - else: - itl = engine_core_timestamp - req_stats.last_token_ts - self.inter_token_latencies_iter.append(itl) + # Process the batch-level "new tokens" engine core event. + # Only update timestamps when tokens were actually generated. + if num_new_generation_tokens > 0: + if is_prefilling: + req_stats.first_token_ts = engine_core_timestamp + else: + itl = engine_core_timestamp - req_stats.last_token_ts + self.inter_token_latencies_iter.append(itl) - req_stats.last_token_ts = engine_core_timestamp + req_stats.last_token_ts = engine_core_timestamp def update_from_events( self, @@ -411,27 +446,49 @@ def update_from_finished_request( ): e2e_latency = self._time_since(req_stats.arrival_time) - # Queued interval is from first QUEUED event to first SCHEDULED - queued_time = req_stats.scheduled_ts - req_stats.queued_ts - - # Prefill interval is from first SCHEDULED to first NEW_TOKEN - # Any preemptions during prefill is included in the interval - prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts - - # Decode interval is from first NEW_TOKEN to last NEW_TOKEN - # Any preemptions during decode are included - decode_time = req_stats.last_token_ts - req_stats.first_token_ts - - # Inference interval is from first SCHEDULED to last NEW_TOKEN - # Any preemptions during prefill or decode are included - inference_time = req_stats.last_token_ts - req_stats.scheduled_ts - - # Do not count the token generated by the prefill phase - mean_time_per_output_token = ( - decode_time / (req_stats.num_generation_tokens - 1) - if req_stats.num_generation_tokens - 1 > 0 - else 0 - ) + # build timing based on how far the request progressed + timing: RequestTiming = None + was_queued = req_stats.queued_ts > 0 + was_scheduled = req_stats.scheduled_ts > 0 + got_first_token = req_stats.first_token_ts > 0 + got_last_token = req_stats.last_token_ts > 0 + + if was_queued and was_scheduled: + # queued: from first QUEUED event to first SCHEDULED + queued_time = req_stats.scheduled_ts - req_stats.queued_ts + + if got_first_token and got_last_token: + # request generated tokens - full timing available + + # prefill: from first SCHEDULED to first NEW_TOKEN + # (any preemptions during prefill are included) + prefill_time = req_stats.first_token_ts - req_stats.scheduled_ts + + # decode: from first NEW_TOKEN to last NEW_TOKEN + # (any preemptions during decode are included) + decode_time = req_stats.last_token_ts - req_stats.first_token_ts + + # inference: from first SCHEDULED to last NEW_TOKEN + # (any preemptions during prefill or decode are included) + inference_time = req_stats.last_token_ts - req_stats.scheduled_ts + + # don't count the token generated by the prefill phase + mean_tpot = ( + decode_time / (req_stats.num_generation_tokens - 1) + if req_stats.num_generation_tokens > 1 + else None + ) + timing = CompletedTiming( + queued_time=queued_time, + prefill_time=prefill_time, + decode_time=decode_time, + inference_time=inference_time, + mean_time_per_output_token=mean_tpot, + ) + else: + # scheduled but no tokens (abort during prefill, KV error, etc.) + timing = ScheduledTiming(queued_time=queued_time) + # else: timing stays None (rejected before scheduling) finished_req = FinishedRequestStats( finish_reason=finish_reason, @@ -439,11 +496,7 @@ def update_from_finished_request( num_prompt_tokens=num_prompt_tokens, num_generation_tokens=req_stats.num_generation_tokens, max_tokens_param=max_tokens_param, - queued_time=queued_time, - prefill_time=prefill_time, - inference_time=inference_time, - decode_time=decode_time, - mean_time_per_output_token=mean_time_per_output_token, + timing=timing, is_corrupted=req_stats.is_corrupted, num_cached_tokens=num_cached_tokens, )