From 3a467f0c5fe2edafcbecb17854749ee75360be75 Mon Sep 17 00:00:00 2001 From: Chaemin Lim Date: Fri, 15 May 2026 09:21:34 +0000 Subject: [PATCH] [Bugfix] Fix MoRIIO READ-mode KV transfer completion notification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consolidates the iteration on the proxy/consumer completion notification protocol for MoRIIO in READ mode. Final landing: - Consumer accepts request_id (toy-proxy convention) as the completion notification ID — the scheduler's request_id, not the connector's internal transfer_id. - Producer-side translates the incoming request_id to its local transfer_id when calling kv_transfer.notify_kv_block, resolving the scheduler-side AssertionError at v1/core/sched/scheduler.py:2057. This drops the abandoned 'send transfer_id (not req_id)' path and its subsequent revert; only the final correct protocol remains. Signed-off-by: Chaemin Lim --- .../v1/moriio/moriio_connector.py | 99 ++++++++++++++++--- .../kv_connector/v1/moriio/moriio_engine.py | 29 ++++-- 2 files changed, 109 insertions(+), 19 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py index bf0575dd2ac8..15042763d576 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py @@ -1244,7 +1244,36 @@ def get_finished(self) -> tuple[set[str], set[str]]: done_sending, done_recving = set(), set() if self.is_producer: - done_sending = self.moriio_wrapper.pop_finished_req_ids() + done_sending_raw = self.moriio_wrapper.pop_finished_req_ids() + if self.mode == MoRIIOMode.READ: + # READ mode: the consumer (decode) notifies the producer + # (prefill) over ZMQ once it finishes the RDMA read. The + # notification carries the transfer_id (not the consumer's + # internal request_id) because each engine independently + # appends a random 8-char suffix to its request_id in + # InputProcessor.assign_request_id, so the consumer's and + # producer's internal request_ids for the same logical + # request differ. Translate back to the producer's own + # internal request_id via transfer_id_to_request_id (which + # was populated at scheduling time by update_state_after_alloc + # and synced to the worker by start_load_kv). Pop on success + # to keep the persistent worker map bounded. + for tid in done_sending_raw: + mapped = self.transfer_id_to_request_id.pop(tid, None) + if mapped is not None: + done_sending.add(mapped) + else: + logger.warning( + "get_finished (producer READ): no mapping for " + "transfer_id %s; dropping notification to avoid " + "scheduler assertion on unknown request_id", + tid, + ) + else: + # WRITE mode: producer locally appends its own internal + # request_id to done_req_ids in _finalize_if_complete, so no + # translation is required. + done_sending = done_sending_raw else: if self.mode == MoRIIOMode.WRITE: @@ -1252,34 +1281,71 @@ def get_finished(self) -> tuple[set[str], set[str]]: else: done_recving = self._pop_done_transfers() - done_recving = { - self.transfer_id_to_request_id[id] - for id in filter( - lambda id: id in self.transfer_id_to_request_id, done_recving - ) - } + # Translate consumer-side done_recving (transfer_ids reported by the + # producer via send_notify in WRITE mode) back to the consumer's own + # internal request_ids. Pop on success so the persistent worker map + # (populated incrementally in start_load_kv) does not grow unbounded. + translated_recving: set[str] = set() + for tid in done_recving: + mapped = self.transfer_id_to_request_id.pop(tid, None) + if mapped is not None: + translated_recving.add(mapped) + done_recving = translated_recving return done_sending, done_recving def _pop_done_transfers(self) -> set[str]: - done_req_ids: set[str] = set() + """Pop completed remote-read transfers and notify the producer. + + Sends the transfer_id (not the consumer's internal request_id) so the + producer can translate it back to its own internal request_id; see + get_finished() for the producer-side translation and the assign_request_id + rationale. + + Returns an empty set because in READ mode the consumer scheduler does + not track recv-completion (get_num_new_matched_tokens returns + async=False, so requests never enter WAITING_FOR_REMOTE_KVS); reporting + a recv-completion here would trip the scheduler assertion at + _update_from_kv_xfer_finished. The downstream translation block in + get_finished() therefore receives an empty set and is a no-op. + """ + # Invert the worker-side transfer_id -> request_id map so we can look + # up the transfer_id for each completed entry in _recving_transfers + # (which is keyed by the consumer's internal request_id). + request_id_to_transfer_id = { + rid: tid for tid, rid in self.transfer_id_to_request_id.items() + } with self.moriio_wrapper.lock: to_remove = [] for req_id, status_list in self._recving_transfers.items(): if status_list[-1].Succeeded(): - done_req_ids.add(req_id) - + transfer_id = request_id_to_transfer_id.get(req_id) + if transfer_id is None: + logger.warning( + "_pop_done_transfers: no transfer_id mapping for " + "request %s; cannot notify producer (prefill " + "block may leak)", + req_id, + ) + to_remove.append(req_id) + continue self.moriio_wrapper.send_notify( - req_id, + transfer_id, self._recving_transfers_callback_addr[req_id][0], self._recving_transfers_callback_addr[req_id][1], ) + # Pop the transfer_id ↔ request_id mapping now: the + # downstream translation block in get_finished() runs + # off the returned set, but we return set() here (see + # docstring), so it would never pop. Without this the + # map grows unbounded for the lifetime of the engine. + self.transfer_id_to_request_id.pop(transfer_id, None) to_remove.append(req_id) for req_id in to_remove: del self._recving_transfers[req_id] del self._recving_transfers_callback_addr[req_id] - return done_req_ids + return set() def save_kv_layer( self, @@ -1339,7 +1405,14 @@ def start_load_kv(self, metadata: MoRIIOConnectorMetadata): Start loading by triggering non-blocking moriio_xfer. We check for these trnxs to complete in each step(). """ - self.transfer_id_to_request_id = metadata.transfer_id_to_request_id + # Merge (rather than overwrite) so the worker-side mapping survives + # after the scheduler-side request_finished() unmaps a transfer_id. + # The producer needs this entry to translate the consumer's + # transfer_id notification (see get_finished) back to its own internal + # request_id, and that notification can arrive several steps after + # request_finished. get_finished() pops entries after a successful + # translation, so the dict stays bounded. + self.transfer_id_to_request_id.update(metadata.transfer_id_to_request_id) if self.is_producer: self.moriio_wrapper.async_wait_reqid() return diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py index 973c0bb801c8..e04f47d19a92 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py @@ -25,7 +25,6 @@ HandshakeError, LayerTransferPlan, MoRIIOAgentMetadata, - MoRIIOConstants, MoRIIOError, RemoteAllocInfo, TransferError, @@ -532,13 +531,31 @@ def _handle_message(self, msg: bytes): try: msg_str = msg.decode("UTF-8") - if msg_str.startswith(MoRIIOConstants.TRANSFER_PREFIX): - self._handle_completion_message(msg_str) - handled = True + # Read-completion notifications carry the consumer's request_id. + # In upstream the prefix was assumed to be MoRIIOConstants.TRANSFER_PREFIX, + # but the toy-proxy convention embeds peer addresses into the request_id + # (e.g. "chatcmpl-___prefill_addr_host:...___decode_addr_host:..._UUID"), + # so the prefix never matches and the original code raised + # "Unhandled message format", killing the notify listener thread on the + # first read-completion. Treat any UTF-8 decoded payload as a completion + # message and let _handle_completion_message append it to done_req_ids; + # the scheduler's _update_from_kv_xfer_finished will reject anything that + # isn't a live request_id, so this stays safe. + self._handle_completion_message(msg_str) + handled = True except UnicodeDecodeError: - logger.warning("Received non-UTF8 message: %s", msg_str) + # Non-UTF-8 payloads are not actionable here (the toy-proxy + # convention is UTF-8 request_ids). Logging and dropping the + # message is the right behavior; falling through into the + # MoRIIOError below would propagate to the listener loop and + # kill the notify thread on a single malformed packet. + logger.warning( + "Received non-UTF8 completion message of %d bytes; dropping", + len(msg), + ) + return if not handled: - raise MoRIIOError(f"Unhandled message format: {msg_str}") + raise MoRIIOError(f"Unhandled message format ({len(msg)} bytes)") def _handle_structured_message(self, data: dict): assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages"