diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 20ef566416b8..abc84e9c3000 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -34,6 +34,7 @@ NixlConnectorWorker, NixlHandshakePayload, NixlKVConnectorStats, + ReqState, compute_nixl_compatibility_hash, ) from vllm.distributed.kv_transfer.kv_transfer_state import ( @@ -1618,7 +1619,8 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init): kv_meta = sched_out.kv_connector_metadata assert kv_meta is not None assert isinstance(kv_meta, NixlConnectorMetadata) - assert req.request_id in kv_meta.reqs_in_batch + assert req.request_id in kv_meta.reqs_to_send + assert kv_meta.reqs_to_send[req.request_id] == ReqState.SCHEDULED #### Model Runner start #### # Bind scheduler-produced metadata and start worker processing. @@ -1643,7 +1645,8 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init): kv_meta2 = sched_out2.kv_connector_metadata assert kv_meta2 is not None assert isinstance(kv_meta2, NixlConnectorMetadata) - assert req.request_id not in kv_meta2.reqs_in_batch + assert req.request_id in kv_meta2.reqs_to_send + assert kv_meta2.reqs_to_send[req.request_id] == ReqState.ABORTED # Bind empty/abort metadata and run worker step #### Model Runner start #### diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index be56eb4e93c1..4bb43b01e2f1 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import copy +import enum import logging import math import queue @@ -219,13 +220,17 @@ class ReqMeta: remote: RemoteMeta | None = None +class ReqState(enum.Enum): + SCHEDULED = 1 + FINISHED = 2 + ABORTED = 3 + + class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {} - self.reqs_to_send: dict[ReqId, float] = {} - self.reqs_in_batch: set[ReqId] = set() - self.reqs_not_processed: set[ReqId] = set() + self.reqs_to_send: dict[ReqId, ReqState] = {} def _add_new_req( self, @@ -482,12 +487,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} - # Reqs to send and their expiration time - self._reqs_need_send: dict[ReqId, float] = {} - self._reqs_in_batch: set[ReqId] = set() - # Reqs to remove from processed set because they're not to send after - # remote prefill or aborted. - self._reqs_not_processed: set[ReqId] = set() + # Reqs to send state updates + self._reqs_need_send: dict[ReqId, ReqState] = {} def shutdown(self): self._stop_event.set() @@ -623,7 +624,7 @@ def update_state_after_alloc( return if params.get("do_remote_decode"): - self._reqs_in_batch.add(request.request_id) + self._reqs_need_send[request.request_id] = ReqState.SCHEDULED if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. @@ -697,14 +698,10 @@ def build_connector_meta( ) meta.reqs_to_send = self._reqs_need_send - meta.reqs_in_batch = self._reqs_in_batch - meta.reqs_not_processed = self._reqs_not_processed # Clear the list once workers start the transfers self._reqs_need_recv.clear() self._reqs_need_save.clear() - self._reqs_in_batch = set() - self._reqs_not_processed = set() self._reqs_need_send = {} return meta @@ -747,24 +744,15 @@ def request_finished( if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: # Also include the case of a P/D Prefill request with immediate # block free (eg abort). Stop tracking this request. - self._reqs_not_processed.add(request.request_id) + self._reqs_need_send[request.request_id] = ReqState.ABORTED return False, None # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below delay_free_blocks = len(block_ids) > 0 - if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion - logger.debug( - "NIXLConnector request_finished(%s) waiting for %d seconds " - "for remote decode to fetch blocks", - request.request_id, - envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, - ) - self._reqs_need_send[request.request_id] = ( - time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT - ) + self._reqs_need_send[request.request_id] = ReqState.FINISHED return delay_free_blocks, dict( do_remote_prefill=True, @@ -1874,7 +1862,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.permute_device_kv(block_ids_to_permute) # Handle timeout to avoid stranding blocks on remote. - now = time.perf_counter() + now = time.monotonic() while self._reqs_to_send: req_id, expires = next(iter(self._reqs_to_send.items())) # Sorted dict, oldest requests are put first so we can exit early. @@ -2043,19 +2031,23 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): # which blocks are read from D. As P can now more easily lag behind D # while processing the next batch, we make sure to only set an # expiration for requests that have not been read from D yet. - for req_id in metadata.reqs_in_batch: - self._reqs_to_process.add(req_id) - - # Remove all requests that are not to be processed (eg aborted). - for req_id in metadata.reqs_not_processed: - self._reqs_to_process.discard(req_id) - # We should never get an abort after setting an expiry timer - assert req_id not in self._reqs_to_send - - # Add to requests that are waiting to be read and track expiration. - for req_id, expiration_time in metadata.reqs_to_send.items(): - if req_id in self._reqs_to_process: - self._reqs_to_send[req_id] = expiration_time + for req_id, req_state in metadata.reqs_to_send.items(): + if req_state == ReqState.SCHEDULED: + self._reqs_to_process.add(req_id) + elif req_state == ReqState.ABORTED: + # Remove all requests that are not to be processed (eg aborted). + self._reqs_to_process.discard(req_id) + # We should never get an abort after setting an expiry timer + assert req_id not in self._reqs_to_send + elif req_state == ReqState.FINISHED and req_id in self._reqs_to_process: + # Add to requests that are waiting to be read and track expiration. + abort_timeout = envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + logger.debug( + "req %s : waiting %d seconds for remote decode to fetch blocks", + req_id, + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + ) + self._reqs_to_send[req_id] = time.monotonic() + abort_timeout def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): assert meta.remote is not None