Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1244,42 +1244,108 @@ def get_finished(self) -> tuple[set[str], set[str]]:
done_sending, done_recving = set(), set()

if self.is_producer:
done_sending = self.moriio_wrapper.pop_finished_req_ids()
done_sending_raw = self.moriio_wrapper.pop_finished_req_ids()
if self.mode == MoRIIOMode.READ:
# READ mode: the consumer (decode) notifies the producer
# (prefill) over ZMQ once it finishes the RDMA read. The
# notification carries the transfer_id (not the consumer's
# internal request_id) because each engine independently
# appends a random 8-char suffix to its request_id in
# InputProcessor.assign_request_id, so the consumer's and
# producer's internal request_ids for the same logical
# request differ. Translate back to the producer's own
# internal request_id via transfer_id_to_request_id (which
# was populated at scheduling time by update_state_after_alloc
# and synced to the worker by start_load_kv). Pop on success
# to keep the persistent worker map bounded.
for tid in done_sending_raw:
mapped = self.transfer_id_to_request_id.pop(tid, None)
if mapped is not None:
done_sending.add(mapped)
else:
logger.warning(
"get_finished (producer READ): no mapping for "
"transfer_id %s; dropping notification to avoid "
"scheduler assertion on unknown request_id",
tid,
)
else:
# WRITE mode: producer locally appends its own internal
# request_id to done_req_ids in _finalize_if_complete, so no
# translation is required.
done_sending = done_sending_raw

else:
if self.mode == MoRIIOMode.WRITE:
done_recving = self.moriio_wrapper.pop_finished_write_req_ids()
else:
done_recving = self._pop_done_transfers()

done_recving = {
self.transfer_id_to_request_id[id]
for id in filter(
lambda id: id in self.transfer_id_to_request_id, done_recving
)
}
# Translate consumer-side done_recving (transfer_ids reported by the
# producer via send_notify in WRITE mode) back to the consumer's own
# internal request_ids. Pop on success so the persistent worker map
# (populated incrementally in start_load_kv) does not grow unbounded.
translated_recving: set[str] = set()
for tid in done_recving:
mapped = self.transfer_id_to_request_id.pop(tid, None)
if mapped is not None:
translated_recving.add(mapped)
done_recving = translated_recving

return done_sending, done_recving

def _pop_done_transfers(self) -> set[str]:
done_req_ids: set[str] = set()
"""Pop completed remote-read transfers and notify the producer.

Sends the transfer_id (not the consumer's internal request_id) so the
producer can translate it back to its own internal request_id; see
get_finished() for the producer-side translation and the assign_request_id
rationale.

Returns an empty set because in READ mode the consumer scheduler does
not track recv-completion (get_num_new_matched_tokens returns
async=False, so requests never enter WAITING_FOR_REMOTE_KVS); reporting
a recv-completion here would trip the scheduler assertion at
_update_from_kv_xfer_finished. The downstream translation block in
get_finished() therefore receives an empty set and is a no-op.
"""
# Invert the worker-side transfer_id -> request_id map so we can look
# up the transfer_id for each completed entry in _recving_transfers
# (which is keyed by the consumer's internal request_id).
request_id_to_transfer_id = {
rid: tid for tid, rid in self.transfer_id_to_request_id.items()
}
with self.moriio_wrapper.lock:
to_remove = []
for req_id, status_list in self._recving_transfers.items():
if status_list[-1].Succeeded():
done_req_ids.add(req_id)

transfer_id = request_id_to_transfer_id.get(req_id)
if transfer_id is None:
logger.warning(
"_pop_done_transfers: no transfer_id mapping for "
"request %s; cannot notify producer (prefill "
"block may leak)",
req_id,
)
to_remove.append(req_id)
continue
self.moriio_wrapper.send_notify(
req_id,
transfer_id,
self._recving_transfers_callback_addr[req_id][0],
self._recving_transfers_callback_addr[req_id][1],
)
Comment on lines 1332 to 1336
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In READ mode, the consumer-side mapping in transfer_id_to_request_id is never removed because _pop_done_transfers returns an empty set, which causes the translation and cleanup block in get_finished (lines 1288-1293) to be skipped. This results in a memory leak as the mapping grows indefinitely with every request. The entry should be popped here once the notification is successfully sent.

