From 45f732ff0a7120726d2fb9431f626acf331c650b Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 27 Oct 2025 11:17:09 +0000 Subject: [PATCH 1/6] propagate num_external_tokens to logger Signed-off-by: NickLucche --- vllm/v1/core/sched/scheduler.py | 2 ++ vllm/v1/engine/__init__.py | 2 ++ vllm/v1/metrics/loggers.py | 2 +- vllm/v1/metrics/stats.py | 5 +++++ vllm/v1/request.py | 5 ++++- 5 files changed, 14 insertions(+), 2 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 46dc1071b839..e2b4732465d7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -538,6 +538,7 @@ def schedule(self) -> SchedulerOutput: self._update_connector_prefix_cache_stats( request, num_external_computed_tokens ) + request.num_external_computed_tokens += num_external_computed_tokens # Request was already popped from self.waiting # unless it was re-added above due to new_blocks being None. @@ -1033,6 +1034,7 @@ def update_from_output( trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, num_nans_in_logits=request.num_nans_in_logits, + num_external_computed_tokens=request.num_external_computed_tokens, ) ) else: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 058a4bcaecb5..c2a68f708a28 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -121,6 +121,8 @@ class EngineCoreOutput( trace_headers: Mapping[str, str] | None = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 + # The number of tokens that have been computed remotely. + num_external_computed_tokens: int = 0 # The number of NaNs in logits. # A value greater than 0 indicates that the output is corrupted. diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 1a175e9e110b..8f75d6db17b0 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -121,7 +121,7 @@ def _reset(self, now): def _track_iteration_stats(self, iteration_stats: IterationStats): # Save tracked stats for token counters. - self.num_prompt_tokens += iteration_stats.num_prompt_tokens + self.num_prompt_tokens += iteration_stats.num_local_prompt_tokens self.num_generation_tokens += iteration_stats.num_generation_tokens self.num_corrupted_reqs += iteration_stats.num_corrupted_reqs diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 4e9db98db0bc..c27a9fd51ba0 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -220,6 +220,8 @@ def __init__(self): self.num_generation_tokens = 0 self.num_prompt_tokens = 0 self.num_preempted_reqs = 0 + # Num of prompt tokens that have been computed locally. + self.num_local_prompt_tokens = 0 self.finished_requests: list[FinishedRequestStats] = [] self.max_num_generation_tokens_iter: list[int] = [] self.n_params_iter: list[int] = [] @@ -250,6 +252,9 @@ def update_from_output( self.num_generation_tokens += num_new_generation_tokens if is_prefilling: self.num_prompt_tokens += prompt_len + self.num_local_prompt_tokens += ( + prompt_len - output.num_external_computed_tokens + ) first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 7a5f1183ed48..d06eb5cdee2b 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -118,9 +118,12 @@ def __init__( # indicates that the output is corrupted self.num_nans_in_logits = 0 - # The number of requests being preempted by the scheduler + # The number of requests being preempted by the scheduler. self.num_preemptions = 0 + # The number of tokens that have been computed remotely. + self.num_external_computed_tokens = 0 + self.block_hashes: list[BlockHash] = [] self.get_hash_new_full_blocks: Callable[[], list[BlockHash]] | None = None if block_hasher is not None: From 96f91ba0d7a0eed760f1d867f15964112c480150 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 27 Oct 2025 14:36:09 +0000 Subject: [PATCH 2/6] assign remote token count earlier Signed-off-by: NickLucche --- vllm/v1/core/sched/scheduler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e2b4732465d7..7927ddebcef8 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -430,6 +430,11 @@ def schedule(self) -> SchedulerOutput: self.waiting.pop_request() skipped_waiting_requests.prepend_request(request) continue + # Keep track of number of tokens to load from remote + # for the request st we can compute actual throughput + request.num_external_computed_tokens = ( + num_external_computed_tokens + ) num_external_computed_tokens = ext_tokens @@ -538,7 +543,6 @@ def schedule(self) -> SchedulerOutput: self._update_connector_prefix_cache_stats( request, num_external_computed_tokens ) - request.num_external_computed_tokens += num_external_computed_tokens # Request was already popped from self.waiting # unless it was re-added above due to new_blocks being None. From d4fc2181b25eaa0619a56793f84fe2492a73f383 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 27 Oct 2025 14:53:17 +0000 Subject: [PATCH 3/6] reset num external tokens when having to recompute prefill Signed-off-by: NickLucche --- vllm/v1/core/sched/scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7927ddebcef8..c3931f488e60 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1506,6 +1506,8 @@ def _update_requests_with_invalid_blocks( request.num_computed_tokens - request.num_cached_tokens ) request.num_computed_tokens = request.num_cached_tokens + # Prefill is to be recomputed locally. + request.num_external_computed_tokens = 0 affected_req_ids.add(request.request_id) From d038f1ce9f52c9617ab770f78bb09770c1eda420 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 29 Oct 2025 15:37:46 +0000 Subject: [PATCH 4/6] fix num external tokens on failures Co-authored-by: David Ben-David davidb@pliops.com Signed-off-by: NickLucche --- vllm/v1/core/sched/scheduler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c3931f488e60..26c24e13432b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1491,9 +1491,12 @@ def _update_requests_with_invalid_blocks( marked_invalid_block = True # Truncate the computed tokens at the first failed block request.num_computed_tokens = idx * self.block_size - total_affected_tokens += ( + num_affected_tokens = ( req_num_computed_tokens - request.num_computed_tokens ) + total_affected_tokens += num_affected_tokens + # Prefill is to be recomputed locally, track its performance. + request.num_external_computed_tokens -= num_affected_tokens if is_affected: if not marked_invalid_block: @@ -1506,8 +1509,6 @@ def _update_requests_with_invalid_blocks( request.num_computed_tokens - request.num_cached_tokens ) request.num_computed_tokens = request.num_cached_tokens - # Prefill is to be recomputed locally. - request.num_external_computed_tokens = 0 affected_req_ids.add(request.request_id) From a1b82fb9a0bc3e8734b6c42558b5a7d2fed6dbe9 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 29 Oct 2025 16:03:55 +0000 Subject: [PATCH 5/6] comment Signed-off-by: NickLucche --- vllm/v1/request.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index d06eb5cdee2b..2ce35e492970 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -118,7 +118,7 @@ def __init__( # indicates that the output is corrupted self.num_nans_in_logits = 0 - # The number of requests being preempted by the scheduler. + # The number of times the request was preempted by the scheduler. self.num_preemptions = 0 # The number of tokens that have been computed remotely. From 4d2bacbe295b075d22dc1431598ab7a1af4c142a Mon Sep 17 00:00:00 2001 From: NickLucche Date: Tue, 11 Nov 2025 18:00:57 +0000 Subject: [PATCH 6/6] fix ext_tokens Signed-off-by: NickLucche --- vllm/v1/core/sched/scheduler.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 26c24e13432b..310aef49578e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -432,10 +432,7 @@ def schedule(self) -> SchedulerOutput: continue # Keep track of number of tokens to load from remote # for the request st we can compute actual throughput - request.num_external_computed_tokens = ( - num_external_computed_tokens - ) - + request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens # Total computed tokens (local + external).