From 206e640012dc1a57329d7b7551256550cd22b319 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 15 May 2025 16:23:33 -0700 Subject: [PATCH 1/4] [BugFix] Fix handling of num_computed_tokens with connector https://github.com/vllm-project/vllm/pull/18001 changed the behaviour subtly and broke some multi-connector cases. This change ensures we don't call the connector get_num_new_matched_tokens method a second time for a given request after an async load has completed. Signed-off-by: Nick Hill --- vllm/v1/core/sched/scheduler.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f338e4ba1440..97784fa60573 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -345,32 +345,38 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests.appendleft(request) continue + num_external_computed_tokens = 0 + load_kv_async = False + # Get already-cached tokens. if num_prealloc_computed_tokens == 0: new_computed_blocks, num_native_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_native_computed_tokens)) + + # Total computed tokens (local + external). + num_computed_tokens = (num_native_computed_tokens + + num_external_computed_tokens) else: # P/D: skip checking prefix cache if loaded from remote kvs. new_computed_blocks = KVCacheBlocks.create_empty() num_native_computed_tokens = 0 - # Get externally-cached tokens if using a KVConnector. - num_external_computed_tokens, load_kv_async = ( - (0, False) if self.connector is None else - self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) - - # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + - num_external_computed_tokens + - num_prealloc_computed_tokens) + # Total computed tokens (allocated in prior step). + num_computed_tokens = num_prealloc_computed_tokens encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget # P/D: loading remote KV, do not allocate for new work. if load_kv_async: + assert num_external_computed_tokens > 0 num_new_tokens = 0 # Number of tokens to be scheduled. else: @@ -411,7 +417,7 @@ def schedule(self) -> SchedulerOutput: # KVConnector: update internal state after allocation. # This information is used to determine if a load is # needed for this request. - if self.connector is not None: + if num_external_computed_tokens: self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks, From 45de6c7251fceb13baeefe74ce04b6315c84e295 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 15 May 2025 16:52:39 -0700 Subject: [PATCH 2/4] fix linting Signed-off-by: Nick Hill --- vllm/v1/core/sched/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 97784fa60573..69afc066439b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -418,6 +418,7 @@ def schedule(self) -> SchedulerOutput: # This information is used to determine if a load is # needed for this request. if num_external_computed_tokens: + assert self.connector is not None self.connector.update_state_after_alloc( request, new_computed_blocks + new_blocks, From 2bcad95a94499c4fa1fde26c2b26fc4037c17b5e Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 15 May 2025 18:39:00 -0700 Subject: [PATCH 3/4] handle full cache hit on P/D decode worker case Signed-off-by: Nick Hill --- .../kv_connector/v1/nixl_connector.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index abd1ea2bea82..bea6bdc48426 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -208,7 +208,17 @@ def get_num_new_matched_tokens( rounded_num_prompt_tokens = round_down( len(request.prompt_token_ids), self.block_size) count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) - return count, count > 0 + if count > 0: + return count, True + + # NOTE: if count is 0 here, we have less than block_size + # tokens to pull after subtracting the local prefix cache hit. + # The remote only sends fully computed blocks, so there is + # nothing to transfer but we need still need to notify the + # prefill worker so that the remote blocks are freed. + if all(p in params for p in ("remote_engine_id", "remote_host", + "remote_port")): + self._reqs_need_recv[request.request_id] = (request, []) # No remote prefill for this request. return 0, False @@ -224,10 +234,6 @@ def update_state_after_alloc(self, request: "Request", num_external_tokens, params) if params is not None and params.get("do_remote_prefill"): - # NOTE(rob): if prompt < block_size, no remote blocks - # since the remote only sends fully computed blocks, so - # skip recving for this request. num_external_tokens - # should be 0 if there are no remote blocks. if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")): From c71c3c376426b929c2f9f905ca99f9b4b0d0f873 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 16 May 2025 06:36:35 -0700 Subject: [PATCH 4/4] fix comment wording MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Nicolò Lucchesi --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index bea6bdc48426..086dbeb90f34 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -214,7 +214,7 @@ def get_num_new_matched_tokens( # NOTE: if count is 0 here, we have less than block_size # tokens to pull after subtracting the local prefix cache hit. # The remote only sends fully computed blocks, so there is - # nothing to transfer but we need still need to notify the + # nothing to transfer but we still need to notify the # prefill worker so that the remote blocks are freed. if all(p in params for p in ("remote_engine_id", "remote_host", "remote_port")):