diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index c6c4a5085bff..effd31f5ecd3 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest import torch @@ -62,6 +62,43 @@ def test_finish_request(): assert len(scheduler.waiting) == 9 - i +def test_abort_request_with_kv_connector(): + # `use_kv_connector=True` will expose a kv_connector to the scheduler, but + # we will need to mimick the delay_freed since the default kv_connector is + # too simple + scheduler = create_scheduler(use_kv_connector=True) + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + with patch.object( + scheduler, + "_connector_finished", + side_effect=lambda req: ( + req.status == RequestStatus.FINISHED_LENGTH_CAPPED, + {"fake_kv_params": False}, + ), + ): + for i, request in enumerate(requests): + scheduler.finish_requests( + request.request_id, RequestStatus.FINISHED_LENGTH_CAPPED + ) + assert request.request_id in scheduler.requests # since delayed + assert len(scheduler.waiting) == 9 - i + + assert not scheduler.waiting and not scheduler.running + assert len(scheduler.requests) == 10 + + for i, request in enumerate(requests): + scheduler.finish_requests( + request.request_id, RequestStatus.FINISHED_ABORTED + ) + assert request.request_id not in scheduler.requests # since aborted + + assert not scheduler.waiting and not scheduler.running + assert not scheduler.requests + + def test_get_num_unfinished_requests(): scheduler = create_scheduler() requests = create_requests(num_requests=10) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 52b98ef65459..4dabdb7fedf9 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1326,11 +1326,31 @@ def finish_requests( waiting_requests_to_remove = [] valid_requests = [] + # this is only required only if we have a kv connector + should_force_abort = ( + finished_status == RequestStatus.FINISHED_ABORTED + and self.get_kv_connector() is not None + ) + forced_aborted_requests = [] + # First pass: collect requests to remove from queues for req_id in request_ids: request = self.requests.get(req_id) - if request is None or request.is_finished(): - # Invalid request ID. + if request is None: + continue # Invalid request ID. + elif request.is_finished(): + if ( + should_force_abort + and request.status == RequestStatus.FINISHED_LENGTH_CAPPED + ): + # we need to force the status to FINISHED_ABORTED to avoid + # the request being delayed freed. The kv_connector will + # delay the free if it the status is FINISHED_LENGTH_CAPPED + logger.info( + "Request %s is finished but will get forced aborted.", + req_id, + ) + forced_aborted_requests.append(request) continue valid_requests.append(request) @@ -1350,6 +1370,11 @@ def finish_requests( request.status = finished_status self._free_request(request) + # Free the requests that are being delayed + for request in forced_aborted_requests: + request.status = finished_status + self._free_request(request) + def _free_request(self, request: Request) -> dict[str, Any] | None: assert request.is_finished()