diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index dc13fb840021..08d06a1f63c4 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -637,6 +637,16 @@ def profile_and_init_predictor(self: Scheduler): f"seq_lens={seq_lens}, latencies_ms={latencies}" ) + if self.attn_tp_size > 1: + data_to_sync_tp = [seq_lens, latencies] + data_to_sync_tp = broadcast_pyobj( + data_to_sync_tp, + self.attn_tp_group.rank, + self.attn_tp_cpu_group, + src=self.attn_tp_group.ranks[0], + ) + seq_lens, latencies = data_to_sync_tp + # Broadcast data to all ranks if torch.distributed.is_available() and torch.distributed.is_initialized(): data_to_sync = [seq_lens, latencies] @@ -866,7 +876,7 @@ def _pp_recv_pyobj_from_prev_stage(self: Scheduler): else: data = None - if self.attn_tp_size != 1: + if self.attn_tp_size > 1: data = broadcast_pyobj( data, self.attn_tp_group.rank,