diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index e821e47172ce..c05e979a9ffc 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -153,7 +153,6 @@ def test_prefix_caching_for_prefill_dedup(): same_prompt=True, block_size=BLOCK_SIZE, ) - requests_copy = requests.copy() # Two requests with the same prompt. req0 = requests.pop(0) @@ -167,26 +166,31 @@ def test_prefix_caching_for_prefill_dedup(): # Make sure prefix caching de-duplicates the prompts in the same step, # so all the blocks except the last are shared between the two requests. assert len(sched_output.num_scheduled_tokens) == 2 - num_blocks = num_prompt_tokens // BLOCK_SIZE - assert req0.num_cached_tokens == 0 - assert req1.num_cached_tokens >= num_blocks * BLOCK_SIZE + assert sched_output.num_scheduled_tokens[req0.request_id] == num_prompt_tokens + assert ( + sched_output.num_scheduled_tokens[req1.request_id] + == num_prompt_tokens % BLOCK_SIZE + ) sched_outputs.append(scheduler.schedule()) while sched_outputs: + added_req = None if requests: - scheduler.add_request(requests.pop(0)) + added_req = requests.pop(0) + scheduler.add_request(added_req) sched_output = sched_outputs.popleft() model_runner_output = _make_model_runner_output(sched_output) scheduler.update_from_output(sched_output, model_runner_output) sched_output = scheduler.schedule() if sched_output.num_scheduled_tokens: sched_outputs.append(sched_output) + if added_req: + assert ( + sched_output.num_scheduled_tokens[added_req.request_id] + == num_prompt_tokens % BLOCK_SIZE + ) - # Other requests scheduled after the two requests should also get - # prefix cache hit. assert scheduler.get_num_unfinished_requests() == 0 - for req in requests_copy[1:]: - assert req.num_cached_tokens >= num_blocks * BLOCK_SIZE def test_prefix_caching_for_multi_turn(): @@ -243,12 +247,15 @@ def test_prefix_caching_for_multi_turn(): # Schedule the next-turn requests. for req in next_turn_requests: scheduler.add_request(req) - sched_outputs.append(scheduler.schedule()) + sched_output = scheduler.schedule() + sched_outputs.append(sched_output) # Make sure the next-turn requests get prefix cache hit by the previous # requests. for req in next_turn_requests: - assert req.num_cached_tokens == req.num_prompt_tokens // BLOCK_SIZE * BLOCK_SIZE + assert sched_output.num_scheduled_tokens[req.request_id] == ( + req.num_prompt_tokens % BLOCK_SIZE + ) def test_abort_request_when_structured_output_fsm_cannot_advance(): diff --git a/tests/v1/engine/test_output_processor.py b/tests/v1/engine/test_output_processor.py index ece48e009d27..1919349790fa 100644 --- a/tests/v1/engine/test_output_processor.py +++ b/tests/v1/engine/test_output_processor.py @@ -84,6 +84,7 @@ def test_incremental_detokenization( engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, + prompts_list=dummy_test_vectors.prompt_tokens, request_ids=[req.request_id for req in requests], ) @@ -506,6 +507,7 @@ def test_logprobs_processor( engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, + prompts_list=dummy_test_vectors.prompt_tokens, generated_logprobs_raw=None if num_sample_logprobs is None else dummy_test_vectors.generation_logprobs, @@ -691,6 +693,7 @@ def test_stop_token( engine_core = MockEngineCore( tokens_list=[generation_tokens], + prompts_list=dummy_test_vectors.prompt_tokens, generated_logprobs_raw=[generation_logprobs] if do_logprobs else None, prompt_logprobs_raw=None, eos_token_id=sampling_params.eos_token_id, @@ -794,6 +797,7 @@ def test_stop_string( engine_core = MockEngineCore( tokens_list=dummy_test_vectors.generation_tokens, + prompts_list=dummy_test_vectors.prompt_tokens, generated_logprobs_raw=dummy_test_vectors.generation_logprobs if num_sample_logprobs else None, @@ -917,6 +921,7 @@ def test_iteration_stats(dummy_test_vectors): engine_core = MockEngineCore( dummy_test_vectors.generation_tokens, + dummy_test_vectors.prompt_tokens, request_ids=[req.request_id for req in requests], ) @@ -927,7 +932,7 @@ def test_iteration_stats(dummy_test_vectors): inactive_request = requests[num_active] # First iteration has 2 prefills. - outputs = engine_core.get_outputs()[:num_active] + outputs = engine_core.get_outputs(num_active) iteration_stats = IterationStats() output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) total_prompt_tokens = sum( @@ -941,7 +946,7 @@ def test_iteration_stats(dummy_test_vectors): assert iteration_stats.num_generation_tokens == num_active # Just decodes in this step. - outputs = engine_core.get_outputs()[:num_active] + outputs = engine_core.get_outputs(num_active) iteration_stats = IterationStats() output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) @@ -951,7 +956,7 @@ def test_iteration_stats(dummy_test_vectors): # Add a new request - prefill and 2 decodes in this step. output_processor.add_request(inactive_request, None) num_active += 1 - outputs = engine_core.get_outputs()[:num_active] + outputs = engine_core.get_outputs(num_active) iteration_stats = IterationStats() output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1]) @@ -960,7 +965,7 @@ def test_iteration_stats(dummy_test_vectors): assert iteration_stats.num_generation_tokens == num_active # Just decodes in this step. - outputs = engine_core.get_outputs()[:num_active] + outputs = engine_core.get_outputs(num_active) iteration_stats = IterationStats() output_processor.process_outputs(outputs, engine_core_timestamp, iteration_stats) @@ -1003,6 +1008,7 @@ def test_lora_request_tracking(log_stats: bool, dummy_test_vectors): engine_core = MockEngineCore( dummy_test_vectors.generation_tokens, + dummy_test_vectors.prompt_tokens, request_ids=[req.request_id for req in requests], ) diff --git a/tests/v1/engine/utils.py b/tests/v1/engine/utils.py index de953a58843e..013e73bd8e48 100644 --- a/tests/v1/engine/utils.py +++ b/tests/v1/engine/utils.py @@ -11,6 +11,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.v1.engine import EngineCoreOutput, FinishReason +from vllm.v1.metrics.stats import PrefillStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors GeneralTokenizerType: TypeAlias = PreTrainedTokenizer | PreTrainedTokenizerFast @@ -330,6 +331,7 @@ class MockEngineCore: def __init__( self, tokens_list: list[list[int]], + prompts_list: list[list[int]], # For each request, for each sampled token offset, # a tuple of # (list of topk token ids, list of sample logprob vals, rank) @@ -346,12 +348,13 @@ def __init__( ) -> None: self.num_requests = len(tokens_list) self.tokens_list = tokens_list - self.current_idx = 0 + self.prompts_list = prompts_list self.generated_logprobs_raw = generated_logprobs_raw self.do_logprobs = generated_logprobs_raw is not None self.prompt_logprobs_raw = prompt_logprobs_raw self.do_prompt_logprobs = prompt_logprobs_raw is not None self.request_finished = [False for _ in range(self.num_requests)] + self.request_token_idx = [0 for _ in range(self.num_requests)] self.eos_token_id = eos_token_id self.stop_token_ids = stop_token_ids self.request_ids = ( @@ -360,14 +363,18 @@ def __init__( else [f"request-{i}" for i in range(self.num_requests)] ) - def get_outputs(self) -> list[EngineCoreOutput]: + def get_outputs(self, num_active: int = -1) -> list[EngineCoreOutput]: do_logprobs = self.do_logprobs do_prompt_logprobs = self.do_prompt_logprobs - token_idx = self.current_idx outputs = [] - for req_idx, token_ids in enumerate(self.tokens_list): + for req_idx, (token_ids, prompt_token_ids) in enumerate( + zip(self.tokens_list, self.prompts_list) + ): + if num_active != -1 and req_idx >= num_active: + break if not self.request_finished[req_idx]: + token_idx = self.request_token_idx[req_idx] if do_logprobs: assert self.generated_logprobs_raw is not None (logprobs_token_ids_, logprobs_, sampled_token_ranks_) = ( @@ -381,19 +388,32 @@ def get_outputs(self) -> list[EngineCoreOutput]: else: logprobs = None if do_prompt_logprobs: - if self.current_idx == 0: + if token_idx == 0: assert self.prompt_logprobs_raw is not None prompt_logprobs = self.prompt_logprobs_raw[req_idx] else: prompt_logprobs = None else: prompt_logprobs = None + + # Add prefill_stats on first output (prefill) for this request + if token_idx == 0: + prefill_stats = PrefillStats() + prefill_stats.set( + num_prompt_tokens=len(prompt_token_ids), + num_local_cached_tokens=0, + num_external_cached_tokens=0, + ) + else: + prefill_stats = None + new_token_id = token_ids[token_idx] output = EngineCoreOutput( request_id=self.request_ids[req_idx], new_token_ids=[new_token_id], new_logprobs=logprobs, new_prompt_logprobs_tensors=prompt_logprobs, + prefill_stats=prefill_stats, ) if token_idx == len(token_ids) - 1: output.finish_reason = FinishReason.LENGTH @@ -407,5 +427,6 @@ def get_outputs(self) -> list[EngineCoreOutput]: self.request_finished[req_idx] = True outputs.append(output) - self.current_idx += 1 + self.request_token_idx[req_idx] += 1 + return outputs diff --git a/tests/v1/metrics/test_stats.py b/tests/v1/metrics/test_stats.py index 48f6caefdbff..3d9315c1ae6c 100644 --- a/tests/v1/metrics/test_stats.py +++ b/tests/v1/metrics/test_stats.py @@ -1,7 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from vllm.v1.engine import FinishReason -from vllm.v1.metrics.stats import IterationStats, PromptTokenStats, RequestStateStats +from vllm.v1.metrics.stats import ( + IterationStats, + PrefillStats, + PromptTokenStats, + RequestStateStats, +) def test_iteration_stats_repr(): @@ -114,15 +119,18 @@ def test_prompt_token_stats_all_computed(): stats = PromptTokenStats() # Case 1: No caching (All tokens computed locally) - stats.update_from_output( - num_cached_tokens=0, - num_external_computed_tokens=0, - prompt_len=1000, + prefill_stats = PrefillStats() + prefill_stats.set( + num_prompt_tokens=1000, + num_local_cached_tokens=0, + num_external_cached_tokens=0, ) + stats.update_from_output(prefill_stats) assert stats.computed == 1000 assert stats.local_cache_hit == 0 assert stats.external_kv_transfer == 0 + assert stats.cached_tokens == 0 assert stats.total == 1000 @@ -131,15 +139,19 @@ def test_prompt_token_stats_partial_local_cache(): stats = PromptTokenStats() # Case 2: Partial local cache - stats.update_from_output( - num_cached_tokens=300, - num_external_computed_tokens=0, - prompt_len=1000, + prefill_stats = PrefillStats() + prefill_stats.set( + num_prompt_tokens=1000, + num_local_cached_tokens=300, + num_external_cached_tokens=0, ) + stats.update_from_output(prefill_stats) assert stats.computed == 700 assert stats.local_cache_hit == 300 assert stats.external_kv_transfer == 0 + assert stats.cached_tokens == 300 + assert stats.total == 1000 def test_prompt_token_stats_partial_external_transfer(): @@ -147,15 +159,19 @@ def test_prompt_token_stats_partial_external_transfer(): stats = PromptTokenStats() # Case 3: Partial external transfer - stats.update_from_output( - num_cached_tokens=500, - num_external_computed_tokens=500, - prompt_len=1000, + prefill_stats = PrefillStats() + prefill_stats.set( + num_prompt_tokens=1000, + num_local_cached_tokens=0, + num_external_cached_tokens=500, ) + stats.update_from_output(prefill_stats) assert stats.computed == 500 assert stats.local_cache_hit == 0 assert stats.external_kv_transfer == 500 + assert stats.cached_tokens == 500 + assert stats.total == 1000 def test_prompt_token_stats_mixed_sources(): @@ -163,47 +179,60 @@ def test_prompt_token_stats_mixed_sources(): stats = PromptTokenStats() # Case 4: Mixed sources - stats.update_from_output( - num_cached_tokens=600, - num_external_computed_tokens=200, - prompt_len=1000, + prefill_stats = PrefillStats() + prefill_stats.set( + num_prompt_tokens=1000, + num_local_cached_tokens=400, + num_external_cached_tokens=200, ) + stats.update_from_output(prefill_stats) assert stats.computed == 400 assert stats.local_cache_hit == 400 assert stats.external_kv_transfer == 200 + assert stats.cached_tokens == 600 + assert stats.total == 1000 def test_prompt_token_stats_full_local_cache_recompute(): """Test full local cache triggers last token recomputation. - When all tokens are cached, the scheduler reduces num_cached_tokens by 1 - to force the model to recompute the last token. + When all tokens are cached, the scheduler forces the model to recompute + the last token (num_computed_tokens=1), with the rest from cache. """ stats = PromptTokenStats() - # Case 5: Full local cache (999 cached after reduction, 1 recomputed) - stats.update_from_output( - num_cached_tokens=999, - num_external_computed_tokens=0, - prompt_len=1000, + # Case 5: Full local cache (999 cached, 1 recomputed) + prefill_stats = PrefillStats() + prefill_stats.set( + num_prompt_tokens=1000, + num_local_cached_tokens=999, + num_external_cached_tokens=0, ) + stats.update_from_output(prefill_stats) assert stats.computed == 1 assert stats.local_cache_hit == 999 + assert stats.external_kv_transfer == 0 + assert stats.cached_tokens == 999 + assert stats.total == 1000 def test_prompt_token_stats_full_external_transfer_recompute(): """Test full external transfer triggers last token recomputation.""" stats = PromptTokenStats() - # Case 6: Full external transfer (999 cached after reduction, 1 recomputed) - stats.update_from_output( - num_cached_tokens=999, - num_external_computed_tokens=999, - prompt_len=1000, + # Case 6: Full external transfer (999 from external, 1 recomputed) + prefill_stats = PrefillStats() + prefill_stats.set( + num_prompt_tokens=1000, + num_local_cached_tokens=0, + num_external_cached_tokens=999, ) + stats.update_from_output(prefill_stats) assert stats.computed == 1 assert stats.local_cache_hit == 0 assert stats.external_kv_transfer == 999 + assert stats.cached_tokens == 999 + assert stats.total == 1000 diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f61d54faedc7..c9b7abb28d74 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -629,7 +629,6 @@ def schedule(self) -> SchedulerOutput: step_skipped_waiting.prepend_request(request) continue - request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens connector_prefix_cache_queries = ( @@ -642,6 +641,15 @@ def schedule(self) -> SchedulerOutput: num_new_local_computed_tokens + num_external_computed_tokens ) assert num_computed_tokens <= request.num_tokens + + # Track first scheduled prefill, not post-preemption repeat prefills + if request.prefill_stats is not None: + assert num_computed_tokens <= request.num_prompt_tokens + request.prefill_stats.set( + num_prompt_tokens=request.num_prompt_tokens, + num_local_cached_tokens=num_new_local_computed_tokens, + num_external_cached_tokens=num_external_computed_tokens, + ) else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -826,9 +834,6 @@ def schedule(self) -> SchedulerOutput: token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - # Count the number of prefix cached tokens. - if request.num_cached_tokens < 0: - request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule @@ -1466,10 +1471,9 @@ def update_from_output( pooling_output=pooler_output, stop_reason=request.stop_reason, events=request.take_events(), + prefill_stats=request.take_prefill_stats(), kv_transfer_params=kv_transfer_params, trace_headers=request.trace_headers, - num_cached_tokens=request.num_cached_tokens, - num_external_computed_tokens=request.num_external_computed_tokens, routed_experts=routed_experts, num_nans_in_logits=request.num_nans_in_logits, ) @@ -1496,7 +1500,6 @@ def update_from_output( finish_reason=request.get_finished_reason(), events=request.take_events(), trace_headers=request.trace_headers, - num_cached_tokens=request.num_cached_tokens, ) ) @@ -2070,10 +2073,6 @@ def _update_waiting_for_remote_kv(self, request: Request) -> None: if request.num_computed_tokens == request.num_tokens: request.num_computed_tokens = request.num_tokens - 1 - # Count the number of prefix cached tokens. - if request.num_cached_tokens < 0: - request.num_cached_tokens = request.num_computed_tokens - self.finished_recving_kv_req_ids.remove(request.request_id) def _try_promote_blocked_waiting_request(self, request: Request) -> bool: @@ -2220,7 +2219,7 @@ def _update_requests_with_invalid_blocks( req_num_computed_tokens - request.num_computed_tokens ) total_affected_tokens += num_affected_tokens - request.num_external_computed_tokens -= num_affected_tokens + # collect invalid block and all downstream dependent blocks if evict_blocks: blocks_to_evict.update(req_block_ids[idx:]) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index c97adfe8f465..d5c5dba63475 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -15,7 +15,7 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from vllm.v1.metrics.stats import SchedulerStats +from vllm.v1.metrics.stats import PrefillStats, SchedulerStats from vllm.v1.outputs import LogprobsLists, LogprobsTensors from vllm.v1.serial_utils import UtilityResult @@ -171,10 +171,9 @@ class EngineCoreOutput( kv_transfer_params: dict[str, Any] | None = None trace_headers: Mapping[str, str] | None = None - # The number of tokens with prefix cache hits (local + external). - num_cached_tokens: int = 0 - # The number of tokens computed remotely (original count from connector). - num_external_computed_tokens: int = 0 + + prefill_stats: PrefillStats | None = None + routed_experts: np.ndarray | None = None # The number of NaNs in logits. # A value greater than 0 indicates that the output is corrupted. diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index f9e965092288..3d1d7a82db30 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -617,8 +617,13 @@ def process_outputs( stop_reason = engine_core_output.stop_reason kv_transfer_params = engine_core_output.kv_transfer_params routed_experts = engine_core_output.routed_experts - req_state.num_cached_tokens = engine_core_output.num_cached_tokens - req_state.is_prefilling = False + + if req_state.is_prefilling: + if engine_core_output.prefill_stats is not None: + req_state.num_cached_tokens = ( + engine_core_output.prefill_stats.num_cached_tokens + ) + req_state.is_prefilling = False if pooling_output is None: assert req_state.detokenizer is not None @@ -776,7 +781,6 @@ def _update_stats_from_output( engine_core_output, engine_core_timestamp, req_state.is_prefilling, - req_state.prompt_len, req_state.stats, self.lora_states, req_state.lora_name, diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 79955815d582..bb51656f5095 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -237,6 +237,40 @@ class FinishedRequestStats: num_cached_tokens: int = 0 +@dataclass +class PrefillStats: + """Breakdown of a scheduled prefill computation. + + Fields: + num_prompt_tokens: Total number of tokens to be prefilled. + num_computed_tokens: Tokens to be prefilled locally (actual compute work). + num_cached_tokens: Tokens to be prefilled without actual compute work. + num_local_cached_tokens: Tokens to be prefilled from local prefix cache. + num_external_cached_tokens: Tokens to be prefilled from external KV transfer. + """ + + num_prompt_tokens: int = 0 + num_computed_tokens: int = 0 + num_cached_tokens: int = 0 + num_local_cached_tokens: int = 0 + num_external_cached_tokens: int = 0 + + def set( + self, + num_prompt_tokens: int, + num_local_cached_tokens: int, + num_external_cached_tokens: int, + ): + num_cached_tokens = num_local_cached_tokens + num_external_cached_tokens + assert num_cached_tokens <= num_prompt_tokens + + self.num_prompt_tokens = num_prompt_tokens + self.num_computed_tokens = num_prompt_tokens - num_cached_tokens + self.num_cached_tokens = num_cached_tokens + self.num_local_cached_tokens = num_local_cached_tokens + self.num_external_cached_tokens = num_external_cached_tokens + + @dataclass class PromptTokenStats: """Breakdown of prompt tokens by source. @@ -265,28 +299,14 @@ class PromptTokenStats: cached_tokens: int = 0 total: int = 0 - def update_from_output( - self, - num_cached_tokens: int, - num_external_computed_tokens: int, - prompt_len: int, - ) -> None: + def update_from_output(self, prefill_stats: PrefillStats) -> None: """Update stats from a prefill output.""" - self.computed += prompt_len - num_cached_tokens - self.external_kv_transfer += num_external_computed_tokens - # FIXME(yifan): local_cache_hit can go negative after preemption. - # num_cached_tokens is a one-time snapshot from first scheduling and - # is never reset on preemption, while num_external_computed_tokens is - # overwritten on re-scheduling. If CPU offload finds more tokens on - # the second pass than the original total, the subtraction underflows. - # A fundamental fix is to track the first-time num_external_computed_tokens - # as a separate metric rather than reusing num_external_computed_tokens - # for metric directly. - self.local_cache_hit += max( - 0, (num_cached_tokens - num_external_computed_tokens) - ) - self.cached_tokens += num_cached_tokens - self.total += prompt_len + self.computed += prefill_stats.num_computed_tokens + self.cached_tokens += prefill_stats.num_cached_tokens + self.total += prefill_stats.num_prompt_tokens + + self.local_cache_hit += prefill_stats.num_local_cached_tokens + self.external_kv_transfer += prefill_stats.num_external_cached_tokens def get_by_source(self, source: str) -> int: """Get token count by source label.""" @@ -333,7 +353,6 @@ def update_from_output( output: "EngineCoreOutput", engine_core_timestamp: float, is_prefilling: bool, - prompt_len: int, req_stats: RequestStateStats, lora_states: "LoRARequestStates", lora_name: str | None, @@ -342,11 +361,8 @@ def update_from_output( self.num_generation_tokens += num_new_generation_tokens if is_prefilling: - self.prompt_token_stats.update_from_output( - num_cached_tokens=output.num_cached_tokens, - num_external_computed_tokens=output.num_external_computed_tokens, - prompt_len=prompt_len, - ) + if output.prefill_stats is not None: + self.prompt_token_stats.update_from_output(output.prefill_stats) first_token_latency = self._time_since(req_stats.arrival_time) self.time_to_first_tokens_iter.append(first_token_latency) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 946e71c15d35..678d57580cc7 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -20,6 +20,7 @@ EngineCoreRequest, FinishReason, ) +from vllm.v1.metrics.stats import PrefillStats from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.utils import ConstantList @@ -145,9 +146,6 @@ def __init__( self.all_token_ids = ConstantList(self._all_token_ids) # trace_headers self.trace_headers = trace_headers - # State - # The number of tokens with prefix cache hits. - self.num_cached_tokens = -1 # True if this request is scheduled as a non-final prefill chunk. self.is_prefill_chunk = False @@ -159,8 +157,7 @@ def __init__( # The number of times this request has been preempted by the scheduler. self.num_preemptions = 0 - # The number of tokens that have been computed remotely. - self.num_external_computed_tokens = 0 + self.prefill_stats: PrefillStats | None = PrefillStats() self.block_hashes: list[BlockHash] = [] # Store the block hasher without binding self to avoid creating a @@ -278,6 +275,13 @@ def take_events(self) -> list[EngineCoreEvent] | None: events, self.events = self.events, [] return events + def take_prefill_stats(self) -> PrefillStats | None: + if self.prefill_stats is None: + return None + prefill_stats = self.prefill_stats + self.prefill_stats = None + return prefill_stats + def __lt__(self, other: "Request") -> bool: """ Compare two requests based on priority, arrival time, and request ID.