diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8455746cd56d..eef8fcf985f7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -448,7 +448,9 @@ def schedule(self) -> SchedulerOutput: self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue - + # Keep track of number of tokens to load from remote + # for the request st we can compute actual throughput + request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens # Total computed tokens (local + external). @@ -1081,6 +1083,7 @@ def update_from_output( trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, num_nans_in_logits=request.num_nans_in_logits, + num_external_computed_tokens=request.num_external_computed_tokens, ) ) else: @@ -1533,9 +1536,12 @@ def _update_requests_with_invalid_blocks( marked_invalid_block = True # Truncate the computed tokens at the first failed block request.num_computed_tokens = idx * self.block_size - total_affected_tokens += ( + num_affected_tokens = ( req_num_computed_tokens - request.num_computed_tokens ) + total_affected_tokens += num_affected_tokens + # Prefill is to be recomputed locally, track its performance. + request.num_external_computed_tokens -= num_affected_tokens if is_affected: if not marked_invalid_block: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 058a4bcaecb5..c2a68f708a28 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -121,6 +121,8 @@ class EngineCoreOutput( trace_headers: Mapping[str, str] | None = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 + # The number of tokens that have been computed remotely. + num_external_computed_tokens: int = 0 # The number of NaNs in logits. # A value greater than 0 indicates that the output is corrupted. diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 1a175e9e110b..8f75d6db17b0 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -121,7 +121,7 @@ def _reset(self, now): def _track_iteration_stats(self, iteration_stats: IterationStats): # Save tracked stats for token counters. - self.num_prompt_tokens += iteration_stats.num_prompt_tokens + self.num_prompt_tokens += iteration_stats.num_local_prompt_tokens self.num_generation_tokens += iteration_stats.num_generation_tokens self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 4e9db98db0bc..c27a9fd51ba0 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -220,6 +220,8 @@ def __init__(self): self.num_generation_tokens = 0 self.num_prompt_tokens = 0 self.num_preempted_reqs = 0 + # Num of prompt tokens that have been computed locally. + self.num_local_prompt_tokens = 0 self.finished_requests: list[FinishedRequestStats] = [] self.max_num_generation_tokens_iter: list[int] = [] self.n_params_iter: list[int] = [] @@ -250,6 +252,9 @@ def update_from_output( self.num_generation_tokens += num_new_generation_tokens if is_prefilling: self.num_prompt_tokens += prompt_len + self.num_local_prompt_tokens += ( + prompt_len - output.num_external_computed_tokens + ) first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 7a5f1183ed48..2ce35e492970 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -118,9 +118,12 @@ def __init__( # indicates that the output is corrupted self.num_nans_in_logits = 0 - # The number of requests being preempted by the scheduler + # The number of times the request was preempted by the scheduler. self.num_preemptions = 0 + # The number of tokens that have been computed remotely. + self.num_external_computed_tokens = 0 + self.block_hashes: list[BlockHash] = [] self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None if block_hasher is not None: