Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3413,3 +3413,52 @@ def test_prepend_skipped_requests_order():

# verify waiting order is preserved
assert list(scheduler.waiting) == expected_waiting_reqs


def test_abort_request_waiting_for_remote_kvs():
scheduler = create_scheduler(use_kv_connector=True)

# add a single request
request = create_requests(num_requests=1)[0]
scheduler.add_request(request)

# set request to waiting for remote KVs, and abort it
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.finish_requests((request.request_id,), RequestStatus.FINISHED_ABORTED)
assert request.status == RequestStatus.FINISHED_ABORTED

# verify request is not deleted
assert request.request_id in scheduler.requests

# finish recving request
scheduler_output = scheduler.schedule()
model_runner_output = ModelRunnerOutput(
req_ids=[],
req_id_to_index={},
kv_connector_output=KVConnectorOutput(finished_recving={request.request_id}),
)
scheduler.update_from_output(scheduler_output, model_runner_output)

# assert request is deleted
assert request.request_id not in scheduler.requests
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered whether it would be useful to check the blocks were freed, as the NIXL tests do:

  req_to_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[0].req_to_blocks                                                                                                                                 
  assert req0_id not in req_to_blocks
```

but in either case we're verifying that `_free_blocks()` was called, so I guess not

(just noting for reference)

assert not scheduler.finished_recving_kv_req_ids


def test_abort_request_finished_recving():
scheduler = create_scheduler(use_kv_connector=True)

# add a single request
request = create_requests(num_requests=1)[0]
scheduler.add_request(request)

# set request to waiting for remote KVs, finished but not yet updated
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
scheduler.finished_recving_kv_req_ids.add(request.request_id)

# abort request
scheduler.finish_requests((request.request_id,), RequestStatus.FINISHED_ABORTED)
assert request.status == RequestStatus.FINISHED_ABORTED

# verify request is deleted
assert request.request_id not in scheduler.requests
assert not scheduler.finished_recving_kv_req_ids
55 changes: 53 additions & 2 deletions tests/v1/kv_connector/unit/test_offloading_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
TransferSpec,
)
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
from vllm.v1.request import Request
from vllm.v1.request import Request, RequestStatus

from .utils import (
EOS_TOKEN_ID,
Expand Down Expand Up @@ -355,7 +355,7 @@ def _run(self, decoded_tokens: list[int], complete_transfers: bool):
self.scheduler.update_from_output(scheduler_output, model_runner_output)

if (
prev_token_id is EOS_TOKEN_ID
prev_token_id == EOS_TOKEN_ID
and prev_token_id != token_id
and self.scheduler.requests
):
Expand Down Expand Up @@ -730,6 +730,57 @@ def test_concurrent_lookups_of_the_same_prefix(request_runner):
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)


def test_abort_loading_requests(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100

runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)

# store 1 blocks
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2),
)

# start a request to load the first block, but don't complete
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
)

# request triggered a load
transfer_jobs = list(runner.offloading_spec.handler.transfer_specs)
assert transfer_jobs

# abort request
req_id = str(runner.req_id)
runner.scheduler.finish_requests((req_id,), RequestStatus.FINISHED_ABORTED)

# verify request is not deleted
assert req_id in runner.scheduler.requests

# complete loading request
runner.run(
decoded_tokens=[],
expected_loaded_gpu_block_indexes=(0, 1, 2),
)

# assert request is deleted
assert req_id not in runner.scheduler.requests


class TestOffloadingConnectorStats:
"""Tests for OffloadingConnector stats reconstruction and operations."""

Expand Down
25 changes: 21 additions & 4 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1669,19 +1669,30 @@ def finish_requests(

# Second pass: set status and free requests
for request in valid_requests:
delay_free_blocks = False
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
delay_free_blocks = (
request.request_id not in self.finished_recving_kv_req_ids
)
self.finished_recving_kv_req_ids.discard(request.request_id)
self.failed_recving_kv_req_ids.discard(request.request_id)

request.status = finished_status
self._free_request(request)
self._free_request(request, delay_free_blocks=delay_free_blocks)

def _free_request(self, request: Request) -> dict[str, Any] | None:
def _free_request(
self, request: Request, delay_free_blocks: bool = False
) -> dict[str, Any] | None:
assert request.is_finished()

delay_free_blocks, kv_xfer_params = self._connector_finished(request)
connector_delay_free_blocks, kv_xfer_params = self._connector_finished(request)
self.encoder_cache_manager.free(request)
request_id = request.request_id
self.finished_req_ids.add(request_id)
if self.finished_req_ids_dict is not None:
self.finished_req_ids_dict[request.client_index].add(request_id)

delay_free_blocks |= connector_delay_free_blocks
if not delay_free_blocks:
self._free_blocks(request)

Expand Down Expand Up @@ -1953,7 +1964,13 @@ def _update_from_kv_xfer_finished(self, kv_connector_output: KVConnectorOutput):
# KV Connector:: update recv and send status from last step.
for req_id in kv_connector_output.finished_recving or ():
logger.debug("Finished recving KV transfer for request %s", req_id)
self.finished_recving_kv_req_ids.add(req_id)
assert req_id in self.requests
req = self.requests[req_id]
if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
self.finished_recving_kv_req_ids.add(req_id)
else:
assert RequestStatus.is_finished(req.status)
self._free_blocks(self.requests[req_id])
for req_id in kv_connector_output.finished_sending or ():
logger.debug("Finished sending KV transfer for request %s", req_id)
assert req_id in self.requests
Expand Down