1010import traceback
1111import weakref
1212from collections import deque
13- from collections .abc import Callable
13+ from collections .abc import Callable , Sequence
1414from concurrent .futures import Future , InvalidStateError
1515from contextlib import suppress
1616from dataclasses import dataclass
@@ -172,6 +172,7 @@ def _init_executor(self) -> None:
172172 # Start background thread to monitor worker health if not in headless mode.
173173 if self .monitor_workers :
174174 self .start_worker_monitor ()
175+
175176 self .response_mqs = []
176177 # Only leader node have remote response mqs
177178 if self .parallel_config .node_rank_within_dp == 0 :
@@ -186,6 +187,7 @@ def _init_executor(self) -> None:
186187 ]
187188 assert remote_message_queue is not None
188189 self .response_mqs .append (remote_message_queue )
190+
189191 # Ensure message queues are ready. Will deadlock if re-ordered
190192 # Must be kept consistent with the WorkerProc.
191193
@@ -280,22 +282,6 @@ def take_draft_token_ids(self) -> DraftTokenIds | None:
280282 "take_draft_token_ids" , unique_reply_rank = self .output_rank
281283 )
282284
283- def get_response_mqs (
284- self , unique_reply_rank : int | None = None
285- ) -> list [MessageQueue ]:
286- message_queues = []
287- for rank in range (self .world_size ):
288- if rank < self .local_world_size :
289- local_message_queue = self .workers [rank ].worker_response_mq
290- message_queues .append (local_message_queue )
291- else :
292- remote_message_queue = self .workers [0 ].peer_worker_response_mqs [rank ]
293- assert remote_message_queue is not None
294- message_queues .append (remote_message_queue )
295- if unique_reply_rank is not None :
296- message_queues = [message_queues [unique_reply_rank ]]
297- return message_queues
298-
299285 def collective_rpc ( # type: ignore[override]
300286 self ,
301287 method : str | Callable ,
@@ -332,9 +318,9 @@ def collective_rpc( # type: ignore[override]
332318 send_method = cloudpickle .dumps (method , protocol = pickle .HIGHEST_PROTOCOL )
333319 self .rpc_broadcast_mq .enqueue ((send_method , args , kwargs , output_rank ))
334320
335- response_mqs = self .response_mqs
336- if unique_reply_rank is not None :
337- response_mqs = [ self . response_mqs [unique_reply_rank ]]
321+ response_mqs : Sequence [ MessageQueue ] = self .response_mqs
322+ if output_rank is not None :
323+ response_mqs = ( response_mqs [output_rank ],)
338324
339325 shutdown_event = self .shutdown_event
340326
0 commit comments