diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index c39a8fc139ef..37cfd20b0197 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -3,7 +3,7 @@ import logging import math import time -from collections import deque +from collections import defaultdict, deque from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple @@ -135,6 +135,7 @@ def event_loop_pp(self: Scheduler): self.send_proxy_work = self._pp_send_dict_to_next_stage( result.pp_hidden_states_proxy_tensors.tensors, async_send=True, + msg_type="proxy", ) self.pp_outputs = next_pp_outputs @@ -306,6 +307,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): self.send_proxy_work = self._pp_send_dict_to_next_stage( result.pp_hidden_states_proxy_tensors.tensors, async_send=True, + msg_type="proxy", ) self.pp_outputs = next_pp_outputs @@ -486,6 +488,7 @@ def event_loop_pp_disagg_decode(self: Scheduler): self.send_proxy_work = self._pp_send_dict_to_next_stage( result.pp_hidden_states_proxy_tensors.tensors, async_send=True, + msg_type="proxy", ) self.pp_outputs = next_pp_outputs @@ -529,6 +532,9 @@ def init_pp_loop_state(self: Scheduler): self.send_proxy_work = [] self.send_output_work = [] self.launch_event = None + self._pp_tensor_dict_inbox: Dict[str, deque[Dict[str, torch.Tensor]]] = ( + defaultdict(deque) + ) def profile_and_init_predictor(self: Scheduler): """ @@ -913,7 +919,15 @@ def _pp_send_dict_to_next_stage( self: Scheduler, tensor_dict: Dict[str, torch.Tensor], async_send: bool = True, + msg_type: str = "default", ): + # Warn once if using default untyped messages + if msg_type == "default": + logger.warning_once( + "PP send: using default untyped message. " + "Consider adding msg_type='proxy' or 'output' to avoid recv conflicts." + ) + tensor_dict["__msg_type__"] = msg_type p2p_work = [] p2p_work.extend( self.pp_group.send_tensor_dict( @@ -926,14 +940,48 @@ def _pp_send_dict_to_next_stage( ) return p2p_work + def _pp_recv_typed_dict( + self: Scheduler, + expected_kind: str = "default", + all_gather_group: Optional = None, + ) -> Dict[str, torch.Tensor]: + """Receive a typed tensor dict, demultiplexing by msg_type. + + If a message of the wrong kind is received, it's stashed in the queue + and we continue receiving until we get the expected kind. + """ + if expected_kind in self._pp_tensor_dict_inbox: + inbox_queue = self._pp_tensor_dict_inbox[expected_kind] + if inbox_queue: + return inbox_queue.popleft() + + while True: + tensor_dict = self.pp_group.recv_tensor_dict( + all_gather_group=all_gather_group + ) + received_kind = tensor_dict.get("__msg_type__", "default") + if received_kind == expected_kind: + if received_kind == "default": + logger.warning_once( + f"PP recv: got default untyped message. Content keys: {tensor_dict.keys()}" + "Consider adding msg_type='proxy' or 'output' to avoid recv conflicts." + ) + return tensor_dict + else: + logger.debug( + f"PP recv: expected {expected_kind}, got {received_kind}, stashing" + ) + self._pp_tensor_dict_inbox[received_kind].append(tensor_dict) + def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]: pp_proxy_tensors = None if not self.pp_group.is_first_rank: pp_proxy_tensors = PPProxyTensors( - self.pp_group.recv_tensor_dict( + self._pp_recv_typed_dict( + expected_kind="proxy", all_gather_group=( self.attn_tp_group if self.require_attn_tp_allgather else None - ) + ), ) ) return pp_proxy_tensors @@ -941,12 +989,12 @@ def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]: def _pp_recv_dict_from_prev_stage( self: Scheduler, ) -> Dict[str, torch.Tensor]: - res = self.pp_group.recv_tensor_dict( + return self._pp_recv_typed_dict( + expected_kind="output", all_gather_group=( self.attn_tp_group if self.require_attn_tp_allgather else None ), ) - return res def _pp_prep_batch_result( self: Scheduler, @@ -1000,6 +1048,7 @@ def _pp_send_output_to_next_stage( send_output_work = self._pp_send_dict_to_next_stage( pp_outputs_to_send.tensors, async_send=True, + msg_type="output", ) # send the outputs from the last round to let the next stage worker run post processing if not self.pp_group.is_last_rank: @@ -1008,6 +1057,7 @@ def _pp_send_output_to_next_stage( send_output_work = self._pp_send_dict_to_next_stage( pp_outputs.tensors, async_send=True, + msg_type="output", ) return send_output_work