Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions python/sglang/srt/managers/scheduler_pp_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import math
import time
from collections import deque
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -135,6 +135,7 @@ def event_loop_pp(self: Scheduler):
self.send_proxy_work = self._pp_send_dict_to_next_stage(
result.pp_hidden_states_proxy_tensors.tensors,
async_send=True,
msg_type="proxy",
)

self.pp_outputs = next_pp_outputs
Expand Down Expand Up @@ -306,6 +307,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler):
self.send_proxy_work = self._pp_send_dict_to_next_stage(
result.pp_hidden_states_proxy_tensors.tensors,
async_send=True,
msg_type="proxy",
)

self.pp_outputs = next_pp_outputs
Expand Down Expand Up @@ -486,6 +488,7 @@ def event_loop_pp_disagg_decode(self: Scheduler):
self.send_proxy_work = self._pp_send_dict_to_next_stage(
result.pp_hidden_states_proxy_tensors.tensors,
async_send=True,
msg_type="proxy",
)

self.pp_outputs = next_pp_outputs
Expand Down Expand Up @@ -529,6 +532,9 @@ def init_pp_loop_state(self: Scheduler):
self.send_proxy_work = []
self.send_output_work = []
self.launch_event = None
self._pp_tensor_dict_inbox: Dict[str, deque[Dict[str, torch.Tensor]]] = (
defaultdict(deque)
)

def profile_and_init_predictor(self: Scheduler):
"""
Expand Down Expand Up @@ -913,7 +919,15 @@ def _pp_send_dict_to_next_stage(
self: Scheduler,
tensor_dict: Dict[str, torch.Tensor],
async_send: bool = True,
msg_type: str = "default",
):
# Warn once if using default untyped messages
if msg_type == "default":
logger.warning_once(
"PP send: using default untyped message. "
"Consider adding msg_type='proxy' or 'output' to avoid recv conflicts."
)
tensor_dict["__msg_type__"] = msg_type
p2p_work = []
p2p_work.extend(
self.pp_group.send_tensor_dict(
Expand All @@ -926,27 +940,61 @@ def _pp_send_dict_to_next_stage(
)
return p2p_work

def _pp_recv_typed_dict(
self: Scheduler,
expected_kind: str = "default",
all_gather_group: Optional = None,
) -> Dict[str, torch.Tensor]:
"""Receive a typed tensor dict, demultiplexing by msg_type.

If a message of the wrong kind is received, it's stashed in the queue
and we continue receiving until we get the expected kind.
"""
if expected_kind in self._pp_tensor_dict_inbox:
inbox_queue = self._pp_tensor_dict_inbox[expected_kind]
if inbox_queue:
return inbox_queue.popleft()

while True:
tensor_dict = self.pp_group.recv_tensor_dict(
all_gather_group=all_gather_group
)
received_kind = tensor_dict.get("__msg_type__", "default")
if received_kind == expected_kind:
if received_kind == "default":
logger.warning_once(
f"PP recv: got default untyped message. Content keys: {tensor_dict.keys()}"
"Consider adding msg_type='proxy' or 'output' to avoid recv conflicts."
)
return tensor_dict
else:
logger.debug(
f"PP recv: expected {expected_kind}, got {received_kind}, stashing"
)
self._pp_tensor_dict_inbox[received_kind].append(tensor_dict)
Comment thread
ShangmingCai marked this conversation as resolved.

def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]:
pp_proxy_tensors = None
if not self.pp_group.is_first_rank:
pp_proxy_tensors = PPProxyTensors(
self.pp_group.recv_tensor_dict(
self._pp_recv_typed_dict(
expected_kind="proxy",
all_gather_group=(
self.attn_tp_group if self.require_attn_tp_allgather else None
)
),
)
)
return pp_proxy_tensors

def _pp_recv_dict_from_prev_stage(
self: Scheduler,
) -> Dict[str, torch.Tensor]:
res = self.pp_group.recv_tensor_dict(
return self._pp_recv_typed_dict(
expected_kind="output",
all_gather_group=(
self.attn_tp_group if self.require_attn_tp_allgather else None
),
)
return res

def _pp_prep_batch_result(
self: Scheduler,
Expand Down Expand Up @@ -1000,6 +1048,7 @@ def _pp_send_output_to_next_stage(
send_output_work = self._pp_send_dict_to_next_stage(
pp_outputs_to_send.tensors,
async_send=True,
msg_type="output",
)
# send the outputs from the last round to let the next stage worker run post processing
if not self.pp_group.is_last_rank:
Expand All @@ -1008,6 +1057,7 @@ def _pp_send_output_to_next_stage(
send_output_work = self._pp_send_dict_to_next_stage(
pp_outputs.tensors,
async_send=True,
msg_type="output",
)
return send_output_work

Expand Down
Loading