diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 2fe45242153c..1a0fa3f2c46f 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1549,6 +1549,12 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role): # All can be scheduled - 1st token. output = scheduler.schedule() + + # verify request-level cache hit stats are set + for request in requests: + assert request.num_cached_tokens == NUM_MATCHED_NEW_TOKENS + assert request.num_external_computed_tokens == NUM_MATCHED_NEW_TOKENS + if is_async: assert _num_waiting_requests(scheduler) == 2 assert scheduler.running == [] @@ -1607,6 +1613,12 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role): # Restarts the preempted request - generate 3rd token. # This will have a local and remote cache hit. output = scheduler.schedule() + + # verify request level hit stats are NOT re-set + for request in requests: + assert request.num_cached_tokens == NUM_MATCHED_NEW_TOKENS + assert request.num_external_computed_tokens == NUM_MATCHED_NEW_TOKENS + if is_async: waiting_req_ids = [ req.request_id @@ -1649,6 +1661,11 @@ def test_kv_connector_handles_preemption(is_async, use_ec_connector, ec_role): # All memory should be freed since nothing is running. assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() == NUM_BLOCKS - 1 + # final verification request-level cache hit stats are NOT re-set + for request in requests: + assert request.num_cached_tokens == NUM_MATCHED_NEW_TOKENS + assert request.num_external_computed_tokens == NUM_MATCHED_NEW_TOKENS + def make_output(scheduler: Scheduler): return ModelRunnerOutput( diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ea2c2a6cd180..a2f3103da258 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -619,7 +619,6 @@ def schedule(self) -> SchedulerOutput: step_skipped_waiting.prepend_request(request) continue - request.num_external_computed_tokens = ext_tokens num_external_computed_tokens = ext_tokens connector_prefix_cache_queries = ( @@ -632,6 +631,16 @@ def schedule(self) -> SchedulerOutput: num_new_local_computed_tokens + num_external_computed_tokens ) assert num_computed_tokens <= request.num_tokens + + if request.num_preemptions == 0: + # For request-level stats, + # track hits only the first time a request gets scheduled. + # If allocation will later fail, we will get back here + # the next time the request re-tries scheduling + request.num_cached_tokens = num_computed_tokens + request.num_external_computed_tokens = ( + num_external_computed_tokens + ) else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -802,9 +811,6 @@ def schedule(self) -> SchedulerOutput: token_budget -= num_new_tokens request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens - # Count the number of prefix cached tokens. - if request.num_cached_tokens < 0: - request.num_cached_tokens = num_computed_tokens # Encoder-related. if encoder_inputs_to_schedule: scheduled_encoder_inputs[request_id] = encoder_inputs_to_schedule @@ -2158,6 +2164,7 @@ def _update_requests_with_invalid_blocks( req_num_computed_tokens = request.num_computed_tokens else: # Sync loading. num_computed_tokens includes new tokens + # TODO(orozery): Bug below! Incorrect for preempted requests! req_num_computed_tokens = request.num_cached_tokens req_num_computed_blocks = ( @@ -2192,7 +2199,19 @@ def _update_requests_with_invalid_blocks( req_num_computed_tokens - request.num_computed_tokens ) total_affected_tokens += num_affected_tokens - request.num_external_computed_tokens -= num_affected_tokens + if request.num_preemptions == 0: + # For request-level stats, + # track hits only the first time a request gets scheduled. + num_local_hit_tokens = ( + request.num_cached_tokens - request.num_external_computed_tokens + ) + assert num_local_hit_tokens >= 0 + request.num_cached_tokens = min( + request.num_cached_tokens, request.num_computed_tokens + ) + request.num_external_computed_tokens = max( + 0, request.num_cached_tokens - num_local_hit_tokens + ) # collect invalid block and all downstream dependent blocks if evict_blocks: blocks_to_evict.update(req_block_ids[idx:]) @@ -2204,6 +2223,8 @@ def _update_requests_with_invalid_blocks( # Revert to considering only cached tokens as computed. # Currently this only applies to sync loading; Async # loading does not yet support block sharing + + # TODO(orozery): Bug! Incorrect computation for preempted requests! total_affected_tokens += ( request.num_computed_tokens - request.num_cached_tokens ) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index f2ee33b49f22..747073ad6e47 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -145,9 +145,6 @@ def __init__( self.all_token_ids = ConstantList(self._all_token_ids) # trace_headers self.trace_headers = trace_headers - # State - # The number of tokens with prefix cache hits. - self.num_cached_tokens = -1 # True if this request is scheduled as a non-final prefill chunk. self.is_prefill_chunk = False @@ -159,7 +156,14 @@ def __init__( # The number of times this request has been preempted by the scheduler. self.num_preemptions = 0 - # The number of tokens that have been computed remotely. + # Fields used for request-level cache stats + # These fields are only set on the first time a request gets scheduled + # Cache hits following request preemption are currently not tracked. + + # Total number of KV cache hit tokens: + # local prefix cache hits + external (connector-based) hits + self.num_cached_tokens = 0 + # Number of external tokens hit (excluding local prefix cache hits) self.num_external_computed_tokens = 0 self.block_hashes: list[BlockHash] = []