From 7878d58a8f9e3600414c02df81ebceec7b641900 Mon Sep 17 00:00:00 2001 From: Guan Luo Date: Wed, 16 Jul 2025 10:43:52 -0700 Subject: [PATCH 1/9] fix: NIXL connector transfers partial block to transfer complete multi-modal context to downstream worker Signed-off-by: GuanLuo --- .../kv_transfer/kv_connector/v1/nixl_connector.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c06cda356f57..2d0a291f22f3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -257,9 +257,7 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. assert num_computed_tokens % self.block_size == 0 - rounded_num_prompt_tokens = round_down( - len(request.prompt_token_ids), self.block_size) - count = max(rounded_num_prompt_tokens - num_computed_tokens, 0) + count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) if count > 0: return count, True @@ -382,12 +380,8 @@ def request_finished( or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): return False, None - # Get computed blocks. - all_full = request.num_computed_tokens % self.block_size == 0 - computed_block_ids = block_ids if all_full else block_ids[:-1] - # If prompt < block_size, no xfer so free blocks immediately. - delay_free_blocks = len(computed_block_ids) > 0 + delay_free_blocks = len(block_ids) > 0 if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion @@ -397,7 +391,7 @@ def request_finished( return delay_free_blocks, dict( do_remote_prefill=True, do_remote_decode=False, - remote_block_ids=computed_block_ids, + remote_block_ids=block_ids, remote_engine_id=self.engine_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, From 451b70ded16c97b5caf6ac4d4a28e0b4fbfeb216 Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Wed, 16 Jul 2025 12:51:34 -0700 Subject: [PATCH 2/9] fix: remove assert that may fail Signed-off-by: GuanLuo --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2d0a291f22f3..6ec221d7e34e 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -256,7 +256,6 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. - assert num_computed_tokens % self.block_size == 0 count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) if count > 0: return count, True From c9cbab6ed643e6c104373b7953f9637f6631b7f8 Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Wed, 16 Jul 2025 14:27:09 -0700 Subject: [PATCH 3/9] chore: style Signed-off-by: GuanLuo --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6ec221d7e34e..a7fc89f0169a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -29,7 +29,7 @@ from vllm.forward_context import ForwardContext from vllm.logger import init_logger from vllm.platforms import _Backend, current_platform -from vllm.utils import make_zmq_path, make_zmq_socket, round_down +from vllm.utils import make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus From c2d374abace4ffc21fb244f5f6cf2539435eee0b Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Thu, 31 Jul 2025 04:02:52 -0700 Subject: [PATCH 4/9] chore: address comment Signed-off-by: GuanLuo --- .../kv_connector/v1/nixl_connector.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index a7fc89f0169a..7988a60b3e6d 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -256,9 +256,8 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. - count = max(len(request.prompt_token_ids) - num_computed_tokens, 0) - if count > 0: - return count, True + count = len(request.prompt_token_ids) - num_computed_tokens + return count, count > 0 # No remote prefill for this request. return 0, False @@ -279,18 +278,16 @@ def update_state_after_alloc(self, request: "Request", # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. - # figure out full computed blocks to save + # save all blocks block_ids = blocks.get_block_ids()[0] - all_full = request.num_tokens % self.block_size == 0 - full_block_ids = (block_ids if all_full else block_ids[:-1]) # TODO: skip the blocks that are already in the host xfer buffer. # Currently, the host xfer buffer block is 1-to-1 mapped to device # kv blocks, so host blocks won't be flushed as long as its device # block is not overwritten; and it will be safe to skip saving them # to host xfer buffer. - if full_block_ids: + if block_ids: self._reqs_need_save[request.request_id] = \ - (request, full_block_ids) + (request, block_ids) elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all(p in params for p in ("remote_engine_id", "remote_host", @@ -379,7 +376,8 @@ def request_finished( or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): return False, None - # If prompt < block_size, no xfer so free blocks immediately. + # [TODO] check whether block_ids actually ever be 0. If not we could + # remove the conditional below delay_free_blocks = len(block_ids) > 0 if delay_free_blocks: From 2089edde479cf3a3b24abcc35de1746fa1788ced Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Thu, 31 Jul 2025 04:16:31 -0700 Subject: [PATCH 5/9] chore: add back sanity check Signed-off-by: GuanLuo --- vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 7988a60b3e6d..2bbd37f74764 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -257,7 +257,8 @@ def get_num_new_matched_tokens( if params is not None and params.get("do_remote_prefill"): # Remote prefill: get all prompt blocks from remote. count = len(request.prompt_token_ids) - num_computed_tokens - return count, count > 0 + if count > 0: + return count, True # No remote prefill for this request. return 0, False From 3eb776af67e1217b697dfdea82eee72796f3e924 Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Fri, 1 Aug 2025 13:03:20 -0700 Subject: [PATCH 6/9] test: update test to reflect new behavior Signed-off-by: GuanLuo --- .../kv_connector/unit/test_nixl_connector.py | 12 +++++----- .../unit/test_remote_decode_lifecycle.py | 22 ++++++++++++++----- .../unit/test_remote_prefill_lifecycle.py | 22 ++++++++++++++----- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index c5ca7df83685..e97a9659c633 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -173,9 +173,9 @@ def test_prompt_less_than_block_size(): """ Test that we can handle case where prompt is < block. - In this case, the P worker will send empty remote_block_ids. - The D worker should not schedule an async read in this case, - since there is nothing to pull. + In this case, the P worker will still send remote_block_ids of the + partial block. The D worker should schedule an async read + in this case. """ vllm_config = create_vllm_config() scheduler = create_scheduler(vllm_config) @@ -188,7 +188,7 @@ def test_prompt_less_than_block_size(): request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_prefill=True, - num_remote_blocks=0) + num_remote_blocks=1) scheduler.add_request(request) scheduler_output = scheduler.schedule() @@ -196,10 +196,10 @@ def test_prompt_less_than_block_size(): kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) - assert len(kv_connector_metadata.reqs_to_recv) == 0 + assert len(kv_connector_metadata.reqs_to_recv) == 1 # This request should be scheduled regularly. - assert len(scheduler_output.scheduled_new_reqs) == 1 + assert len(scheduler_output.scheduled_new_reqs) == 0 class FakeNixlConnectorWorker(NixlConnectorWorker): diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 12a71d97e8d2..de969dd72d28 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy +import math from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.request import FinishReason, RequestStatus @@ -120,12 +121,20 @@ def test_short_prompt_lifecycle(): model_runner_output = create_model_runner_output(reqs=[request]) # (1c): update_from_output() - # Since tokens < block_size, there will be no kv xfer. - # So this should be cleaned up immediately. - _ = scheduler.update_from_output(scheduler_output, model_runner_output) + # Even though tokens < block_size, there will be kv xfer for partial block. + eco = scheduler.update_from_output(scheduler_output, model_runner_output) + kv_transfer_params = eco[0].outputs[0].kv_transfer_params + + assert (len( + kv_transfer_params["remote_block_ids"]) == 1) # Confirm we do not have any memory leaks after req lifecycle. - # We need one more call to schedule() to clear data for persistent batch. + # We need to mark sending finish to clear data for persistent batch. + scheduler_output = scheduler.schedule() + scheduler.schedule() + model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) + model_runner_output.finished_sending = [request.request_id] + scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) @@ -168,9 +177,10 @@ def test_prefix_cache_lifecycle(): eco = scheduler.update_from_output(scheduler_output, model_runner_output) kv_transfer_params = eco[0].outputs[0].kv_transfer_params - # Ensure we send all block ids, even if there is a cache hit. + # Ensure we send all block ids, including the partial blocks, + # even if there is a cache hit. assert (len( - kv_transfer_params["remote_block_ids"]) == NUM_EXTERNAL_FULL_BLOCKS) + kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)) # STEP (2): Ensure it is freed. scheduler_output = scheduler.schedule() diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index f89970bf2c80..056a1c3461ee 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -360,7 +360,7 @@ def test_cannot_schedule_after_recv(): BLOCK_SIZE = vllm_config.cache_config.block_size # Prompt will use 2 blocks + 1 block after we schedule. NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) - NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) + NUM_TOKENS_REMOTE = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) request_remote = create_request(request_id=2, @@ -391,14 +391,23 @@ def test_cannot_schedule_after_recv(): assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 - # Step 4: try to schedule, not enough blocks. + # Step 4: try to schedule, remote request is put to running list + # because the transfer is completed. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal, request_remote]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 2 + assert len(scheduler.waiting) == 0 + + # Step 5: Remote request will be put back to waiting list + # because it needs new block to hold generated token. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_normal]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 1 - # Step 5: finish the request, free it. + # Step 6: finish the request, free it. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_normal], use_eos=True) @@ -406,16 +415,17 @@ def test_cannot_schedule_after_recv(): assert len(scheduler.running) == 0 assert len(scheduler.waiting) == 1 - # Step 6: now we can schedule (with 2 blocks computed). + # Step 7: now we can schedule (with 2 blocks computed), + # request is retrieved from preempted list. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote]) - assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + assert (scheduler_output.scheduled_cached_reqs.num_computed_tokens[0] == NUM_PROMPT_BLOCKS * BLOCK_SIZE) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 1 assert len(scheduler.waiting) == 0 - # Step 7: free everything. + # Step 8: free everything. scheduler_output = scheduler.schedule() model_runner_output = create_model_runner_output(reqs=[request_remote], use_eos=True) From 4dc41b9ac09d1707743e852ff9cd04ef47883380 Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Fri, 1 Aug 2025 14:00:34 -0700 Subject: [PATCH 7/9] style: format Signed-off-by: GuanLuo --- tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py | 7 +++---- .../v1/kv_connector/unit/test_remote_prefill_lifecycle.py | 3 ++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index de969dd72d28..07f9477361e3 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -import math from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT from vllm.v1.request import FinishReason, RequestStatus @@ -125,8 +124,7 @@ def test_short_prompt_lifecycle(): eco = scheduler.update_from_output(scheduler_output, model_runner_output) kv_transfer_params = eco[0].outputs[0].kv_transfer_params - assert (len( - kv_transfer_params["remote_block_ids"]) == 1) + assert (len(kv_transfer_params["remote_block_ids"]) == 1) # Confirm we do not have any memory leaks after req lifecycle. # We need to mark sending finish to clear data for persistent batch. @@ -180,7 +178,8 @@ def test_prefix_cache_lifecycle(): # Ensure we send all block ids, including the partial blocks, # even if there is a cache hit. assert (len( - kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + 1)) + kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS + + 1)) # STEP (2): Ensure it is freed. scheduler_output = scheduler.schedule() diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 056a1c3461ee..066a28c7a961 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -394,7 +394,8 @@ def test_cannot_schedule_after_recv(): # Step 4: try to schedule, remote request is put to running list # because the transfer is completed. scheduler_output = scheduler.schedule() - model_runner_output = create_model_runner_output(reqs=[request_normal, request_remote]) + model_runner_output = create_model_runner_output( + reqs=[request_normal, request_remote]) scheduler.update_from_output(scheduler_output, model_runner_output) assert len(scheduler.running) == 2 assert len(scheduler.waiting) == 0 From 00a1bcb0b182449d102666f0ace6fda282d60df9 Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Mon, 4 Aug 2025 17:15:22 -0700 Subject: [PATCH 8/9] test: add test case of insufficient block for transfer Signed-off-by: GuanLuo --- .../unit/test_remote_prefill_lifecycle.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index 066a28c7a961..6474491c5b64 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -433,3 +433,86 @@ def test_cannot_schedule_after_recv(): scheduler.update_from_output(scheduler_output, model_runner_output) _ = scheduler.schedule() assert_scheduler_empty(scheduler) + + +def test_cannot_recv(): + """ + Test that we can handle no schedule KV block transfer due to not + enough remaining KV blocks. + """ + + # NOTE: the KVCacheManager will use 1 null block. + # So there are 5 total working blocks. + TOTAL_NUM_BLOCKS = 6 + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS) + + # Prime the KVCache. + NUM_PROMPT_BLOCKS = 2 + BLOCK_SIZE = vllm_config.cache_config.block_size + # Prompt will use 2 blocks + 1 block after we schedule. + NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) + NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) + + request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_remote = create_request(request_id=2, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True) + + # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). + scheduler.add_request(request_normal) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 2: 3 blocks are in use, + # need 3 new for remote blocks but only 2 are available. + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + # Should not have KV transfer in progress. + assert (request_remote.status != RequestStatus.WAITING_FOR_REMOTE_KVS) + + # Step 3: finish the request, free it. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Step 4: now we can initiate KV transfer (with 2 blocks computed). + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + assert (request_remote.status == RequestStatus.WAITING_FOR_REMOTE_KVS) + + # Step 5: finish recving (5 blocks in use) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + reqs=[], finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Step 6: schedule remote request + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 7: free everything. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) From c8e8eac1bbd81aba17f55e074e62bb24e37bbbf5 Mon Sep 17 00:00:00 2001 From: GuanLuo Date: Mon, 4 Aug 2025 17:53:06 -0700 Subject: [PATCH 9/9] chore: address comment Signed-off-by: GuanLuo --- tests/v1/kv_connector/unit/test_nixl_connector.py | 6 ++---- tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py | 4 ---- .../kv_transfer/kv_connector/v1/nixl_connector.py | 2 +- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index e97a9659c633..c6739832355f 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -184,7 +184,7 @@ def test_prompt_less_than_block_size(): BLOCK_SIZE = vllm_config.cache_config.block_size NUM_TOKENS = int(BLOCK_SIZE * 0.5) - # Request will have 0 remote blocks. + # Request will have 1 partial remote block. request = create_request(request_id=1, num_tokens=NUM_TOKENS, do_remote_prefill=True, @@ -192,13 +192,11 @@ def test_prompt_less_than_block_size(): scheduler.add_request(request) scheduler_output = scheduler.schedule() - # This request should not have to read async. + # This request will read async. kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, NixlConnectorMetadata) assert len(kv_connector_metadata.reqs_to_recv) == 1 - - # This request should be scheduled regularly. assert len(scheduler_output.scheduled_new_reqs) == 0 diff --git a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py index 07f9477361e3..f0893c89e12a 100644 --- a/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py @@ -129,11 +129,9 @@ def test_short_prompt_lifecycle(): # Confirm we do not have any memory leaks after req lifecycle. # We need to mark sending finish to clear data for persistent batch. scheduler_output = scheduler.schedule() - scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.finished_sending = [request.request_id] scheduler.update_from_output(scheduler_output, model_runner_output) - _ = scheduler.schedule() assert_scheduler_empty(scheduler) @@ -183,9 +181,7 @@ def test_prefix_cache_lifecycle(): # STEP (2): Ensure it is freed. scheduler_output = scheduler.schedule() - scheduler.schedule() model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT) model_runner_output.finished_sending = [request_remote.request_id] scheduler.update_from_output(scheduler_output, model_runner_output) - _ = scheduler.schedule() assert_scheduler_empty(scheduler) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 2bbd37f74764..020162346f34 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -377,7 +377,7 @@ def request_finished( or request.status != RequestStatus.FINISHED_LENGTH_CAPPED): return False, None - # [TODO] check whether block_ids actually ever be 0. If not we could + # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below delay_free_blocks = len(block_ids) > 0