Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()

def add_new_req(
self,
Expand Down Expand Up @@ -278,6 +279,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()

def get_num_new_matched_tokens(
self, request: "Request",
Expand Down Expand Up @@ -324,6 +326,9 @@ def update_state_after_alloc(self, request: "Request",

if not params:
return

if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
Expand Down Expand Up @@ -373,6 +378,8 @@ def build_connector_meta(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
load_remote_cache=True,
save_to_host=False,
)

for req_id, (req, block_ids) in self._reqs_need_save.items():
Expand All @@ -386,10 +393,12 @@ def build_connector_meta(
)

meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch

# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_need_send = {}

return meta
Expand Down Expand Up @@ -546,6 +555,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._recving_transfers = defaultdict[ReqId, list[Transfer]](list)
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
# Set of requests that have been part of a batch, regardless of status.
self._reqs_to_process: set[ReqId] = set()

# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
Expand Down Expand Up @@ -1082,6 +1093,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
"Releasing expired KV blocks for request %s which were "
"retrieved by %d decode worker(s) within %d seconds.", req_id,
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
self._reqs_to_process.remove(req_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using discard() instead of remove() would be more robust. In a complex distributed system with asynchronous operations, it's safer to use discard() to avoid potential KeyError exceptions if the req_id is unexpectedly not in the set due to race conditions. This would make the worker more resilient.

A similar change would be beneficial on line 1126.

Suggested change
self._reqs_to_process.remove(req_id)
self._reqs_to_process.discard(req_id)

del self._reqs_to_send[req_id]
done_sending.add(req_id)

Expand All @@ -1097,7 +1109,8 @@ def _get_new_notifs(self) -> set[str]:
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
if req_id not in self._reqs_to_send:
if (req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process):
logger.error(
"Potentially invalid KV blocks for "
"unrecognized request %s were retrieved by "
Expand All @@ -1110,7 +1123,8 @@ def _get_new_notifs(self) -> set[str]:
tp_ratio):
notified_req_ids.add(req_id)
del self.consumer_notification_counts_by_req[req_id]
del self._reqs_to_send[req_id]
self._reqs_to_process.remove(req_id)
self._reqs_to_send.pop(req_id, None)
return notified_req_ids

def _pop_done_transfers(
Expand Down Expand Up @@ -1171,8 +1185,19 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
while not self._ready_requests.empty():
self._read_blocks_for_req(*self._ready_requests.get_nowait())

# Keep around the requests that have been part of a batch. This is
# needed because async scheduling pushes the misalignment between the
# moment in which requests expiration is set (P side) and the moment in
# which blocks are read from D. As P can now more easily lag behind D
# while processing the next batch, we make sure to only set an
# expiration for requests that have not been read from D yet.
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)

# Add to requests that are waiting to be read and track expiration.
self._reqs_to_send.update(metadata.reqs_to_send)
for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process:
self._reqs_to_send[req_id] = expiration_time

def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
Expand Down