diff --git a/tests/v1/kv_connector/unit/test_error_propagation.py b/tests/v1/kv_connector/unit/test_error_propagation.py index 20e181f379f5..11286611ecdb 100644 --- a/tests/v1/kv_connector/unit/test_error_propagation.py +++ b/tests/v1/kv_connector/unit/test_error_propagation.py @@ -121,7 +121,7 @@ def test_error_propagation_async_load(fail_scheduler: Scheduler): assert len(fail_scheduler.waiting) == 1 assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS - assert request.num_computed_tokens == 0 + assert request.num_computed_tokens == num_external_computed_tokens (req_block_ids,) = fail_scheduler.kv_cache_manager.get_block_ids(request.request_id) invalid_block_ids = {req_block_ids[invalid_block_idx]} diff --git a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py index 6cb2d3ea4d97..53fe599849b6 100644 --- a/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py +++ b/tests/v1/kv_connector/unit/test_invalid_blocks_correctness.py @@ -339,7 +339,7 @@ def test_async_recompute_blocks_not_cached_when_invalid( # request should be waiting for remote KVs assert len(recompute_scheduler.waiting) == 1 assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS - assert request.num_computed_tokens == 0 + assert request.num_computed_tokens == num_external_computed_tokens # get the allocated block IDs (req_block_ids,) = recompute_scheduler.kv_cache_manager.get_block_ids( diff --git a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py index 364eabb96a31..fcdb2869d7dc 100644 --- a/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py +++ b/tests/v1/kv_connector/unit/test_kv_load_failure_recovery.py @@ -78,7 +78,7 @@ def test_async_load_failure( assert len(scheduler.waiting) == 3 for request in scheduler.waiting: - assert request.num_computed_tokens == 0 + assert request.num_computed_tokens == num_external_computed_tokens assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 @@ -103,7 +103,7 @@ def test_async_load_failure( min_invalid_block_idx * scheduler.block_size ) else: - assert request.num_computed_tokens == 0 + assert request.num_computed_tokens == num_external_computed_tokens assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert scheduler.failed_recving_kv_req_ids == {request2.request_id} assert scheduler.connector.get_num_new_matched_tokens.call_count == 3 @@ -305,7 +305,7 @@ def test_async_progressive_load_failure( assert len(scheduler.waiting) == 1 assert scheduler.waiting.peek_request().request_id == request.request_id - assert request.num_computed_tokens == 0 + assert request.num_computed_tokens == num_external_computed_tokens assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS assert scheduler.connector.get_num_new_matched_tokens.call_count == 1 diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index b9588ebcd211..f0ff216be664 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -57,7 +57,7 @@ def test_basic_lifecycle(): assert len(scheduler.waiting) == 1 assert request in scheduler.waiting assert request.status == RequestStatus.WAITING_FOR_REMOTE_KVS - assert request.num_computed_tokens == 0 + assert request.num_computed_tokens == NUM_TOKENS # ... but should have (uncached) blocks allocated to it. block_pool = scheduler.kv_cache_manager.block_pool diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index e44702b99e8e..cb99de93b6fb 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -638,6 +638,7 @@ def schedule(self) -> SchedulerOutput: num_computed_tokens = ( num_new_local_computed_tokens + num_external_computed_tokens ) + assert num_computed_tokens <= request.num_tokens else: # KVTransfer: WAITING reqs have num_computed_tokens > 0 # after async KV recvs are completed. @@ -773,6 +774,20 @@ def schedule(self) -> SchedulerOutput: # into the WAITING_FOR_REMOTE_KV state. skipped_waiting_requests.prepend_request(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + # Set num_computed_tokens even though KVs are not yet loaded. + # request.num_computed_tokens will not be used anywhere until + # the request finished the KV transfer. + # + # If a transfer error is reported by the connector, + # request.num_computed_tokens will be re-set accordingly in + # _update_requests_with_invalid_blocks. + # + # When the transfer is finished, either successfully or not, + # request.num_computed_tokens will correctly reflect the number + # of computed tokens. + # _update_waiting_for_remote_kv will then cache + # only the successfully loaded tokens. + request.num_computed_tokens = num_computed_tokens continue self.running.append(request) @@ -1994,17 +2009,17 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: self.failed_recving_kv_req_ids.remove(request.request_id) else: # Now that the blocks are ready, actually cache them. - (block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id) - num_computed_tokens = len(block_ids) * self.block_size - # Handle the case where num request tokens less than one block. - num_computed_tokens = min(num_computed_tokens, request.num_tokens) - if num_computed_tokens == request.num_tokens: - num_computed_tokens -= 1 # This will cache the blocks iff caching is enabled. - self.kv_cache_manager.cache_blocks(request, num_computed_tokens) + self.kv_cache_manager.cache_blocks(request, request.num_computed_tokens) - # Update the request state for scheduling. - request.num_computed_tokens = num_computed_tokens + # on a full prompt hit, we need to re-compute the last token + # in order to be able to sample the next token + if request.num_computed_tokens == request.num_tokens: + request.num_computed_tokens = request.num_tokens - 1 + + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = request.num_computed_tokens # Return that we are ready. self.finished_recving_kv_req_ids.remove(request.request_id) @@ -2084,13 +2099,8 @@ def _update_requests_with_invalid_blocks( # We iterate only over blocks that may contain externally computed # tokens if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - # Async loading. If num_computed_tokens is set it implies we - # already processed some block failures for it in a prior step - req_num_computed_tokens = ( - request.num_computed_tokens - if req_id in self.failed_recving_kv_req_ids - else len(req_block_ids) * self.block_size - ) + # Async loading. num_computed_tokens does not include new tokens + req_num_computed_tokens = request.num_computed_tokens else: # Sync loading. num_computed_tokens includes new tokens req_num_computed_tokens = request.num_cached_tokens