diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 537a02464d0b..21ef86465e04 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -3413,3 +3413,52 @@ def test_prepend_skipped_requests_order(): # verify waiting order is preserved assert list(scheduler.waiting) == expected_waiting_reqs + + +def test_abort_request_waiting_for_remote_kvs(): + scheduler = create_scheduler(use_kv_connector=True) + + # add a single request + request = create_requests(num_requests=1)[0] + scheduler.add_request(request) + + # set request to waiting for remote KVs, and abort it + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + scheduler.finish_requests((request.request_id,), RequestStatus.FINISHED_ABORTED) + assert request.status == RequestStatus.FINISHED_ABORTED + + # verify request is not deleted + assert request.request_id in scheduler.requests + + # finish recving request + scheduler_output = scheduler.schedule() + model_runner_output = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + kv_connector_output=KVConnectorOutput(finished_recving={request.request_id}), + ) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # assert request is deleted + assert request.request_id not in scheduler.requests + assert not scheduler.finished_recving_kv_req_ids + + +def test_abort_request_finished_recving(): + scheduler = create_scheduler(use_kv_connector=True) + + # add a single request + request = create_requests(num_requests=1)[0] + scheduler.add_request(request) + + # set request to waiting for remote KVs, finished but not yet updated + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + scheduler.finished_recving_kv_req_ids.add(request.request_id) + + # abort request + scheduler.finish_requests((request.request_id,), RequestStatus.FINISHED_ABORTED) + assert request.status == RequestStatus.FINISHED_ABORTED + + # verify request is deleted + assert request.request_id not in scheduler.requests + assert not scheduler.finished_recving_kv_req_ids diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 1805f009db0e..5b84202a581a 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -42,7 +42,7 @@ TransferSpec, ) from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput -from vllm.v1.request import Request +from vllm.v1.request import Request, RequestStatus from .utils import ( EOS_TOKEN_ID, @@ -355,7 +355,7 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool): self.scheduler.update_from_output(scheduler_output, model_runner_output) if ( - prev_token_id is EOS_TOKEN_ID + prev_token_id == EOS_TOKEN_ID and prev_token_id != token_id and self.scheduler.requests ): @@ -730,6 +730,57 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner): assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs) +def test_abort_loading_requests(request_runner): + offloaded_block_size = 12 + gpu_block_size = 4 + num_gpu_blocks = 100 + + runner = request_runner( + offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks, + ) + + # store 1 blocks + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output(block_hashes) + ) + runner.run( + decoded_tokens=[EOS_TOKEN_ID], + expected_stored_gpu_block_indexes=(0, 1, 2), + ) + + # start a request to load the first block, but don't complete + runner.scheduler.reset_prefix_cache() + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.lookup.return_value = 1 + runner.run( + decoded_tokens=[], + complete_transfers=False, + ) + + # request triggered a load + transfer_jobs = list(runner.offloading_spec.handler.transfer_specs) + assert transfer_jobs + + # abort request + req_id = str(runner.req_id) + runner.scheduler.finish_requests((req_id,), RequestStatus.FINISHED_ABORTED) + + # verify request is not deleted + assert req_id in runner.scheduler.requests + + # complete loading request + runner.run( + decoded_tokens=[], + expected_loaded_gpu_block_indexes=(0, 1, 2), + ) + + # assert request is deleted + assert req_id not in runner.scheduler.requests + + class TestOffloadingConnectorStats: """Tests for OffloadingConnector stats reconstruction and operations.""" diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 3f7ac9374e15..2bb720ea2677 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1669,19 +1669,30 @@ def finish_requests( # Second pass: set status and free requests for request in valid_requests: + delay_free_blocks = False + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + delay_free_blocks = ( + request.request_id not in self.finished_recving_kv_req_ids + ) + self.finished_recving_kv_req_ids.discard(request.request_id) + self.failed_recving_kv_req_ids.discard(request.request_id) + request.status = finished_status - self._free_request(request) + self._free_request(request, delay_free_blocks=delay_free_blocks) - def _free_request(self, request: Request) -> dict[str, Any] | None: + def _free_request( + self, request: Request, delay_free_blocks: bool = False + ) -> dict[str, Any] | None: assert request.is_finished() - delay_free_blocks, kv_xfer_params = self._connector_finished(request) + connector_delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) request_id = request.request_id self.finished_req_ids.add(request_id) if self.finished_req_ids_dict is not None: self.finished_req_ids_dict[request.client_index].add(request_id) + delay_free_blocks |= connector_delay_free_blocks if not delay_free_blocks: self._free_blocks(request) @@ -1953,7 +1964,13 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput): # KV Connector:: update recv and send status from last step. for req_id in kv_connector_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) - self.finished_recving_kv_req_ids.add(req_id) + assert req_id in self.requests + req = self.requests[req_id] + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + self.finished_recving_kv_req_ids.add(req_id) + else: + assert RequestStatus.is_finished(req.status) + self._free_blocks(self.requests[req_id]) for req_id in kv_connector_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) assert req_id in self.requests