Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,15 +248,22 @@ def test_abort_during_kv_transfer():
# Request removed from PB but blocks should not be freed.
assert len(scheduler.requests) == 1

# Abort the request, and check the blocks are still not freed
# Abort the request. Since the request is already finished
# (FINISHED_LENGTH_CAPPED), this becomes an "abort after finished" scenario.
# Blocks will NOT be freed immediately; instead they wait for the connector
# to report finished_sending.
scheduler.finish_requests([request.request_id], RequestStatus.FINISHED_ABORTED)

# After abort, the request should still exist (waiting for finished_sending).
# This is the new behavior for "abort after finished" scenario.
assert len(scheduler.requests) == 1
assert request.status == RequestStatus.FINISHED_ABORTED

# Simulate a finished sending notification
# Simulate a finished sending notification - now blocks will be freed
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request.request_id]
finished_sending=set([request.request_id])
)
scheduler.update_from_output(scheduler_output, model_runner_output)
assert_scheduler_empty(scheduler)
33 changes: 33 additions & 0 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def __init__(self):
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()
self.reqs_abort_done: set[ReqId] = set()

def _add_new_req(
self,
Expand Down Expand Up @@ -553,6 +554,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
# Reqs that were aborted after finished and need cleanup.
self._reqs_abort_done: set[ReqId] = set()

def shutdown(self):
self._stop_event.set()
Expand Down Expand Up @@ -774,12 +777,14 @@ def build_connector_meta(
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
meta.reqs_not_processed = self._reqs_not_processed
meta.reqs_abort_done = self._reqs_abort_done

# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
self._reqs_abort_done = set()

return meta

Expand All @@ -794,6 +799,24 @@ def request_finished(
"""
from vllm.v1.request import RequestStatus

# Check if this is an abort after finished case.
if request.status == RequestStatus.FINISHED_ABORTED:
# Request was already finished and is now being aborted.
# Clean up state and mark for immediate reporting via
# finished_sending to unblock block freeing.
req_id = request.request_id
logger.debug(
"NIXLConnector request_finished(%s): abort after finished, "
"marking for cleanup via finished_sending",
req_id,
)
self._reqs_not_processed.add(req_id)
self._reqs_need_send.pop(req_id, None)
self._reqs_abort_done.add(req_id)
# Don't delay free blocks - will be freed when finished_sending
# is reported from worker.
return False, None

params = request.kv_transfer_params
logger.debug(
"NIXLConnector request_finished(%s), request_status=%s, "
Expand Down Expand Up @@ -999,6 +1022,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._invalid_block_ids: set[int] = set()
# requests that skipped transfer (handshake or transfer failures)
self._failed_recv_reqs: set[ReqId] = set()
# requests that were aborted after finished
self._aborted_reqs: set[ReqId] = set()

# Handshake metadata of this worker for NIXL transfers.
self.xfer_handshake_metadata: NixlHandshakePayload | None = None
Expand Down Expand Up @@ -2002,6 +2027,10 @@ def get_finished(self) -> tuple[set[str], set[str]]:
del self._reqs_to_send[req_id]
done_sending.add(req_id)

# Add aborted requests (abort after finished) to done_sending.
done_sending.update(self._aborted_reqs)
self._aborted_reqs.clear()

return done_sending, done_recving

def _get_new_notifs(self) -> set[str]:
Expand Down Expand Up @@ -2169,6 +2198,10 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
if req_id in self._reqs_to_process:
self._reqs_to_send[req_id] = expiration_time

# Handle aborted requests (abort after finished).
# These will be reported as done_sending immediately.
self._aborted_reqs.update(metadata.reqs_abort_done)

def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None and self.kv_topo is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
Expand Down
21 changes: 20 additions & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1709,10 +1709,29 @@ def finish_requests(
# First pass: collect requests to remove from queues
for req_id in request_ids:
request = self.requests.get(req_id)
if request is None or request.is_finished():
if request is None:
# Invalid request ID.
continue

if request.is_finished():
# If the request is already finished, only FINISHED_ABORTED is
# allowed, which is used to force resource cleanup.
assert finished_status == RequestStatus.FINISHED_ABORTED, (
"Only FINISHED_ABORTED is allowed for requests that are "
"already finished."
)
logger.info("Aborting finished request %s.", req_id)
# Set status to FINISHED_ABORTED so connector can detect this
# case and participate in cleanup.
request.status = RequestStatus.FINISHED_ABORTED
# Notify connector to participate in cleanup. Blocks will be
# freed when connector reports finished_sending.
# A finished request can only exist in self.requests when
# connector delays block freeing (P/D scenario).
assert self.connector is not None
self._connector_finished(request)
Comment on lines +1727 to +1732
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to remove this.
Connectors may assume that request_finished is called only once.
See #33377 for example.

Instead, NixlConnector will have to poll each of its "being sent" requests to see if their status was changed.

continue

valid_requests.append(request)
if request.status == RequestStatus.RUNNING:
running_requests_to_remove.add(request)
Expand Down