Suggested change
self.moriio_wrapper.send_notify(
req_id,
transfer_id,
self._recving_transfers_callback_addr[req_id][0],
self._recving_transfers_callback_addr[req_id][1],
)
self.moriio_wrapper.send_notify(
transfer_id,
self._recving_transfers_callback_addr[req_id][0],
self._recving_transfers_callback_addr[req_id][1],
)
self.transfer_id_to_request_id.pop(transfer_id, None)

# Pop the transfer_id ↔ request_id mapping now: the
# downstream translation block in get_finished() runs
# off the returned set, but we return set() here (see
# docstring), so it would never pop. Without this the
# map grows unbounded for the lifetime of the engine.
self.transfer_id_to_request_id.pop(transfer_id, None)
to_remove.append(req_id)
for req_id in to_remove:
del self._recving_transfers[req_id]
del self._recving_transfers_callback_addr[req_id]

return done_req_ids
return set()

def save_kv_layer(
self,
Expand Down Expand Up @@ -1339,7 +1405,14 @@ def start_load_kv(self, metadata: MoRIIOConnectorMetadata):
Start loading by triggering non-blocking moriio_xfer.
We check for these trnxs to complete in each step().
"""
self.transfer_id_to_request_id = metadata.transfer_id_to_request_id
# Merge (rather than overwrite) so the worker-side mapping survives
# after the scheduler-side request_finished() unmaps a transfer_id.
# The producer needs this entry to translate the consumer's
# transfer_id notification (see get_finished) back to its own internal
# request_id, and that notification can arrive several steps after
# request_finished. get_finished() pops entries after a successful
# translation, so the dict stays bounded.
self.transfer_id_to_request_id.update(metadata.transfer_id_to_request_id)
if self.is_producer:
self.moriio_wrapper.async_wait_reqid()
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
HandshakeError,
LayerTransferPlan,
MoRIIOAgentMetadata,
MoRIIOConstants,
MoRIIOError,
RemoteAllocInfo,
TransferError,
Expand Down Expand Up @@ -532,13 +531,31 @@ def _handle_message(self, msg: bytes):

try:
msg_str = msg.decode("UTF-8")
if msg_str.startswith(MoRIIOConstants.TRANSFER_PREFIX):
self._handle_completion_message(msg_str)
handled = True
# Read-completion notifications carry the consumer's request_id.
# In upstream the prefix was assumed to be MoRIIOConstants.TRANSFER_PREFIX,
# but the toy-proxy convention embeds peer addresses into the request_id
# (e.g. "chatcmpl-___prefill_addr_host:...___decode_addr_host:..._UUID"),
# so the prefix never matches and the original code raised
# "Unhandled message format", killing the notify listener thread on the
# first read-completion. Treat any UTF-8 decoded payload as a completion
# message and let _handle_completion_message append it to done_req_ids;
# the scheduler's _update_from_kv_xfer_finished will reject anything that
# isn't a live request_id, so this stays safe.
self._handle_completion_message(msg_str)
handled = True
except UnicodeDecodeError:
logger.warning("Received non-UTF8 message: %s", msg_str)
# Non-UTF-8 payloads are not actionable here (the toy-proxy
# convention is UTF-8 request_ids). Logging and dropping the
# message is the right behavior; falling through into the
# MoRIIOError below would propagate to the listener loop and
# kill the notify thread on a single malformed packet.
logger.warning(
"Received non-UTF8 completion message of %d bytes; dropping",
len(msg),
)
return
if not handled:
raise MoRIIOError(f"Unhandled message format: {msg_str}")
raise MoRIIOError(f"Unhandled message format ({len(msg)} bytes)")

def _handle_structured_message(self, data: dict):
assert get_role() == ROLE.PRODUCER, "Only prefill can get block messages"
Expand Down