From d172c8233841a8302b7d1014a0351e21985e19f9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 18 Oct 2025 20:05:06 -0700 Subject: [PATCH 01/28] add pp mixin --- .../sglang/srt/managers/scheduler_pp_mixin.py | 688 ++++++++++++++++++ 1 file changed, 688 insertions(+) create mode 100644 python/sglang/srt/managers/scheduler_pp_mixin.py diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py new file mode 100644 index 000000000000..ec6247b422d5 --- /dev/null +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -0,0 +1,688 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pipeline parallelism mixin for scheduler - contains PP event loop and utilities.""" + +from __future__ import annotations + +import logging +from collections import deque +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +import torch + +from sglang.srt.disaggregation.base.conn import KVPoll +from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.distributed.parallel_state import P2PWork +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.utils import ( + GenerationBatchResult, + get_logprob_dict_from_result, + get_logprob_from_pp_outputs, +) +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors +from sglang.srt.utils import DynamicGradMode, broadcast_pyobj, point_to_point_pyobj + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.managers.scheduler import Scheduler + + +@dataclass +class PPBatchMetadata: + bid: int + can_run_cuda_graph: bool + + +class SchedulerPPMixin: + def _pp_commit_comm_work(self: Scheduler, work: List[P2PWork]) -> None: + for p2p_work in work: + p2p_work.work.wait() + work.clear() + + def send_pyobj_to_next_stage(self: Scheduler, data, async_send: bool = False): + p2p_work = [] + if self.attn_tp_rank == 0: + dp_offset = self.dp_rank * self.attn_tp_size + p2p_work = point_to_point_pyobj( + data, + self.pp_rank * self.tp_size + dp_offset, + self.world_group.cpu_group, + self.pp_rank * self.tp_size + dp_offset, + ((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset, + async_send=async_send, + ) + return p2p_work + + def recv_pyobj_from_prev_stage(self: Scheduler): + if self.attn_tp_rank == 0: + dp_offset = self.dp_rank * self.attn_tp_size + data = point_to_point_pyobj( + [], + self.pp_rank * self.tp_size + dp_offset, + self.world_group.cpu_group, + ((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset, + self.pp_rank * self.tp_size + dp_offset, + ) + else: + data = None + + if self.tp_size != 1: + data = broadcast_pyobj( + data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0] + ) + + return data + + def _pp_prepare_tensor_dict( + self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch + ) -> Dict[str, torch.Tensor]: + tensor_dict = { + "next_token_ids": result.next_token_ids, + } + + if batch.return_logprob: + logprob_dict = get_logprob_dict_from_result(result) + tensor_dict = { + **tensor_dict, + **logprob_dict, + } + return tensor_dict + + def _pp_send_dict_to_next_stage( + self: Scheduler, + tensor_dict: Dict[str, torch.Tensor], + async_send: bool = True, + ): + p2p_work = [] + + p2p_work.extend( + self.pp_group.send_tensor_dict( + tensor_dict, + all_gather_group=self.attn_tp_group, + async_send=async_send, + ) + ) + return p2p_work + + def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]: + pp_proxy_tensors = None + if not self.pp_group.is_first_rank: + pp_proxy_tensors = PPProxyTensors( + self.pp_group.recv_tensor_dict(all_gather_group=self.attn_tp_group) + ) + return pp_proxy_tensors + + def _pp_recv_dict_from_prev_stage( + self: Scheduler, + ) -> Dict[str, torch.Tensor]: + res = self.pp_group.recv_tensor_dict( + all_gather_group=self.attn_tp_group, + ) + return res + + def _pp_prep_batch_result( + self: Scheduler, + batch: ScheduleBatch, + mb_metadata: PPBatchMetadata, + pp_outputs: PPProxyTensors, + ): + logits_output = None + extend_input_len_per_req = None + extend_logprob_start_len_per_req = None + + if batch.return_logprob: + ( + logits_output, + extend_input_len_per_req, + extend_logprob_start_len_per_req, + ) = get_logprob_from_pp_outputs(pp_outputs) + batch.output_ids = pp_outputs["next_token_ids"] + output_result = GenerationBatchResult( + logits_output=logits_output, + pp_hidden_states_proxy_tensors=None, + next_token_ids=pp_outputs["next_token_ids"], + extend_input_len_per_req=extend_input_len_per_req, + extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, + bid=mb_metadata.bid, + can_run_cuda_graph=mb_metadata.can_run_cuda_graph, + ) + return output_result + + def _pp_process_batch_result( + self: Scheduler, batch: ScheduleBatch, output_result: GenerationBatchResult + ): + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.process_batch_result_disagg_prefill(batch, output_result) + else: + self.process_batch_result(batch, output_result) + + def _pp_send_output_to_next_stage( + self: Scheduler, + next_first_rank_mb_id: int, + mbs: List[ScheduleBatch], + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], + pp_outputs: PPProxyTensors | None, + ) -> List[P2PWork]: + send_output_work = [] + if self.pp_group.is_last_rank: + # send ready PP output to rank 0 + if mbs[next_first_rank_mb_id] is not None: + q_event, pp_outputs_to_send = last_rank_comm_queue.popleft() + torch.cuda.current_stream().wait_event(q_event) + with torch.profiler.record_function("send_res_dict_to_next_stage"): + send_output_work = self._pp_send_dict_to_next_stage( + pp_outputs_to_send.tensors, + async_send=True, + ) + # send the outputs from the last round to let the next stage worker run post processing + if not self.pp_group.is_last_rank: + if pp_outputs: + with torch.profiler.record_function("send_res_dict_to_next_stage"): + send_output_work = self._pp_send_dict_to_next_stage( + pp_outputs.tensors, + async_send=True, + ) + return send_output_work + + def _pp_send_recv_and_preprocess_output_tensors( + self: Scheduler, + next_first_rank_mb_id: int, + next_mb_id: int, + mbs: List[ScheduleBatch], + mb_metadata: List[PPBatchMetadata], + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], + pp_outputs: PPProxyTensors | None, + ) -> Tuple[PPProxyTensors, List[P2PWork], torch.cuda.Event]: + next_pp_outputs = None + d2h_event = None + batch_result = None + send_output_work = self._pp_send_output_to_next_stage( + next_first_rank_mb_id, + mbs, + last_rank_comm_queue, + pp_outputs, + ) + + if mbs[next_mb_id] is not None: + with torch.profiler.record_function("recv_res_dict_from_prev_stage"): + next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) + self._pp_commit_comm_work(work=send_output_work) + if mbs[next_mb_id] is not None: + batch_result = self._pp_prep_batch_result( + mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs + ) + d2h_event = self.process_batch_result_d2h(mbs[next_mb_id], batch_result) + + return next_pp_outputs, batch_result, d2h_event + + def _pp_launch_batch( + self: Scheduler, + mb_id: int, + pp_proxy_tensors: PPProxyTensors, + mb_metadata: PPBatchMetadata, + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], + ): + with torch.profiler.record_function("run_batch"): + result = self.run_batch(self.cur_batch, pp_proxy_tensors) + mb_metadata[mb_id] = PPBatchMetadata( + bid=result.bid, + can_run_cuda_graph=result.can_run_cuda_graph, + ) + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + if self.pp_group.is_last_rank: + # (last rank) buffer the outputs for async batch depth + last_rank_comm_queue.append( + ( + event, + PPProxyTensors( + self._pp_prepare_tensor_dict(result, self.cur_batch) + ), + ) + ) + return result, event + + @DynamicGradMode() + def event_loop_pp(self: Scheduler): + """ + A scheduler loop for pipeline parallelism. + Notes: + 1. Each stage runs in the same order and is notified by the previous stage. + 2. We use async send but sync recv to avoid desynchronization while minimizing the communication overhead. + 3. We can use async batch depth to buffer the outputs in the last stage for to allow overlapping the GPU computation and CPU processing and avoid last PP rank staggler. + + Unified Schedule: + ==================================================================== + Stage P + recv ith req from previous stage + recv ith proxy from previous stage + run ith batch + recv prev (i+1)% mb_size th outputs + process batch result of prev (i+1)% mb_size th batch (can be run in parallel with the curr batch GPU computation) + send ith req to next stage + send ith proxy to next stage + send current stage's outputs to next stage(can be stashed and delayed to send later) + + the above order can be optimized and reordered to minimize communication-related CPU stall and overhead bubbles. + + ==================================================================== + """ + self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth + mbs = [None] * self.pp_loop_size + last_mbs = [None] * self.pp_loop_size + self.running_mbs = [ + ScheduleBatch(reqs=[], batch_is_full=False) + for _ in range(self.pp_loop_size) + ] + mb_metadata: List[Optional[PPBatchMetadata]] = [None] * self.pp_loop_size + pp_outputs: Optional[PPProxyTensors] = None + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = deque() + send_req_work = [] + send_proxy_work = [] + event = None + while True: + server_is_idle = True + for mb_id in range(self.pp_loop_size): + self.running_batch = self.running_mbs[mb_id] + self.last_batch = last_mbs[mb_id] + next_first_rank_mb_id = (mb_id + self.pp_size) % self.pp_loop_size + next_mb_id = (mb_id + 1) % self.pp_loop_size + with torch.profiler.record_function("recv_requests"): + recv_reqs = self.recv_requests() + self._pp_commit_comm_work(send_req_work) + self.process_input_requests(recv_reqs) + with torch.profiler.record_function("get_next_batch_to_run"): + mbs[mb_id] = self.get_next_batch_to_run() + self.running_mbs[mb_id] = self.running_batch + self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] + if self.cur_batch: + server_is_idle = False + pp_proxy_tensors = self._pp_recv_proxy_tensors() + self._pp_commit_comm_work(send_proxy_work) + next_pp_outputs = None + next_batch_result = None + d2h_event = None + if self.server_args.pp_async_batch_depth > 0: + next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, + ) + ) + + if self.cur_batch: + result, event = self._pp_launch_batch( + mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue + ) + if self.server_args.pp_async_batch_depth == 0: + next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, + ) + ) + if mbs[next_mb_id] is not None: + d2h_event.synchronize() + with torch.profiler.record_function("process_batch_result"): + self._pp_process_batch_result( + mbs[next_mb_id], + next_batch_result, + ) + last_mbs[next_mb_id] = mbs[next_mb_id] + if not self.pp_group.is_last_rank: + with torch.profiler.record_function("send_reqs_to_next_stage"): + send_req_work = self.send_pyobj_to_next_stage( + recv_reqs, + async_send=True, + ) + if self.cur_batch: + torch.cuda.current_stream().wait_event(event) + with torch.profiler.record_function( + "send_proxy_dict_to_next_stage" + ): + send_proxy_work = self._pp_send_dict_to_next_stage( + result.pp_hidden_states_proxy_tensors, + async_send=True, + ) + + if self.delayed_weight_sync_fn: + self.delayed_weight_sync_fn() + self.delayed_weight_sync_fn = None + + pp_outputs = next_pp_outputs + + # When the server is idle, self-check and re-init some states + if server_is_idle: + self.check_memory() + self.check_tree_cache() + self.new_token_ratio = self.init_new_token_ratio + self.maybe_sleep_on_idle() + + def process_bootstrapped_queue( + self: Scheduler, bootstrapped_rids: Optional[List[str]] + ): + # finished consensus bootstrapped reqs and prepare the waiting queue + if bootstrapped_rids is not None: + ( + good_consensus_bootstrapped_rids, + bad_consensus_bootstrapped_rids, + ) = bootstrapped_rids + good_reqs, failed_reqs = ( + self.disagg_prefill_bootstrap_queue.pop_bootstrapped( + return_failed_reqs=True, + rids_to_check=good_consensus_bootstrapped_rids, + bad_rids_to_check=bad_consensus_bootstrapped_rids, + ) + ) + self.waiting_queue.extend(good_reqs) + return [[req.rid for req in good_reqs], [req.rid for req in failed_reqs]] + return None + + def _pp_pd_get_bootstrapped_ids(self: Scheduler): + # communicate pre-consensus bootstrapp reqs + if self.pp_group.is_first_rank: + # First rank, pop the bootstrap reqs from the bootstrap queue + good_bootstrapped_rids, bad_bootstrapped_rids = self.get_rids( + self.disagg_prefill_bootstrap_queue.queue, + [KVPoll.WaitingForInput], + [KVPoll.Failed], + ) + else: + # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the concensus + prev_bootstrapped_rids = self.recv_pyobj_from_prev_stage() + prev_good_bootstrapped_rids, prev_bad_bootstrapped_rids = ( + prev_bootstrapped_rids + ) + curr_good_bootstrapped_rids, curr_bad_bootstrapped_rids = self.get_rids( + self.disagg_prefill_bootstrap_queue.queue, + [KVPoll.WaitingForInput], + [KVPoll.Failed], + ) + good_bootstrapped_rids = list( + set(prev_good_bootstrapped_rids) & set(curr_good_bootstrapped_rids) + ) + bad_bootstrapped_rids = list( + set(prev_bad_bootstrapped_rids) | set(curr_bad_bootstrapped_rids) + ) + return [good_bootstrapped_rids, bad_bootstrapped_rids] + + def _pp_pd_get_transferred_ids(self: Scheduler): + # get the current stage transfer success + if self.pp_group.is_first_rank: + transferred_rids = self.get_rids( + self.disagg_prefill_inflight_queue, + [KVPoll.Success, KVPoll.Failed], + ) + # if other ranks, do intersection with the previous rank's transferred rids + else: + # 2 (Release): Receive the transferred rids from the previous rank + # 1. recv previous stage's transferred reqs info + prev_transferred_rids = self.recv_pyobj_from_prev_stage() + # 2. get the current stage's transferred reqs info + curr_transferred_rids = self.get_rids( + self.disagg_prefill_inflight_queue, + [KVPoll.Success, KVPoll.Failed], + ) + # 3. new concensus rids = intersection(previous concensus rids, transfer finished rids) + transferred_rids = list( + set(prev_transferred_rids) & set(curr_transferred_rids) + ) + return transferred_rids + + def _pp_pd_send_consensus_bootstrapped_ids( + self: Scheduler, + bmbs: List[List[str]], + next_first_rank_mb_id: int, + consensus_bootstrapped_rids: List[str], + bootstrapped_rids: List[str], + ): + # 3 (Release): send the release rids from last stage to the first stage + send_consensus_bootstrapped_work = [] + if self.pp_group.is_last_rank: + if bmbs[next_first_rank_mb_id] is not None: + consensus_bootstrapped_rids = bootstrapped_rids + send_consensus_bootstrapped_work = self.send_pyobj_to_next_stage( + consensus_bootstrapped_rids, async_send=True + ) + # 4 (Release): send the release rids from non last rank to the next rank + else: + if consensus_bootstrapped_rids is not None: + send_consensus_bootstrapped_work = self.send_pyobj_to_next_stage( + consensus_bootstrapped_rids, async_send=True + ) + return send_consensus_bootstrapped_work, consensus_bootstrapped_rids + + def _pp_pd_send_consensus_release_ids( + self: Scheduler, + tmbs: List[List[str]], + next_first_rank_mb_id: int, + release_rids: List[str], + transferred_rids: List[str], + ): + send_release_work = [] + if self.pp_group.is_last_rank: + if tmbs[next_first_rank_mb_id] is not None: + release_rids = transferred_rids + send_release_work = self.send_pyobj_to_next_stage( + release_rids, async_send=True + ) + # 4 (Release): send the release rids from non last rank to the next rank + else: + if release_rids is not None: + send_release_work = self.send_pyobj_to_next_stage( + release_rids, async_send=True + ) + return send_release_work, release_rids + + @DynamicGradMode() + def event_loop_pp_disagg_prefill(self: Scheduler): + """ + This is the prefill server event loop for pipeline parallelism. + + Notes: + 1. Following the same rules as the event_loop_pp. + 2. Adds extra steps for KV transfer process: bootstrap + release. + + Prefill Server Schedule: + ==================================================================== + Stage P + recv ith req from previous stage + recv ith bootstrap req from previous stage + recv ith transferred req from previous stage + recv ith proxy from previous stage + run ith batch + recv prev (i+1) % mb_size th consensus bootstrapped req from previous stage + local consensus on bootstrapped req + recv prev (i+1) % mb_size th release req from previous stage + local consensus on release req + recv prev (i+1) % mb_size th outputs + process batch result of prev (i+1)% mb_size th batch (can be run in parallel with the curr batch GPU computation) + send ith req to next stage + send ith bootstrap req to next stage + send ith transferred req to next stage + send ith proxy to next stage + send current stage's outputs to next stage (can be stashed and delayed to send later) + + the above order can be optimized and reordered to minimize communication-related CPU stall and overhead bubbles. + ==================================================================== + + There are two additional elements compared to the regular schedule: + + Bootstrap Requests + Release Requests: + - Both can have local failure and need to be consensus on. PP needs to gurantee eventual consistency of local failure and flush malfunc requests out as soft error. + + """ + self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth + mbs = [None] * self.pp_loop_size + last_mbs = [None] * self.pp_loop_size + self.running_mbs = [ + ScheduleBatch(reqs=[], batch_is_full=False) + for _ in range(self.pp_loop_size) + ] + mb_metadata: List[Optional[PPBatchMetadata]] = [None] * self.pp_loop_size + pp_outputs: Optional[PPProxyTensors] = None + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = deque() + + # PD additionals + consensus_bootstrapped_rids: Optional[List[str]] = None + transferred_rids: List[str] = [] + release_rids: Optional[List[str]] = None + tmbs = [None] * self.pp_loop_size + bmbs = [None] * self.pp_loop_size + + send_req_work = [] + send_bootstrapped_work = [] + send_consensus_bootstrapped_work = [] + send_proxy_work = [] + send_release_work = [] + send_transfer_work = [] + + while True: + server_is_idle = True + for mb_id in range(self.pp_loop_size): + self.running_batch = self.running_mbs[mb_id] + self.last_batch = last_mbs[mb_id] + next_first_rank_mb_id = (mb_id + self.pp_size) % self.pp_loop_size + next_mb_id = (mb_id + 1) % self.pp_loop_size + + next_pp_outputs = None + next_release_rids = None + next_consensus_bootstrapped_rids = None + d2h_event = None + next_batch_result = None + + recv_reqs = self.recv_requests() + self._pp_commit_comm_work(send_req_work) + self.process_input_requests(recv_reqs) + + bootstrapped_rids = self._pp_pd_get_bootstrapped_ids() + bmbs[mb_id] = bootstrapped_rids + self._pp_commit_comm_work(send_bootstrapped_work) + + transferred_rids = self._pp_pd_get_transferred_ids() + self._pp_commit_comm_work(send_transfer_work) + tmbs[mb_id] = transferred_rids + + self.process_prefill_chunk_pp() + mbs[mb_id] = self.get_new_batch_prefill() + self.running_mbs[mb_id] = self.running_batch + + self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] + if self.cur_batch: + server_is_idle = False + pp_proxy_tensors = self._pp_recv_proxy_tensors() + self._pp_commit_comm_work(send_proxy_work) + if self.server_args.pp_async_batch_depth > 0: + next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, + ) + ) + if self.cur_batch: + result, event = self._pp_launch_batch( + mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue + ) + if self.server_args.pp_async_batch_depth == 0: + next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, + ) + ) + send_consensus_bootstrapped_work, consensus_bootstrapped_rids = ( + self._pp_pd_send_consensus_bootstrapped_ids( + bmbs, + next_first_rank_mb_id, + consensus_bootstrapped_rids, + bootstrapped_rids, + ) + ) + send_release_work, release_rids = ( + self._pp_pd_send_consensus_release_ids( + tmbs, next_first_rank_mb_id, release_rids, transferred_rids + ) + ) + + if bmbs[next_mb_id] is not None: + next_consensus_bootstrapped_rids = self.recv_pyobj_from_prev_stage() + next_consensus_bootstrapped_rids = self.process_bootstrapped_queue( + next_consensus_bootstrapped_rids + ) + self._pp_commit_comm_work(send_consensus_bootstrapped_work) + if tmbs[next_mb_id] is not None: + next_release_rids = self.recv_pyobj_from_prev_stage() + self._pp_commit_comm_work(send_release_work) + # post-process the coming microbatch + if mbs[next_mb_id] is not None: + d2h_event.synchronize() + self._pp_process_batch_result( + mbs[next_mb_id], + next_batch_result, + ) + last_mbs[next_mb_id] = mbs[next_mb_id] + + if tmbs[next_mb_id] is not None: + self.process_disagg_prefill_inflight_queue(next_release_rids) + if not self.pp_group.is_last_rank: + send_req_work = self.send_pyobj_to_next_stage( + recv_reqs, async_send=True + ) + send_bootstrapped_work = self.send_pyobj_to_next_stage( + bootstrapped_rids, async_send=True + ) + send_transfer_work = self.send_pyobj_to_next_stage( + transferred_rids, async_send=True + ) + if self.cur_batch: + torch.cuda.current_stream().wait_event(event) + send_proxy_work = self._pp_send_dict_to_next_stage( + result.pp_hidden_states_proxy_tensors, + async_send=True, + ) + + if self.delayed_weight_sync_fn: + self.delayed_weight_sync_fn() + self.delayed_weight_sync_fn = None + + pp_outputs = next_pp_outputs + release_rids = next_release_rids + consensus_bootstrapped_rids = next_consensus_bootstrapped_rids + + self.running_batch.batch_is_full = False + + # When the server is idle, self-check and re-init some states + if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0: + self.check_memory() + self.check_tree_cache() + self.new_token_ratio = self.init_new_token_ratio + self.maybe_sleep_on_idle() From 437f100b87390d218f37217eb8e4e71d564166c5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 18 Oct 2025 20:16:47 -0700 Subject: [PATCH 02/28] update point_to_point_pyobj --- .../sglang/srt/managers/scheduler_pp_mixin.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index ec6247b422d5..2c8b2a80eb21 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -686,3 +686,71 @@ def event_loop_pp_disagg_prefill(self: Scheduler): self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio self.maybe_sleep_on_idle() + + + +# Keep this function for xai's PP implementation +def point_to_point_pyobj( + data: List[Any], + rank: int, + group: Optional[torch.distributed.ProcessGroup] = None, + src: int = 0, + dst: int = 1, + async_send: bool = False, +): + """Send data from src to dst in group.""" + if async_send: + send_func = dist.isend + else: + send_func = dist.send + if rank == src: + p2p_works = [] + if len(data) == 0: + tensor_size = torch.tensor( + [0], + dtype=torch.long, + ) + work = send_func(tensor_size, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_size)) + else: + serialized_data = pickle.dumps(data) + size = len(serialized_data) + tensor_data = torch.ByteTensor( + np.frombuffer(serialized_data, dtype=np.uint8) + ) + tensor_size = torch.tensor([size], dtype=torch.long) + + work = send_func(tensor_size, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_size)) + work = send_func(tensor_data, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_data)) + return p2p_works + + elif rank == dst: + tensor_size = torch.tensor( + [0], + dtype=torch.long, + ) + work = dist.irecv(tensor_size, src=src, group=group) + work.wait() + size = tensor_size.item() + + if size == 0: + return [] + + tensor_data = torch.empty( + size, + dtype=torch.uint8, + ) + work = dist.irecv(tensor_data, src=src, group=group) + work.wait() + + serialized_data = bytes(tensor_data.cpu().numpy()) + data = pickle.loads(serialized_data) + return data + + # Other ranks in pp_group do nothing + return [] From 27b78d503bfb71ac21f716f8f748fd9c0f7f7d8d Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Mon, 20 Oct 2025 15:20:17 +0800 Subject: [PATCH 03/28] [PP] support async PP Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler.py | 126 +---------------- .../sglang/srt/managers/scheduler_pp_mixin.py | 132 +++++------------- python/sglang/srt/server_args.py | 7 + python/sglang/srt/utils/common.py | 48 ++++--- 4 files changed, 73 insertions(+), 240 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 78457abc8802..98b15c35d745 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -136,6 +136,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import ( SchedulerOutputProcessorMixin, ) +from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper from sglang.srt.managers.scheduler_update_weights_mixin import ( @@ -281,6 +282,7 @@ class Scheduler( SchedulerMetricsMixin, SchedulerDisaggregationDecodeMixin, SchedulerDisaggregationPrefillMixin, + SchedulerPPMixin, ): """A scheduler that manages a tensor parallel GPU worker.""" @@ -1055,128 +1057,6 @@ def event_loop_overlap(self): self.launch_batch_sample_if_needed(batch_result) self.last_batch = batch - @DynamicGradMode() - def event_loop_pp(self): - """A non-overlap scheduler loop for pipeline parallelism.""" - mbs = [None] * self.pp_size - last_mbs = [None] * self.pp_size - self.running_mbs = [ - ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) - ] - pp_outputs: Optional[PPProxyTensors] = None - while True: - server_is_idle = True - for mb_id in range(self.pp_size): - self.running_batch = self.running_mbs[mb_id] - self.last_batch = last_mbs[mb_id] - - recv_reqs = self.recv_requests() - self.process_input_requests(recv_reqs) - mbs[mb_id] = self.get_next_batch_to_run() - self.running_mbs[mb_id] = self.running_batch - - self.cur_batch = mbs[mb_id] - if self.cur_batch: - server_is_idle = False - result = self.run_batch(self.cur_batch) - - # (last rank) send the outputs to the next step - if self.pp_group.is_last_rank: - if self.cur_batch: - next_token_ids = result.next_token_ids - if self.cur_batch.return_logprob: - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - "extend_input_len_per_req": result.extend_input_len_per_req, - "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, - } - | ( - { - f"logits_output.{k}": v - for k, v in result.logits_output.__dict__.items() - } - if result.logits_output is not None - else {} - ) - ) - else: - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - } - ) - # send the output from the last round to let the next stage worker run post processing - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, - ) - - # receive outputs and post-process (filter finished reqs) the coming microbatch - next_mb_id = (mb_id + 1) % self.pp_size - next_pp_outputs = None - if mbs[next_mb_id] is not None: - next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( - self.pp_group.recv_tensor_dict( - all_gather_group=self.attn_tp_group - ) - ) - mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] - logits_output_args = { - k[len("logits_output.") :]: v - for k, v in next_pp_outputs.tensors.items() - if k.startswith("logits_output.") - } - if len(logits_output_args) > 0: - logits_output = LogitsProcessorOutput(**logits_output_args) - else: - logits_output = None - - output_result = GenerationBatchResult.from_pp_proxy( - logits_output=logits_output, - next_pp_outputs=next_pp_outputs, - can_run_cuda_graph=result.can_run_cuda_graph, - ) - self.process_batch_result(mbs[next_mb_id], output_result) - last_mbs[next_mb_id] = mbs[next_mb_id] - - # (not last rank) - if not self.pp_group.is_last_rank: - # carry the outputs to the next stage - # send the outputs from the last round to let the next stage worker run post processing - if pp_outputs: - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, - ) - - # send out reqs to the next stage - dp_offset = self.attn_dp_rank * self.attn_tp_size - if self.attn_tp_rank == 0: - point_to_point_pyobj( - recv_reqs, - self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, - self.pp_rank * self.tp_size + dp_offset, - (self.pp_rank + 1) * self.tp_size + dp_offset, - ) - - # send out proxy tensors to the next stage - if self.cur_batch: - # FIXME(lsyin): remove this assert - assert result.pp_hidden_states_proxy_tensors.tensors is not None - self.pp_group.send_tensor_dict( - result.pp_hidden_states_proxy_tensors.tensors, - all_gather_group=self.attn_tp_group, - ) - - pp_outputs = next_pp_outputs - - # When the server is idle, self-check and re-init some states - if server_is_idle: - # When the server is idle, do self-check and re-init some states - self.self_check_during_idle() - def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" @@ -1212,7 +1092,7 @@ def recv_requests(self) -> List[Req]: recv_reqs = point_to_point_pyobj( [], self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, + self.world_group.cpu_group, (self.pp_rank - 1) * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset, ) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 2c8b2a80eb21..0acabd419ff8 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -26,8 +26,11 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.distributed.parallel_state import P2PWork from sglang.srt.managers.schedule_batch import ScheduleBatch + +if TYPE_CHECKING: + from sglang.srt.managers.scheduler import GenerationBatchResult + from sglang.srt.managers.utils import ( - GenerationBatchResult, get_logprob_dict_from_result, get_logprob_from_pp_outputs, ) @@ -42,7 +45,6 @@ @dataclass class PPBatchMetadata: - bid: int can_run_cuda_graph: bool @@ -52,10 +54,10 @@ def _pp_commit_comm_work(self: Scheduler, work: List[P2PWork]) -> None: p2p_work.work.wait() work.clear() - def send_pyobj_to_next_stage(self: Scheduler, data, async_send: bool = False): + def _pp_send_pyobj_to_next_stage(self: Scheduler, data, async_send: bool = False): p2p_work = [] if self.attn_tp_rank == 0: - dp_offset = self.dp_rank * self.attn_tp_size + dp_offset = self.attn_dp_rank * self.attn_tp_size p2p_work = point_to_point_pyobj( data, self.pp_rank * self.tp_size + dp_offset, @@ -107,10 +109,9 @@ def _pp_send_dict_to_next_stage( async_send: bool = True, ): p2p_work = [] - p2p_work.extend( self.pp_group.send_tensor_dict( - tensor_dict, + tensor_dict=tensor_dict, all_gather_group=self.attn_tp_group, async_send=async_send, ) @@ -139,6 +140,8 @@ def _pp_prep_batch_result( mb_metadata: PPBatchMetadata, pp_outputs: PPProxyTensors, ): + from sglang.srt.managers.scheduler import GenerationBatchResult + logits_output = None extend_input_len_per_req = None extend_logprob_start_len_per_req = None @@ -156,7 +159,6 @@ def _pp_prep_batch_result( next_token_ids=pp_outputs["next_token_ids"], extend_input_len_per_req=extend_input_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, - bid=mb_metadata.bid, can_run_cuda_graph=mb_metadata.can_run_cuda_graph, ) return output_result @@ -224,7 +226,7 @@ def _pp_send_recv_and_preprocess_output_tensors( batch_result = self._pp_prep_batch_result( mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs ) - d2h_event = self.process_batch_result_d2h(mbs[next_mb_id], batch_result) + # d2h_event = self.process_batch_result_d2h(mbs[next_mb_id], batch_result) return next_pp_outputs, batch_result, d2h_event @@ -236,9 +238,8 @@ def _pp_launch_batch( last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], ): with torch.profiler.record_function("run_batch"): - result = self.run_batch(self.cur_batch, pp_proxy_tensors) + result = self.run_batch(self.cur_batch) mb_metadata[mb_id] = PPBatchMetadata( - bid=result.bid, can_run_cuda_graph=result.can_run_cuda_graph, ) event = torch.cuda.Event() @@ -310,8 +311,8 @@ def event_loop_pp(self: Scheduler): self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] if self.cur_batch: server_is_idle = False - pp_proxy_tensors = self._pp_recv_proxy_tensors() - self._pp_commit_comm_work(send_proxy_work) + # pp_proxy_tensors = self._pp_recv_proxy_tensors() + # self._pp_commit_comm_work(send_proxy_work) next_pp_outputs = None next_batch_result = None d2h_event = None @@ -328,6 +329,7 @@ def event_loop_pp(self: Scheduler): ) if self.cur_batch: + pp_proxy_tensors = None result, event = self._pp_launch_batch( mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue ) @@ -343,7 +345,7 @@ def event_loop_pp(self: Scheduler): ) ) if mbs[next_mb_id] is not None: - d2h_event.synchronize() + # d2h_event.synchronize() with torch.profiler.record_function("process_batch_result"): self._pp_process_batch_result( mbs[next_mb_id], @@ -352,7 +354,7 @@ def event_loop_pp(self: Scheduler): last_mbs[next_mb_id] = mbs[next_mb_id] if not self.pp_group.is_last_rank: with torch.profiler.record_function("send_reqs_to_next_stage"): - send_req_work = self.send_pyobj_to_next_stage( + send_req_work = self._pp_send_pyobj_to_next_stage( recv_reqs, async_send=True, ) @@ -362,13 +364,13 @@ def event_loop_pp(self: Scheduler): "send_proxy_dict_to_next_stage" ): send_proxy_work = self._pp_send_dict_to_next_stage( - result.pp_hidden_states_proxy_tensors, + result.pp_hidden_states_proxy_tensors.tensors, async_send=True, ) - if self.delayed_weight_sync_fn: - self.delayed_weight_sync_fn() - self.delayed_weight_sync_fn = None + # if self.delayed_weight_sync_fn: + # self.delayed_weight_sync_fn() + # self.delayed_weight_sync_fn = None pp_outputs = next_pp_outputs @@ -409,7 +411,7 @@ def _pp_pd_get_bootstrapped_ids(self: Scheduler): [KVPoll.Failed], ) else: - # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the concensus + # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus prev_bootstrapped_rids = self.recv_pyobj_from_prev_stage() prev_good_bootstrapped_rids, prev_bad_bootstrapped_rids = ( prev_bootstrapped_rids @@ -444,7 +446,7 @@ def _pp_pd_get_transferred_ids(self: Scheduler): self.disagg_prefill_inflight_queue, [KVPoll.Success, KVPoll.Failed], ) - # 3. new concensus rids = intersection(previous concensus rids, transfer finished rids) + # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) transferred_rids = list( set(prev_transferred_rids) & set(curr_transferred_rids) ) @@ -462,13 +464,13 @@ def _pp_pd_send_consensus_bootstrapped_ids( if self.pp_group.is_last_rank: if bmbs[next_first_rank_mb_id] is not None: consensus_bootstrapped_rids = bootstrapped_rids - send_consensus_bootstrapped_work = self.send_pyobj_to_next_stage( + send_consensus_bootstrapped_work = self._pp_send_pyobj_to_next_stage( consensus_bootstrapped_rids, async_send=True ) # 4 (Release): send the release rids from non last rank to the next rank else: if consensus_bootstrapped_rids is not None: - send_consensus_bootstrapped_work = self.send_pyobj_to_next_stage( + send_consensus_bootstrapped_work = self._pp_send_pyobj_to_next_stage( consensus_bootstrapped_rids, async_send=True ) return send_consensus_bootstrapped_work, consensus_bootstrapped_rids @@ -484,13 +486,13 @@ def _pp_pd_send_consensus_release_ids( if self.pp_group.is_last_rank: if tmbs[next_first_rank_mb_id] is not None: release_rids = transferred_rids - send_release_work = self.send_pyobj_to_next_stage( + send_release_work = self._pp_send_pyobj_to_next_stage( release_rids, async_send=True ) # 4 (Release): send the release rids from non last rank to the next rank else: if release_rids is not None: - send_release_work = self.send_pyobj_to_next_stage( + send_release_work = self._pp_send_pyobj_to_next_stage( release_rids, async_send=True ) return send_release_work, release_rids @@ -530,7 +532,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): There are two additional elements compared to the regular schedule: Bootstrap Requests + Release Requests: - - Both can have local failure and need to be consensus on. PP needs to gurantee eventual consistency of local failure and flush malfunc requests out as soft error. + - Both can have local failure and need to be consensus on. PP needs to guarantee eventual consistency of local failure and flush malfunc requests out as soft error. """ self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth @@ -544,7 +546,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): pp_outputs: Optional[PPProxyTensors] = None last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = deque() - # PD additionals + # PD additional consensus_bootstrapped_rids: Optional[List[str]] = None transferred_rids: List[str] = [] release_rids: Optional[List[str]] = None @@ -644,7 +646,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): self._pp_commit_comm_work(send_release_work) # post-process the coming microbatch if mbs[next_mb_id] is not None: - d2h_event.synchronize() + # d2h_event.synchronize() self._pp_process_batch_result( mbs[next_mb_id], next_batch_result, @@ -654,19 +656,19 @@ def event_loop_pp_disagg_prefill(self: Scheduler): if tmbs[next_mb_id] is not None: self.process_disagg_prefill_inflight_queue(next_release_rids) if not self.pp_group.is_last_rank: - send_req_work = self.send_pyobj_to_next_stage( + send_req_work = self._pp_send_pyobj_to_next_stage( recv_reqs, async_send=True ) - send_bootstrapped_work = self.send_pyobj_to_next_stage( + send_bootstrapped_work = self._pp_send_pyobj_to_next_stage( bootstrapped_rids, async_send=True ) - send_transfer_work = self.send_pyobj_to_next_stage( + send_transfer_work = self._pp_send_pyobj_to_next_stage( transferred_rids, async_send=True ) if self.cur_batch: torch.cuda.current_stream().wait_event(event) send_proxy_work = self._pp_send_dict_to_next_stage( - result.pp_hidden_states_proxy_tensors, + result.pp_hidden_states_proxy_tensors.tensors, async_send=True, ) @@ -686,71 +688,3 @@ def event_loop_pp_disagg_prefill(self: Scheduler): self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio self.maybe_sleep_on_idle() - - - -# Keep this function for xai's PP implementation -def point_to_point_pyobj( - data: List[Any], - rank: int, - group: Optional[torch.distributed.ProcessGroup] = None, - src: int = 0, - dst: int = 1, - async_send: bool = False, -): - """Send data from src to dst in group.""" - if async_send: - send_func = dist.isend - else: - send_func = dist.send - if rank == src: - p2p_works = [] - if len(data) == 0: - tensor_size = torch.tensor( - [0], - dtype=torch.long, - ) - work = send_func(tensor_size, dst, group=group) - if async_send: - p2p_works.append(P2PWork(work, tensor_size)) - else: - serialized_data = pickle.dumps(data) - size = len(serialized_data) - tensor_data = torch.ByteTensor( - np.frombuffer(serialized_data, dtype=np.uint8) - ) - tensor_size = torch.tensor([size], dtype=torch.long) - - work = send_func(tensor_size, dst, group=group) - if async_send: - p2p_works.append(P2PWork(work, tensor_size)) - work = send_func(tensor_data, dst, group=group) - if async_send: - p2p_works.append(P2PWork(work, tensor_data)) - return p2p_works - - elif rank == dst: - tensor_size = torch.tensor( - [0], - dtype=torch.long, - ) - work = dist.irecv(tensor_size, src=src, group=group) - work.wait() - size = tensor_size.item() - - if size == 0: - return [] - - tensor_data = torch.empty( - size, - dtype=torch.uint8, - ) - work = dist.irecv(tensor_data, src=src, group=group) - work.wait() - - serialized_data = bytes(tensor_data.cpu().numpy()) - data = pickle.loads(serialized_data) - return data - - # Other ranks in pp_group do nothing - return [] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 939e502fc38f..a12e9fb88507 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -236,6 +236,7 @@ class ServerArgs: tp_size: int = 1 pp_size: int = 1 pp_max_micro_batch_size: Optional[int] = None + pp_async_batch_depth: int = 0 stream_interval: int = 1 stream_output: bool = False random_seed: Optional[int] = None @@ -1842,6 +1843,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.pp_max_micro_batch_size, help="The maximum micro batch size in pipeline parallelism.", ) + parser.add_argument( + "--pp-async-batch-depth", + type=int, + default=ServerArgs.pp_async_batch_depth, + help="The async batch depth of pipeline parallelism.", + ) parser.add_argument( "--stream-interval", type=int, diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 2264168e2c16..5e6e6095cb89 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1219,49 +1219,61 @@ def point_to_point_pyobj( group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, dst: int = 1, + async_send: bool = False, ): - """Send data from src to dst in group using DeviceToDevice communication.""" + """Send data from src to dst in group.""" + from sglang.srt.distributed.parallel_state import P2PWork + if async_send: + send_func = dist.isend + else: + send_func = dist.send if rank == src: + p2p_works = [] if len(data) == 0: tensor_size = torch.tensor( - [0], dtype=torch.long, device=torch.cuda.current_device() + [0], + dtype=torch.long, ) - dist.send(tensor_size, dst=dst, group=group) + work = send_func(tensor_size, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_size)) else: serialized_data = pickle.dumps(data) size = len(serialized_data) tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) - ).cuda( - device=torch.cuda.current_device() - ) # Move to GPU - tensor_size = torch.tensor( - [size], dtype=torch.long, device=torch.cuda.current_device() ) + tensor_size = torch.tensor([size], dtype=torch.long) - dist.send(tensor_size, dst=dst, group=group) - dist.send(tensor_data, dst=dst, group=group) - return data + work = send_func(tensor_size, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_size)) + work = send_func(tensor_data, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_data)) + return p2p_works elif rank == dst: tensor_size = torch.tensor( - [0], dtype=torch.long, device=torch.cuda.current_device() + [0], + dtype=torch.long, ) - dist.recv(tensor_size, src=src, group=group) + work = dist.irecv(tensor_size, src=src, group=group) + work.wait() size = tensor_size.item() if size == 0: return [] tensor_data = torch.empty( - size, dtype=torch.uint8, device=torch.cuda.current_device() + size, + dtype=torch.uint8, ) - dist.recv(tensor_data, src=src, group=group) + work = dist.irecv(tensor_data, src=src, group=group) + work.wait() - serialized_data = bytes( - tensor_data.cpu().numpy() - ) # Move back to host for deserialization + serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) return data From 269e64b3f6dd9e4f8d7743fbdc61d7172721574a Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Mon, 20 Oct 2025 15:50:03 +0800 Subject: [PATCH 04/28] update Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler_pp_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 0acabd419ff8..ca2d939978f1 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -312,7 +312,7 @@ def event_loop_pp(self: Scheduler): if self.cur_batch: server_is_idle = False # pp_proxy_tensors = self._pp_recv_proxy_tensors() - # self._pp_commit_comm_work(send_proxy_work) + self._pp_commit_comm_work(send_proxy_work) next_pp_outputs = None next_batch_result = None d2h_event = None From cad6210e4081f00574ce2a015e23c3fe2b2a6d34 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Tue, 21 Oct 2025 11:58:42 +0800 Subject: [PATCH 05/28] upd --- python/sglang/srt/managers/scheduler_pp_mixin.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index ca2d939978f1..4a6f07f2041b 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -1,18 +1,3 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Pipeline parallelism mixin for scheduler - contains PP event loop and utilities.""" - from __future__ import annotations import logging From 53335ccc21bd220156df142085242fcaec7f4cff Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Tue, 21 Oct 2025 12:09:11 +0800 Subject: [PATCH 06/28] fix --- python/sglang/srt/managers/scheduler_pp_mixin.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 4a6f07f2041b..8dd3544e26e8 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -11,11 +11,8 @@ from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.distributed.parallel_state import P2PWork from sglang.srt.managers.schedule_batch import ScheduleBatch - -if TYPE_CHECKING: - from sglang.srt.managers.scheduler import GenerationBatchResult - from sglang.srt.managers.utils import ( + GenerationBatchResult, get_logprob_dict_from_result, get_logprob_from_pp_outputs, ) From c4932007eb4d8cb1500029678fe3c77c9b467bea Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Tue, 21 Oct 2025 13:46:10 +0800 Subject: [PATCH 07/28] fix Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler.py | 8 ++++++-- python/sglang/srt/managers/scheduler_pp_mixin.py | 5 ++--- python/sglang/srt/managers/tp_worker.py | 9 +-------- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7a781fcb29b5..5b09142fb7ad 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1986,7 +1986,9 @@ def update_cache_from_scheduler( pass def run_batch( - self, batch: ScheduleBatch + self, + batch: ScheduleBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: """Run a batch.""" self.forward_ct += 1 @@ -2025,6 +2027,7 @@ def run_batch( self.future_map.resolve_future(model_worker_batch) batch_result = self.model_worker.forward_batch_generation( model_worker_batch + # here pp is not compatible with overlap ) # FIXME(lsyin): maybe move this to forward_batch_generation batch_result.copy_done = torch.get_device_module( @@ -2058,7 +2061,8 @@ def run_batch( batch.seq_lens = batch_result.next_draft_input.new_seq_lens else: batch_result = self.model_worker.forward_batch_generation( - batch_or_worker_batch + batch_or_worker_batch, + pp_proxy_tensors=pp_proxy_tensors, ) future_indices_or_next_token_ids = batch_result.next_token_ids self.update_cache_from_scheduler(batch, batch_result) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 8dd3544e26e8..679903eaeb10 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -220,7 +220,7 @@ def _pp_launch_batch( last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], ): with torch.profiler.record_function("run_batch"): - result = self.run_batch(self.cur_batch) + result = self.run_batch(self.cur_batch, pp_proxy_tensors) mb_metadata[mb_id] = PPBatchMetadata( can_run_cuda_graph=result.can_run_cuda_graph, ) @@ -293,7 +293,7 @@ def event_loop_pp(self: Scheduler): self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] if self.cur_batch: server_is_idle = False - # pp_proxy_tensors = self._pp_recv_proxy_tensors() + pp_proxy_tensors = self._pp_recv_proxy_tensors() self._pp_commit_comm_work(send_proxy_work) next_pp_outputs = None next_batch_result = None @@ -311,7 +311,6 @@ def event_loop_pp(self: Scheduler): ) if self.cur_batch: - pp_proxy_tensors = None result, event = self._pp_launch_batch( mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue ) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 0a623d4a23ac..b1c9529ce735 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -335,6 +335,7 @@ def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, forward_batch: Optional[ForwardBatch] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, is_verify: bool = False, skip_attn_backend_init=False, ) -> GenerationBatchResult: @@ -350,14 +351,6 @@ def forward_batch_generation( # FIXME(lsyin): unify the interface of forward_batch assert forward_batch is not None - pp_proxy_tensors = None - if not self.pp_group.is_first_rank: - pp_proxy_tensors = PPProxyTensors( - self.pp_group.recv_tensor_dict( - all_gather_group=self.get_attention_tp_group() - ) - ) - if self.pp_group.is_last_rank: logits_output, can_run_cuda_graph = self.model_runner.forward( forward_batch, From 44656b8ded172ae677b0f9b0e66bc4a8a2e1942f Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Tue, 21 Oct 2025 14:03:15 +0800 Subject: [PATCH 08/28] fix Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5b09142fb7ad..d9d28fd305fa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -148,6 +148,7 @@ from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args from sglang.srt.speculative.spec_info import SpeculativeAlgorithm From 21df2f25bcfdbf160e4ee67383925dbe20a7627b Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Tue, 21 Oct 2025 14:15:45 +0800 Subject: [PATCH 09/28] async run and process Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler.py | 5 ++- .../sglang/srt/managers/scheduler_pp_mixin.py | 45 ++++++++++--------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d9d28fd305fa..4a0c51157f48 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -909,7 +909,7 @@ def init_disaggregation(self): self.disagg_prefill_inflight_queue: List[Req] = [] def init_overlap(self): - if not self.enable_overlap: + if not self.enable_overlap and self.pp_size == 1: return self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream() @@ -921,6 +921,9 @@ def init_overlap(self): self.device ).stream(self.copy_stream) + if not self.enable_overlap: + return + self.future_map = FutureMap( self.max_running_requests, self.device, self.spec_algorithm ) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 679903eaeb10..eb97747dd2a6 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -205,10 +205,12 @@ def _pp_send_recv_and_preprocess_output_tensors( next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) self._pp_commit_comm_work(work=send_output_work) if mbs[next_mb_id] is not None: - batch_result = self._pp_prep_batch_result( - mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs - ) - # d2h_event = self.process_batch_result_d2h(mbs[next_mb_id], batch_result) + with self.forward_stream_ctx: + batch_result = self._pp_prep_batch_result( + mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs + ) + d2h_event = torch.cuda.Event() + d2h_event.record(torch.cuda.current_stream()) return next_pp_outputs, batch_result, d2h_event @@ -216,26 +218,27 @@ def _pp_launch_batch( self: Scheduler, mb_id: int, pp_proxy_tensors: PPProxyTensors, - mb_metadata: PPBatchMetadata, + mb_metadata: List[Optional[PPBatchMetadata]], last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], ): with torch.profiler.record_function("run_batch"): - result = self.run_batch(self.cur_batch, pp_proxy_tensors) - mb_metadata[mb_id] = PPBatchMetadata( - can_run_cuda_graph=result.can_run_cuda_graph, - ) - event = torch.cuda.Event() - event.record(torch.cuda.current_stream()) - if self.pp_group.is_last_rank: - # (last rank) buffer the outputs for async batch depth - last_rank_comm_queue.append( - ( - event, - PPProxyTensors( - self._pp_prepare_tensor_dict(result, self.cur_batch) - ), - ) + with self.forward_stream_ctx: + result = self.run_batch(self.cur_batch, pp_proxy_tensors) + mb_metadata[mb_id] = PPBatchMetadata( + can_run_cuda_graph=result.can_run_cuda_graph, ) + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + if self.pp_group.is_last_rank: + # (last rank) buffer the outputs for async batch depth + last_rank_comm_queue.append( + ( + event, + PPProxyTensors( + self._pp_prepare_tensor_dict(result, self.cur_batch) + ), + ) + ) return result, event @DynamicGradMode() @@ -326,7 +329,7 @@ def event_loop_pp(self: Scheduler): ) ) if mbs[next_mb_id] is not None: - # d2h_event.synchronize() + d2h_event.synchronize() with torch.profiler.record_function("process_batch_result"): self._pp_process_batch_result( mbs[next_mb_id], From 8fec3161659a6862ced79e4e7d2777019240b2e8 Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Wed, 22 Oct 2025 11:15:42 +0800 Subject: [PATCH 10/28] fix Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler_pp_mixin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index eb97747dd2a6..691bab1edba2 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -205,7 +205,8 @@ def _pp_send_recv_and_preprocess_output_tensors( next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) self._pp_commit_comm_work(work=send_output_work) if mbs[next_mb_id] is not None: - with self.forward_stream_ctx: + with self.copy_stream_ctx: + self.copy_stream.wait_stream(self.default_stream) batch_result = self._pp_prep_batch_result( mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs ) @@ -223,6 +224,7 @@ def _pp_launch_batch( ): with torch.profiler.record_function("run_batch"): with self.forward_stream_ctx: + self.forward_stream.wait_stream(self.default_stream) result = self.run_batch(self.cur_batch, pp_proxy_tensors) mb_metadata[mb_id] = PPBatchMetadata( can_run_cuda_graph=result.can_run_cuda_graph, From 382ce86024d2ff8cf7ce28c00242b01288d943de Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Wed, 22 Oct 2025 13:22:00 +0800 Subject: [PATCH 11/28] fix Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4a0c51157f48..1990d725a88f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2064,9 +2064,13 @@ def run_batch( # Current implementation strictly synchronizes the seq_lens batch.seq_lens = batch_result.next_draft_input.new_seq_lens else: + kwargs = ( + {"pp_proxy_tensors": pp_proxy_tensors} + if self.spec_algorithm.is_none() + else {} + ) batch_result = self.model_worker.forward_batch_generation( - batch_or_worker_batch, - pp_proxy_tensors=pp_proxy_tensors, + batch_or_worker_batch, **kwargs ) future_indices_or_next_token_ids = batch_result.next_token_ids self.update_cache_from_scheduler(batch, batch_result) From 770b28fd7010b3952377a9a8ca2abbe635480709 Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Wed, 22 Oct 2025 17:36:44 +0800 Subject: [PATCH 12/28] update Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler_pp_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 691bab1edba2..97b567150cf6 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -632,7 +632,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): self._pp_commit_comm_work(send_release_work) # post-process the coming microbatch if mbs[next_mb_id] is not None: - # d2h_event.synchronize() + d2h_event.synchronize() self._pp_process_batch_result( mbs[next_mb_id], next_batch_result, From 49a81dc664c695e4202d94c1bbbb19ab9517f79c Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Wed, 22 Oct 2025 19:21:03 -0700 Subject: [PATCH 13/28] tiny improvement by delaying commit sync --- .../sglang/srt/managers/scheduler_pp_mixin.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 97b567150cf6..abac3f2c9e18 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -203,7 +203,7 @@ def _pp_send_recv_and_preprocess_output_tensors( if mbs[next_mb_id] is not None: with torch.profiler.record_function("recv_res_dict_from_prev_stage"): next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) - self._pp_commit_comm_work(work=send_output_work) + # self._pp_commit_comm_work(work=send_output_work) if mbs[next_mb_id] is not None: with self.copy_stream_ctx: self.copy_stream.wait_stream(self.default_stream) @@ -213,7 +213,7 @@ def _pp_send_recv_and_preprocess_output_tensors( d2h_event = torch.cuda.Event() d2h_event.record(torch.cuda.current_stream()) - return next_pp_outputs, batch_result, d2h_event + return next_pp_outputs, batch_result, d2h_event, send_output_work def _pp_launch_batch( self: Scheduler, @@ -280,6 +280,7 @@ def event_loop_pp(self: Scheduler): last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = deque() send_req_work = [] send_proxy_work = [] + send_output_work = [] event = None while True: server_is_idle = True @@ -290,7 +291,7 @@ def event_loop_pp(self: Scheduler): next_mb_id = (mb_id + 1) % self.pp_loop_size with torch.profiler.record_function("recv_requests"): recv_reqs = self.recv_requests() - self._pp_commit_comm_work(send_req_work) + # self._pp_commit_comm_work(send_req_work) self.process_input_requests(recv_reqs) with torch.profiler.record_function("get_next_batch_to_run"): mbs[mb_id] = self.get_next_batch_to_run() @@ -299,12 +300,13 @@ def event_loop_pp(self: Scheduler): if self.cur_batch: server_is_idle = False pp_proxy_tensors = self._pp_recv_proxy_tensors() - self._pp_commit_comm_work(send_proxy_work) + # self._pp_commit_comm_work(send_proxy_work) next_pp_outputs = None next_batch_result = None d2h_event = None if self.server_args.pp_async_batch_depth > 0: - next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( self._pp_send_recv_and_preprocess_output_tensors( next_first_rank_mb_id, next_mb_id, @@ -320,7 +322,8 @@ def event_loop_pp(self: Scheduler): mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue ) if self.server_args.pp_async_batch_depth == 0: - next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( self._pp_send_recv_and_preprocess_output_tensors( next_first_rank_mb_id, next_mb_id, @@ -339,6 +342,7 @@ def event_loop_pp(self: Scheduler): ) last_mbs[next_mb_id] = mbs[next_mb_id] if not self.pp_group.is_last_rank: + self._pp_commit_comm_work(send_req_work) with torch.profiler.record_function("send_reqs_to_next_stage"): send_req_work = self._pp_send_pyobj_to_next_stage( recv_reqs, @@ -346,6 +350,7 @@ def event_loop_pp(self: Scheduler): ) if self.cur_batch: torch.cuda.current_stream().wait_event(event) + self._pp_commit_comm_work(send_proxy_work) with torch.profiler.record_function( "send_proxy_dict_to_next_stage" ): From ac347c522a63d5da1901c9f3ed0519c6bb8d6833 Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 23 Oct 2025 15:27:23 -0700 Subject: [PATCH 14/28] send req asap --- .../sglang/srt/managers/scheduler_pp_mixin.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index abac3f2c9e18..65b9ef5668b0 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -293,6 +293,13 @@ def event_loop_pp(self: Scheduler): recv_reqs = self.recv_requests() # self._pp_commit_comm_work(send_req_work) self.process_input_requests(recv_reqs) + if not self.pp_group.is_last_rank: + self._pp_commit_comm_work(send_req_work) + with torch.profiler.record_function("send_reqs_to_next_stage"): + send_req_work = self._pp_send_pyobj_to_next_stage( + recv_reqs, + async_send=True, + ) with torch.profiler.record_function("get_next_batch_to_run"): mbs[mb_id] = self.get_next_batch_to_run() self.running_mbs[mb_id] = self.running_batch @@ -342,12 +349,12 @@ def event_loop_pp(self: Scheduler): ) last_mbs[next_mb_id] = mbs[next_mb_id] if not self.pp_group.is_last_rank: - self._pp_commit_comm_work(send_req_work) - with torch.profiler.record_function("send_reqs_to_next_stage"): - send_req_work = self._pp_send_pyobj_to_next_stage( - recv_reqs, - async_send=True, - ) + # self._pp_commit_comm_work(send_req_work) + # with torch.profiler.record_function("send_reqs_to_next_stage"): + # send_req_work = self._pp_send_pyobj_to_next_stage( + # recv_reqs, + # async_send=True, + # ) if self.cur_batch: torch.cuda.current_stream().wait_event(event) self._pp_commit_comm_work(send_proxy_work) From 2efd976958bd255d95fc83f3a292c07b810c2055 Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 23 Oct 2025 15:35:14 -0700 Subject: [PATCH 15/28] cleanup PR --- python/sglang/srt/managers/scheduler_pp_mixin.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 65b9ef5668b0..758e2d79ec13 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -203,7 +203,6 @@ def _pp_send_recv_and_preprocess_output_tensors( if mbs[next_mb_id] is not None: with torch.profiler.record_function("recv_res_dict_from_prev_stage"): next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) - # self._pp_commit_comm_work(work=send_output_work) if mbs[next_mb_id] is not None: with self.copy_stream_ctx: self.copy_stream.wait_stream(self.default_stream) @@ -291,7 +290,6 @@ def event_loop_pp(self: Scheduler): next_mb_id = (mb_id + 1) % self.pp_loop_size with torch.profiler.record_function("recv_requests"): recv_reqs = self.recv_requests() - # self._pp_commit_comm_work(send_req_work) self.process_input_requests(recv_reqs) if not self.pp_group.is_last_rank: self._pp_commit_comm_work(send_req_work) @@ -307,7 +305,6 @@ def event_loop_pp(self: Scheduler): if self.cur_batch: server_is_idle = False pp_proxy_tensors = self._pp_recv_proxy_tensors() - # self._pp_commit_comm_work(send_proxy_work) next_pp_outputs = None next_batch_result = None d2h_event = None @@ -349,12 +346,6 @@ def event_loop_pp(self: Scheduler): ) last_mbs[next_mb_id] = mbs[next_mb_id] if not self.pp_group.is_last_rank: - # self._pp_commit_comm_work(send_req_work) - # with torch.profiler.record_function("send_reqs_to_next_stage"): - # send_req_work = self._pp_send_pyobj_to_next_stage( - # recv_reqs, - # async_send=True, - # ) if self.cur_batch: torch.cuda.current_stream().wait_event(event) self._pp_commit_comm_work(send_proxy_work) From 7dbc256d0a5fcb58f2ce43446f07d7c70291d73d Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 23 Oct 2025 17:41:44 -0700 Subject: [PATCH 16/28] add recv tensor to cuda stream --- python/sglang/srt/managers/scheduler_pp_mixin.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 758e2d79ec13..783966a72356 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -200,11 +200,12 @@ def _pp_send_recv_and_preprocess_output_tensors( pp_outputs, ) - if mbs[next_mb_id] is not None: - with torch.profiler.record_function("recv_res_dict_from_prev_stage"): - next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) if mbs[next_mb_id] is not None: with self.copy_stream_ctx: + with torch.profiler.record_function("recv_res_dict_from_prev_stage"): + next_pp_outputs = PPProxyTensors( + self._pp_recv_dict_from_prev_stage() + ) self.copy_stream.wait_stream(self.default_stream) batch_result = self._pp_prep_batch_result( mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs From 8b03c9473f975bc00e46101415e30022c4f3296f Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 23 Oct 2025 18:47:48 -0700 Subject: [PATCH 17/28] move sync step forward --- python/sglang/srt/managers/scheduler_pp_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 783966a72356..1feaecdf7df7 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -202,11 +202,11 @@ def _pp_send_recv_and_preprocess_output_tensors( if mbs[next_mb_id] is not None: with self.copy_stream_ctx: + self.copy_stream.wait_stream(self.default_stream) with torch.profiler.record_function("recv_res_dict_from_prev_stage"): next_pp_outputs = PPProxyTensors( self._pp_recv_dict_from_prev_stage() ) - self.copy_stream.wait_stream(self.default_stream) batch_result = self._pp_prep_batch_result( mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs ) From b2e95f02e56f42e51ec72808d24c5c5fc4cb6ead Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 23 Oct 2025 18:49:38 -0700 Subject: [PATCH 18/28] move out of copy stream --- python/sglang/srt/managers/scheduler_pp_mixin.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 1feaecdf7df7..87979d3b967e 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -201,12 +201,10 @@ def _pp_send_recv_and_preprocess_output_tensors( ) if mbs[next_mb_id] is not None: + with torch.profiler.record_function("recv_res_dict_from_prev_stage"): + next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) with self.copy_stream_ctx: self.copy_stream.wait_stream(self.default_stream) - with torch.profiler.record_function("recv_res_dict_from_prev_stage"): - next_pp_outputs = PPProxyTensors( - self._pp_recv_dict_from_prev_stage() - ) batch_result = self._pp_prep_batch_result( mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs ) From 49ba293164ff5b4ac362654a54fa946cdbe332ad Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 23 Oct 2025 20:41:50 -0700 Subject: [PATCH 19/28] commit send proxy work before launch batch --- python/sglang/srt/managers/scheduler_pp_mixin.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 87979d3b967e..5b972575d896 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -319,7 +319,7 @@ def event_loop_pp(self: Scheduler): pp_outputs, ) ) - + self._pp_commit_comm_work(send_proxy_work) if self.cur_batch: result, event = self._pp_launch_batch( mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue @@ -347,7 +347,6 @@ def event_loop_pp(self: Scheduler): if not self.pp_group.is_last_rank: if self.cur_batch: torch.cuda.current_stream().wait_event(event) - self._pp_commit_comm_work(send_proxy_work) with torch.profiler.record_function( "send_proxy_dict_to_next_stage" ): From 60ce68484dbee0e8962fc1528206e143e24cc29f Mon Sep 17 00:00:00 2001 From: zhangxiaolei123456 Date: Thu, 13 Nov 2025 15:01:02 +0800 Subject: [PATCH 20/28] Update scheduler_pp_mixin.py --- .../sglang/srt/managers/scheduler_pp_mixin.py | 72 ++++++++----------- 1 file changed, 30 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index e70536e01995..c53b9f8227cf 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -373,15 +373,10 @@ def process_bootstrapped_queue( ): # finished consensus bootstrapped reqs and prepare the waiting queue if bootstrapped_rids is not None: - ( - good_consensus_bootstrapped_rids, - bad_consensus_bootstrapped_rids, - ) = bootstrapped_rids good_reqs, failed_reqs = ( self.disagg_prefill_bootstrap_queue.pop_bootstrapped( return_failed_reqs=True, - rids_to_check=good_consensus_bootstrapped_rids, - bad_rids_to_check=bad_consensus_bootstrapped_rids, + rids_to_check=bootstrapped_rids, ) ) self.waiting_queue.extend(good_reqs) @@ -392,47 +387,36 @@ def _pp_pd_get_bootstrapped_ids(self: Scheduler): # communicate pre-consensus bootstrapp reqs if self.pp_group.is_first_rank: # First rank, pop the bootstrap reqs from the bootstrap queue - good_bootstrapped_rids, bad_bootstrapped_rids = self.get_rids( - self.disagg_prefill_bootstrap_queue.queue, - [KVPoll.WaitingForInput], - [KVPoll.Failed], + bootstrapped_reqs, failed_reqs = ( + self.disagg_prefill_bootstrap_queue.pop_bootstrapped( + return_failed_reqs=True + ) ) + bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [ + req.rid for req in failed_reqs + ] else: # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus - prev_bootstrapped_rids = self.recv_pyobj_from_prev_stage() - prev_good_bootstrapped_rids, prev_bad_bootstrapped_rids = ( - prev_bootstrapped_rids - ) - curr_good_bootstrapped_rids, curr_bad_bootstrapped_rids = self.get_rids( - self.disagg_prefill_bootstrap_queue.queue, - [KVPoll.WaitingForInput], - [KVPoll.Failed], - ) - good_bootstrapped_rids = list( - set(prev_good_bootstrapped_rids) & set(curr_good_bootstrapped_rids) - ) - bad_bootstrapped_rids = list( - set(prev_bad_bootstrapped_rids) | set(curr_bad_bootstrapped_rids) + bootstrapped_rids = self.recv_pyobj_from_prev_stage() + bootstrapped_reqs = ( + self.disagg_prefill_bootstrap_queue.pop_bootstrapped( + rids_to_check=bootstrapped_rids + ) ) - return [good_bootstrapped_rids, bad_bootstrapped_rids] + self.waiting_queue.extend(bootstrapped_reqs) + return bootstrapped_rids def _pp_pd_get_transferred_ids(self: Scheduler): # get the current stage transfer success if self.pp_group.is_first_rank: - transferred_rids = self.get_rids( - self.disagg_prefill_inflight_queue, - [KVPoll.Success, KVPoll.Failed], - ) + transferred_rids = self.get_transferred_rids() # if other ranks, do intersection with the previous rank's transferred rids else: # 2 (Release): Receive the transferred rids from the previous rank # 1. recv previous stage's transferred reqs info prev_transferred_rids = self.recv_pyobj_from_prev_stage() # 2. get the current stage's transferred reqs info - curr_transferred_rids = self.get_rids( - self.disagg_prefill_inflight_queue, - [KVPoll.Success, KVPoll.Failed], - ) + curr_transferred_rids = self.get_transferred_rids() # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) transferred_rids = list( set(prev_transferred_rids) & set(curr_transferred_rids) @@ -544,6 +528,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): send_bootstrapped_work = [] send_consensus_bootstrapped_work = [] send_proxy_work = [] + send_output_work = [] send_release_work = [] send_transfer_work = [] @@ -562,9 +547,11 @@ def event_loop_pp_disagg_prefill(self: Scheduler): next_batch_result = None recv_reqs = self.recv_requests() - self._pp_commit_comm_work(send_req_work) self.process_input_requests(recv_reqs) + if not self.pp_group.is_last_rank: + self._pp_commit_comm_work(send_req_work) + bootstrapped_rids = self._pp_pd_get_bootstrapped_ids() bmbs[mb_id] = bootstrapped_rids self._pp_commit_comm_work(send_bootstrapped_work) @@ -574,21 +561,20 @@ def event_loop_pp_disagg_prefill(self: Scheduler): tmbs[mb_id] = transferred_rids self.process_prefill_chunk() - batch = self.get_new_batch_prefill() if require_mlp_sync(self.server_args): batch = self.prepare_mlp_sync_batch(batch) mbs[mb_id] = batch - self.running_mbs[mb_id] = self.running_batch self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] if self.cur_batch: server_is_idle = False pp_proxy_tensors = self._pp_recv_proxy_tensors() - self._pp_commit_comm_work(send_proxy_work) + if self.server_args.pp_async_batch_depth > 0: - next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( self._pp_send_recv_and_preprocess_output_tensors( next_first_rank_mb_id, next_mb_id, @@ -598,12 +584,14 @@ def event_loop_pp_disagg_prefill(self: Scheduler): pp_outputs, ) ) + self._pp_commit_comm_work(send_proxy_work) if self.cur_batch: result, event = self._pp_launch_batch( mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue ) if self.server_args.pp_async_batch_depth == 0: - next_pp_outputs, next_batch_result, d2h_event = ( + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( self._pp_send_recv_and_preprocess_output_tensors( next_first_rank_mb_id, next_mb_id, @@ -664,9 +652,9 @@ def event_loop_pp_disagg_prefill(self: Scheduler): async_send=True, ) - if self.delayed_weight_sync_fn: - self.delayed_weight_sync_fn() - self.delayed_weight_sync_fn = None + #if self.delayed_weight_sync_fn: + #self.delayed_weight_sync_fn() + #self.delayed_weight_sync_fn = None pp_outputs = next_pp_outputs release_rids = next_release_rids From d3b26d028cde1f69a7a1188fee38cb58920ec65d Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 14 Nov 2025 16:29:04 +0800 Subject: [PATCH 21/28] fix Co-authored-by: bluecoffee8 Signed-off-by: Shangming Cai --- python/sglang/srt/disaggregation/prefill.py | 34 -------- .../sglang/srt/managers/scheduler_pp_mixin.py | 87 +++++++++++++------ 2 files changed, 60 insertions(+), 61 deletions(-) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 69eb22f30124..e11411b09f6d 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -55,7 +55,6 @@ SWAKVPool, ) from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end -from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj if TYPE_CHECKING: from torch.distributed import ProcessGroup @@ -699,36 +698,3 @@ def send_kv_chunk( ) return req.disagg_kv_sender.send(page_indices, state_indices) - - def send_pyobj_to_next_stage(self, data): - if self.attn_tp_rank == 0: - dp_offset = self.attn_dp_rank * self.attn_tp_size - point_to_point_pyobj( - data, - self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, - self.pp_rank * self.tp_size + dp_offset, - ((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset, - ) - - def recv_pyobj_from_prev_stage(self): - if self.attn_tp_rank == 0: - dp_offset = self.attn_dp_rank * self.attn_tp_size - data = point_to_point_pyobj( - [], - self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, - ((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset, - self.pp_rank * self.tp_size + dp_offset, - ) - else: - data = None - - if self.attn_tp_size != 1: - data = broadcast_pyobj( - data, - self.attn_tp_group.rank, - self.attn_tp_cpu_group, - src=self.attn_tp_group.ranks[0], - ) - return data diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 03df5cd16c0e..7c9a13b4929c 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -8,9 +8,9 @@ import torch from sglang.srt.disaggregation.base.conn import KVPoll -from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.disaggregation.utils import DisaggregationMode, poll_and_all_reduce from sglang.srt.distributed.parallel_state import P2PWork -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.utils import ( GenerationBatchResult, get_logprob_dict_from_result, @@ -50,9 +50,9 @@ def _pp_send_pyobj_to_next_stage(self: Scheduler, data, async_send: bool = False ) return p2p_work - def recv_pyobj_from_prev_stage(self: Scheduler): + def _pp_recv_pyobj_from_prev_stage(self: Scheduler): if self.attn_tp_rank == 0: - dp_offset = self.dp_rank * self.attn_tp_size + dp_offset = self.attn_dp_rank * self.attn_tp_size data = point_to_point_pyobj( [], self.pp_rank * self.tp_size + dp_offset, @@ -241,6 +241,21 @@ def _pp_launch_batch( ) return result, event + def get_rids(self: Scheduler, req_queue: List[Req], *poll_statuses): + """ + Used by PP, get the transferred rids but **do not pop** + """ + polls = poll_and_all_reduce( + [req.disagg_kv_sender for req in req_queue], + self.tp_worker.get_attention_tp_cpu_group(), + ) + rids: List = [] + for poll_status in poll_statuses: + rids.append( + [req.rid for req, poll in zip(req_queue, polls) if poll in poll_status] + ) + return tuple(rids) if len(rids) > 1 else rids[0] + @DynamicGradMode() def event_loop_pp(self: Scheduler): """ @@ -373,10 +388,15 @@ def process_bootstrapped_queue( ): # finished consensus bootstrapped reqs and prepare the waiting queue if bootstrapped_rids is not None: + ( + good_consensus_bootstrapped_rids, + bad_consensus_bootstrapped_rids, + ) = bootstrapped_rids good_reqs, failed_reqs = ( self.disagg_prefill_bootstrap_queue.pop_bootstrapped( return_failed_reqs=True, - rids_to_check=bootstrapped_rids, + rids_to_check=good_consensus_bootstrapped_rids + + bad_consensus_bootstrapped_rids, ) ) self.waiting_queue.extend(good_reqs) @@ -387,36 +407,47 @@ def _pp_pd_get_bootstrapped_ids(self: Scheduler): # communicate pre-consensus bootstrapp reqs if self.pp_group.is_first_rank: # First rank, pop the bootstrap reqs from the bootstrap queue - bootstrapped_reqs, failed_reqs = ( - self.disagg_prefill_bootstrap_queue.pop_bootstrapped( - return_failed_reqs=True - ) + good_bootstrapped_rids, bad_bootstrapped_rids = self.get_rids( + self.disagg_prefill_bootstrap_queue.queue, + [KVPoll.WaitingForInput], + [KVPoll.Failed], ) - bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [ - req.rid for req in failed_reqs - ] else: # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus - bootstrapped_rids = self.recv_pyobj_from_prev_stage() - bootstrapped_reqs = ( - self.disagg_prefill_bootstrap_queue.pop_bootstrapped( - rids_to_check=bootstrapped_rids - ) + prev_bootstrapped_rids = self._pp_recv_pyobj_from_prev_stage() + prev_good_bootstrapped_rids, prev_bad_bootstrapped_rids = ( + prev_bootstrapped_rids + ) + curr_good_bootstrapped_rids, curr_bad_bootstrapped_rids = self.get_rids( + self.disagg_prefill_bootstrap_queue.queue, + [KVPoll.WaitingForInput], + [KVPoll.Failed], + ) + good_bootstrapped_rids = list( + set(prev_good_bootstrapped_rids) & set(curr_good_bootstrapped_rids) ) - self.waiting_queue.extend(bootstrapped_reqs) - return bootstrapped_rids + bad_bootstrapped_rids = list( + set(prev_bad_bootstrapped_rids) | set(curr_bad_bootstrapped_rids) + ) + return [good_bootstrapped_rids, bad_bootstrapped_rids] def _pp_pd_get_transferred_ids(self: Scheduler): # get the current stage transfer success if self.pp_group.is_first_rank: - transferred_rids = self.get_transferred_rids() + transferred_rids = self.get_rids( + self.disagg_prefill_inflight_queue, + [KVPoll.Success, KVPoll.Failed], + ) # if other ranks, do intersection with the previous rank's transferred rids else: # 2 (Release): Receive the transferred rids from the previous rank # 1. recv previous stage's transferred reqs info - prev_transferred_rids = self.recv_pyobj_from_prev_stage() + prev_transferred_rids = self._pp_recv_pyobj_from_prev_stage() # 2. get the current stage's transferred reqs info - curr_transferred_rids = self.get_transferred_rids() + curr_transferred_rids = self.get_rids( + self.disagg_prefill_inflight_queue, + [KVPoll.Success, KVPoll.Failed], + ) # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) transferred_rids = list( set(prev_transferred_rids) & set(curr_transferred_rids) @@ -616,13 +647,15 @@ def event_loop_pp_disagg_prefill(self: Scheduler): ) if bmbs[next_mb_id] is not None: - next_consensus_bootstrapped_rids = self.recv_pyobj_from_prev_stage() + next_consensus_bootstrapped_rids = ( + self._pp_recv_pyobj_from_prev_stage() + ) next_consensus_bootstrapped_rids = self.process_bootstrapped_queue( next_consensus_bootstrapped_rids ) self._pp_commit_comm_work(send_consensus_bootstrapped_work) if tmbs[next_mb_id] is not None: - next_release_rids = self.recv_pyobj_from_prev_stage() + next_release_rids = self._pp_recv_pyobj_from_prev_stage() self._pp_commit_comm_work(send_release_work) # post-process the coming microbatch if mbs[next_mb_id] is not None: @@ -652,9 +685,9 @@ def event_loop_pp_disagg_prefill(self: Scheduler): async_send=True, ) - #if self.delayed_weight_sync_fn: - #self.delayed_weight_sync_fn() - #self.delayed_weight_sync_fn = None + if hasattr(self, "delayed_weight_sync_fn"): + self.delayed_weight_sync_fn() + self.delayed_weight_sync_fn = None pp_outputs = next_pp_outputs release_rids = next_release_rids From 30427d67390df7321b67224cfefa0ceaa0985cab Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Fri, 14 Nov 2025 16:39:28 +0800 Subject: [PATCH 22/28] fix Signed-off-by: Shangming Cai --- python/sglang/srt/managers/scheduler_pp_mixin.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 7c9a13b4929c..109e94940b27 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -241,18 +241,22 @@ def _pp_launch_batch( ) return result, event - def get_rids(self: Scheduler, req_queue: List[Req], *poll_statuses): + def get_rids(self: Scheduler, req_queue: List[Req], *poll_statuses_group): """ - Used by PP, get the transferred rids but **do not pop** + Used by PP, get the required rids with the given poll statuses. """ polls = poll_and_all_reduce( [req.disagg_kv_sender for req in req_queue], self.tp_worker.get_attention_tp_cpu_group(), ) rids: List = [] - for poll_status in poll_statuses: + for poll_statuses in poll_statuses_group: rids.append( - [req.rid for req, poll in zip(req_queue, polls) if poll in poll_status] + [ + req.rid + for req, poll in zip(req_queue, polls) + if poll in poll_statuses + ] ) return tuple(rids) if len(rids) > 1 else rids[0] From 026b942cb2e6d8c6616cc73f4055ef34a0764972 Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Mon, 24 Nov 2025 16:06:59 +0800 Subject: [PATCH 23/28] fix Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler_pp_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 109e94940b27..21c022584fbd 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -63,7 +63,7 @@ def _pp_recv_pyobj_from_prev_stage(self: Scheduler): else: data = None - if self.tp_size != 1: + if self.attn_tp_size != 1: data = broadcast_pyobj( data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0] ) From 59c8ef30e718429b712d3ecdfc1af81d9fd72ee2 Mon Sep 17 00:00:00 2001 From: Xuchun Shang Date: Mon, 24 Nov 2025 16:40:54 +0800 Subject: [PATCH 24/28] fix Signed-off-by: Xuchun Shang --- python/sglang/srt/managers/scheduler_pp_mixin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index 21c022584fbd..ac045d0d9163 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -65,7 +65,10 @@ def _pp_recv_pyobj_from_prev_stage(self: Scheduler): if self.attn_tp_size != 1: data = broadcast_pyobj( - data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0] + data, + self.attn_tp_group.rank, + self.attn_tp_cpu_group, + src=self.attn_tp_group.ranks[0], ) return data From 401560853a57ada58d45cbd40ec381a326e716f8 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Tue, 25 Nov 2025 21:16:38 +0800 Subject: [PATCH 25/28] add dynamic chunk support Co-authored-by: bluecoffee8 Co-authored-by: Xuchun Shang Co-authored-by: ybyang <10629930+whybeyoung@users.noreply.github.com> Signed-off-by: Shangming Cai --- python/sglang/srt/managers/schedule_batch.py | 4 + python/sglang/srt/managers/scheduler.py | 24 +- .../sglang/srt/managers/scheduler_pp_mixin.py | 370 ++++++++++++++++++ python/sglang/srt/server_args.py | 7 + 4 files changed, 404 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6a7f44a310e4..6ed360941044 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -463,6 +463,7 @@ def __init__( extra_key: Optional[str] = None, dimensions: Optional[int] = None, http_worker_ipc: Optional[str] = None, + chunked_prefill_sizes: Optional[List[int]] = None, ): # Input and output info self.rid = rid @@ -683,6 +684,9 @@ def __init__( # For Matryoshka embeddings self.dimensions = dimensions + # For dynamic chunked prefill + self.chunked_prefill_sizes = None + @property def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c1ad0423fc78..969dbc2aa9f1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -453,6 +453,18 @@ def __init__( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) + # Init the dynamic chunking predictor for PP + if self.pp_size > 1 and server_args.enable_dynamic_chunking: + self.init_pp_dynamic_chunk_size(server_args) + try: + self.profile_pp_prefill_latency() + except Exception as e: + logger.warning( + f"[PP Dynamic Chunk] Failed to profile prefill latency: {e}. " + "Dynamic chunking will be disabled." + ) + self.enable_dynamic_chunking = False + # Init the grammar backend for constrained generation self.grammar_queue: List[Req] = [] if not server_args.skip_tokenizer_init: @@ -1702,6 +1714,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # in the waiting queue. return None + # Determine chunked_prefill_size for this batch + chunked_prefill_size = self.chunked_prefill_size + if self.chunked_req is not None: + self.chunked_req.init_next_round_input() + if self.enable_dynamic_chunking: + history_len = len(self.chunked_req.prefix_indices) + dynamic_size = self.predict_next_chunk_size(history_len) + if dynamic_size is not None: + chunked_prefill_size = dynamic_size + # Prefill policy adder = PrefillAdder( self.page_size, @@ -1710,7 +1732,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch, self.new_token_ratio, self.max_prefill_tokens, - self.chunked_prefill_size, + chunked_prefill_size, running_bs if self.is_mixed_chunk else 0, self.priority_scheduling_preemption_threshold, ) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index ac045d0d9163..b70f73cfce21 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -1,11 +1,15 @@ from __future__ import annotations import logging +import math +import time from collections import deque from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +import numpy as np import torch +import torch.distributed from sglang.srt.disaggregation.base.conn import KVPoll from sglang.srt.disaggregation.utils import DisaggregationMode, poll_and_all_reduce @@ -17,6 +21,7 @@ get_logprob_from_pp_outputs, ) from sglang.srt.model_executor.forward_batch_info import PPProxyTensors +from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.utils import DynamicGradMode, broadcast_pyobj, point_to_point_pyobj logger = logging.getLogger(__name__) @@ -25,12 +30,377 @@ from sglang.srt.managers.scheduler import Scheduler +class ChunkSizePredictor: + """ + Predictor for dynamic chunk size based on quadratic latency model. + + Models latency as: f(l) = a*l^2 + b*l + c + Predicts next chunk size x such that: f(L+x) - f(L) = target_latency + """ + + def __init__(self): + self.quadratic_coeff_a = 0.0 + self.linear_coeff_b = 0.0 + self.constant_coeff_c = 0.0 + self.target_latency: Optional[float] = None + self.is_ready = False + + def fit(self, seq_lens: List[int], latencies: List[float]): + """Fit quadratic coefficients f(l) = al^2 + bl + c from data points.""" + L = np.array(seq_lens, dtype=np.float64) + T = np.array(latencies, dtype=np.float64) + + if len(L) < 8: + raise ValueError( + f"Not enough data points for quadratic fitting ({len(L)} < 8). " + "Need at least 8 samples with different sequence lengths." + ) + + # Build design matrix for f(l) = al^2 + bl + c + X = np.column_stack([L * L, L, np.ones_like(L)]) # [l^2, l, 1] + + try: + coeffs, residuals, rank, s = np.linalg.lstsq(X, T, rcond=None) + if len(coeffs) >= 3: + fitted_a = float(coeffs[0]) # quadratic coefficient + fitted_b = float(coeffs[1]) # linear coefficient + fitted_c = float(coeffs[2]) # constant coefficient + else: + raise ValueError("Failed to fit coefficients: insufficient rank") + except np.linalg.LinAlgError as e: + raise ValueError(f"Failed to fit f(l) = al^2 + bl + c: {e}") + + # Validate coefficients + if fitted_a <= 0: + raise ValueError( + f"Fitted quadratic coefficient a={fitted_a:.2e} is not positive. " + "Attention has O(n^2) complexity, so a must be positive. " + "Check warmup data quality." + ) + + if fitted_b < 0: + logger.warning( + f"Fitted linear coefficient b={fitted_b:.2e} is negative. Setting b=0." + ) + fitted_b = 0.0 + + self.quadratic_coeff_a = fitted_a + self.linear_coeff_b = fitted_b + self.constant_coeff_c = fitted_c + + logger.info( + f"[ChunkSizePredictor] Fitted coefficients: a={fitted_a:.2e}, " + f"b={fitted_b:.2e}, c={fitted_c:.2e}" + ) + + def set_target_latency(self, base_chunk_size: int): + """Set target latency based on base chunk size: target = f(base_chunk_size) - f(0).""" + + def f(l: float) -> float: + """Total latency function: f(l) = al^2 + bl + c (or bl + c for linear)""" + return ( + self.quadratic_coeff_a * l * l + + self.linear_coeff_b * l + + self.constant_coeff_c + ) + + self.target_latency = f(float(base_chunk_size)) - f(0.0) + + if self.target_latency <= 0: + raise ValueError( + f"Calculated target_latency={self.target_latency:.2f}ms is not positive. " + "Check warmup data quality." + ) + + logger.info( + f"[ChunkSizePredictor] Target latency: {self.target_latency:.2f}ms " + f"(base_chunk_size={base_chunk_size})" + ) + + def predict_next_chunk_size( + self, + history_len: int, + page_size: int, + context_len: int, + max_chunk_size: Optional[int] = None, + ) -> Optional[int]: + """ + Predict next chunk size x such that f(history_len + x) - f(history_len) = target_latency. + + Args: + history_len: Current sequence length (L) + page_size: Page size for alignment + context_len: Maximum context length + max_chunk_size: Maximum allowed chunk size (optional) + + Returns: + Predicted chunk size, or None if prediction fails + """ + if not self.is_ready or self.target_latency is None: + return None + + # Handle quadratic model: f(l) = al^2 + bl + c + if self.quadratic_coeff_a <= 0: + return None + + # Solve f(L+x) - f(L) = T + # where f(L) = a*L^2 + b*L + c + # This expands to: ax^2 + (2aL+b)x - T = 0 + # A = a, B = 2aL + b, C = -T + A = self.quadratic_coeff_a + B = 2 * self.quadratic_coeff_a * history_len + self.linear_coeff_b + C = -self.target_latency + + discriminant = B * B - 4 * A * C + + if discriminant < 0: + logger.warning( + f"Discriminant is negative ({discriminant:.2e}). " + f"No real solution for chunk size. L={history_len}, T={self.target_latency:.2f}ms." + ) + return None + + sqrt_discriminant = math.sqrt(discriminant) + calculated_chunk_size_float = (-B + sqrt_discriminant) / (2 * A) + + if calculated_chunk_size_float <= 0: + logger.warning( + f"Calculated chunk size is non-positive ({calculated_chunk_size_float:.2f}). " + f"L={history_len}, T={self.target_latency:.2f}ms." + ) + return None + + calculated_chunk_size = int(calculated_chunk_size_float) + + # Align to page_size (round down to nearest multiple) + alignment_size = max(page_size, 1) + dynamic_chunk_size = (calculated_chunk_size // alignment_size) * alignment_size + + # Ensure aligned size is at least alignment_size + if dynamic_chunk_size < alignment_size: + dynamic_chunk_size = alignment_size + + # Apply constraints + max_allowed = context_len - history_len - 100 # Leave 100 tokens margin + if max_chunk_size is not None: + max_allowed = min(max_allowed, max_chunk_size) + dynamic_chunk_size = min(dynamic_chunk_size, max_allowed) + + # Align again after min operation + dynamic_chunk_size = (dynamic_chunk_size // alignment_size) * alignment_size + + if dynamic_chunk_size < alignment_size: + return None + + return dynamic_chunk_size + + @dataclass class PPBatchMetadata: can_run_cuda_graph: bool class SchedulerPPMixin: + def init_pp_dynamic_chunk_size(self: "Scheduler", server_args): + """Initialize PP dynamic chunk size predictor.""" + # Initialize attributes to default values + # This ensures the attributes exist even when pp_size <= 1 + self.enable_dynamic_chunking = False + self.length_predictor = None + + if self.pp_size <= 1: + return + + self.length_predictor = ChunkSizePredictor() + # Enable dynamic chunking only if explicitly enabled via server_args + # and chunked_prefill_size is set + self.enable_dynamic_chunking = ( + server_args.enable_dynamic_chunking + and self.chunked_prefill_size is not None + and self.chunked_prefill_size > 0 + ) + + def profile_pp_prefill_latency(self: "Scheduler"): + """ + Profile prefill latency for dynamic chunk sizing. + + Only runs on PP0 (first rank), then broadcasts data to all ranks. + All ranks fit coefficients using the same data. + """ + # Early return if PP is not enabled or dynamic chunking is disabled + if self.pp_size <= 1: + return + if not self.enable_dynamic_chunking: + return + + seq_lens: List[int] = [] + latencies: List[float] = [] + + if self.pp_group.is_first_rank: + logger.info("Profiling prefill latency for dynamic chunk sizing...") + + # Create requests with different lengths: base_chunk_size // (2**i) for i in range(10) + input_ids_list = [] + for i in range(32): + chunk_size = self.chunked_prefill_size - i * ( + self.chunked_prefill_size // 32 + ) + if chunk_size <= 0: + break + input_ids = np.random.randint( + 0, 10000, size=chunk_size, dtype=np.int64 + ).tolist() + input_ids_list.append(input_ids) + + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=1, + ) + + # Create and profile requests + for i, input_ids in enumerate(input_ids_list): + req = Req( + rid=str(i), + origin_input_text="", + origin_input_ids=input_ids, + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + + # Prepare batch + batch = ScheduleBatch.init_new( + [req], + self.req_to_token_pool, + self.token_to_kv_pool_allocator, + self.tree_cache, + self.model_config, + False, + self.spec_algorithm, + ) + + current_seq_len = len(req.fill_ids) + proxy_tensors = { + "hidden_states": torch.zeros( + ( + current_seq_len, + self.tp_worker.model_runner.model_config.hidden_size, + ), + dtype=self.tp_worker.model_runner.model_config.dtype, + device="cuda", + ), + "residual": torch.zeros( + ( + current_seq_len, + self.tp_worker.model_runner.model_config.hidden_size, + ), + dtype=self.tp_worker.model_runner.model_config.dtype, + device="cuda", + ), + } + from sglang.srt.managers.scheduler_pp_mixin import PPProxyTensors + + pp_proxy = PPProxyTensors(proxy_tensors) + + # Measure latency with CUDA synchronization for accurate timing + # Synchronize before starting timing to ensure clean measurement + if torch.cuda.is_available(): + torch.cuda.synchronize() + + start = time.perf_counter() + batch.prepare_for_extend() + model_worker_batch = batch.get_model_worker_batch() + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + forward_batch = ForwardBatch.init_new( + model_worker_batch, self.tp_worker.model_runner + ) + _, _ = self.tp_worker.model_runner.forward( + forward_batch=forward_batch, pp_proxy_tensors=pp_proxy + ) + + # Synchronize after forward to ensure GPU operations complete + if torch.cuda.is_available(): + torch.cuda.synchronize() + + latency_seconds = time.perf_counter() - start + latency_ms = latency_seconds * 1e3 # Convert to milliseconds + seq_lens.append(len(input_ids)) + latencies.append(latency_ms) + + # Release KV cache + if req.req_pool_idx is not None: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(req.fill_ids) + ] + self.token_to_kv_pool_allocator.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + + logger.info( + f"[PP Dynamic Chunk] [PP0] Profiled {len(seq_lens)} samples: " + f"seq_lens={seq_lens}, latencies_ms={latencies}" + ) + + # Broadcast data to all ranks + if torch.distributed.is_available() and torch.distributed.is_initialized(): + data_to_sync = [seq_lens, latencies] + self.pp_group.broadcast_object_list(data_to_sync, src=0) + seq_lens, latencies = data_to_sync + + # All ranks fit coefficients using the same data + # Use model type specified by server_args + # Both models require at least 8 data points + if len(seq_lens) < 8: + logger.warning( + f"[PP Dynamic Chunk] [PP{self.pp_rank}] Not enough profiling data " + f"({len(seq_lens)} < 8). Both quadratic and linear models require at least 8 samples. " + f"Dynamic chunking disabled." + ) + return + + # Quadratic model: f(l) = al^2 + bl + c + self.length_predictor.fit(seq_lens, latencies) + self.length_predictor.set_target_latency(self.chunked_prefill_size) + self.length_predictor.is_ready = True + logger.info( + f"[PP Dynamic Chunk] [PP{self.pp_rank}] Predictor ready (quadratic). " + f"Target latency: {self.length_predictor.target_latency:.2f}ms" + ) + + def predict_next_chunk_size(self: "Scheduler", history_len: int) -> Optional[int]: + """ + Predict next chunk size dynamically based on current history length. + + Args: + history_len: Current sequence length + + Returns: + Predicted chunk size, or None to use default chunked_prefill_size + """ + if ( + not self.enable_dynamic_chunking + or self.length_predictor is None + or not self.length_predictor.is_ready + ): + return None + + max_chunk_size = getattr(self, "max_prefill_tokens", None) + predicted_size = self.length_predictor.predict_next_chunk_size( + history_len=history_len, + page_size=self.page_size, + context_len=self.model_config.context_len, + max_chunk_size=max_chunk_size, + ) + + if predicted_size is not None: + logger.debug( + f"[PP Dynamic Chunk] [PP{self.pp_rank}] Predicted chunk size: " + f"{predicted_size} (history_len={history_len})" + ) + + return predicted_size + def _pp_commit_comm_work(self: Scheduler, work: List[P2PWork]) -> None: for p2p_work in work: p2p_work.work.wait() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 16aaf07dbf0d..2a355f393b96 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -268,6 +268,7 @@ class ServerArgs: max_queued_requests: Optional[int] = None max_total_tokens: Optional[int] = None chunked_prefill_size: Optional[int] = None + enable_dynamic_chunking: bool = False max_prefill_tokens: int = 16384 schedule_policy: str = "fcfs" enable_priority_scheduling: bool = False @@ -2234,6 +2235,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.chunked_prefill_size, help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.", ) + parser.add_argument( + "--enable-dynamic-chunking", + action="store_true", + default=ServerArgs.enable_dynamic_chunking, + help="Enable dynamic chunk size adjustment for pipeline parallelism. When enabled, chunk sizes are dynamically calculated based on fitted function to maintain consistent execution time across chunks.", + ) parser.add_argument( "--max-prefill-tokens", type=int, From 368382a00ace16ac753eb3ee366365fdd9254dad Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Tue, 25 Nov 2025 21:29:12 +0800 Subject: [PATCH 26/28] clean Signed-off-by: Shangming Cai --- python/sglang/srt/managers/schedule_batch.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6ed360941044..6a7f44a310e4 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -463,7 +463,6 @@ def __init__( extra_key: Optional[str] = None, dimensions: Optional[int] = None, http_worker_ipc: Optional[str] = None, - chunked_prefill_sizes: Optional[List[int]] = None, ): # Input and output info self.rid = rid @@ -684,9 +683,6 @@ def __init__( # For Matryoshka embeddings self.dimensions = dimensions - # For dynamic chunked prefill - self.chunked_prefill_sizes = None - @property def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) From d13b07c8346242766a7d456385c7c8eb03859260 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Wed, 26 Nov 2025 14:52:09 +0800 Subject: [PATCH 27/28] remove redundant code Signed-off-by: Shangming Cai --- python/sglang/srt/managers/scheduler.py | 9 +++-- .../sglang/srt/managers/scheduler_pp_mixin.py | 39 +------------------ 2 files changed, 8 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 969dbc2aa9f1..3421cdeaf640 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -453,11 +453,14 @@ def __init__( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) + self.enable_dynamic_chunking = ( + server_args.enable_dynamic_chunking and self.pp_size > 1 + ) + # Init the dynamic chunking predictor for PP - if self.pp_size > 1 and server_args.enable_dynamic_chunking: - self.init_pp_dynamic_chunk_size(server_args) + if self.enable_dynamic_chunking: try: - self.profile_pp_prefill_latency() + self.profile_and_init_predictor() except Exception as e: logger.warning( f"[PP Dynamic Chunk] Failed to profile prefill latency: {e}. " diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index b70f73cfce21..bef6c5972bd1 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -201,38 +201,13 @@ class PPBatchMetadata: class SchedulerPPMixin: - def init_pp_dynamic_chunk_size(self: "Scheduler", server_args): - """Initialize PP dynamic chunk size predictor.""" - # Initialize attributes to default values - # This ensures the attributes exist even when pp_size <= 1 - self.enable_dynamic_chunking = False - self.length_predictor = None - - if self.pp_size <= 1: - return - - self.length_predictor = ChunkSizePredictor() - # Enable dynamic chunking only if explicitly enabled via server_args - # and chunked_prefill_size is set - self.enable_dynamic_chunking = ( - server_args.enable_dynamic_chunking - and self.chunked_prefill_size is not None - and self.chunked_prefill_size > 0 - ) - - def profile_pp_prefill_latency(self: "Scheduler"): + def profile_and_init_predictor(self: Scheduler): """ Profile prefill latency for dynamic chunk sizing. Only runs on PP0 (first rank), then broadcasts data to all ranks. All ranks fit coefficients using the same data. """ - # Early return if PP is not enabled or dynamic chunking is disabled - if self.pp_size <= 1: - return - if not self.enable_dynamic_chunking: - return - seq_lens: List[int] = [] latencies: List[float] = [] @@ -348,18 +323,8 @@ def profile_pp_prefill_latency(self: "Scheduler"): self.pp_group.broadcast_object_list(data_to_sync, src=0) seq_lens, latencies = data_to_sync - # All ranks fit coefficients using the same data - # Use model type specified by server_args - # Both models require at least 8 data points - if len(seq_lens) < 8: - logger.warning( - f"[PP Dynamic Chunk] [PP{self.pp_rank}] Not enough profiling data " - f"({len(seq_lens)} < 8). Both quadratic and linear models require at least 8 samples. " - f"Dynamic chunking disabled." - ) - return - # Quadratic model: f(l) = al^2 + bl + c + self.length_predictor = ChunkSizePredictor() self.length_predictor.fit(seq_lens, latencies) self.length_predictor.set_target_latency(self.chunked_prefill_size) self.length_predictor.is_ready = True From 34faf31cd6d7350b7cbccbc55c0c444612bdc5d0 Mon Sep 17 00:00:00 2001 From: Shangming Cai Date: Thu, 4 Dec 2025 19:01:08 +0800 Subject: [PATCH 28/28] add smooth coeff Signed-off-by: Shangming Cai --- python/sglang/srt/disaggregation/prefill.py | 2 -- python/sglang/srt/environ.py | 1 + python/sglang/srt/managers/scheduler_pp_mixin.py | 11 ++++++++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 4da1b63947ab..95bab1240369 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -251,8 +251,6 @@ def pop_bootstrapped( # if req not in reqs_info_to_check, skip if req.rid not in rids_to_check: continue - # Either waiting for input or failed - assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed if poll == KVPoll.Bootstrapping: continue diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index db9774f31c68..36672dc92bed 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -192,6 +192,7 @@ class Envs: SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP = EnvBool(False) SGLANG_SCHEDULER_MAX_RECV_PER_POLL = EnvInt(-1) SGLANG_EXPERIMENTAL_CPP_RADIX_TREE = EnvBool(False) + SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR = EnvFloat(0.75) # Test: pd-disaggregation SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake") diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index bef6c5972bd1..e1f30ceade01 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -14,6 +14,7 @@ from sglang.srt.disaggregation.base.conn import KVPoll from sglang.srt.disaggregation.utils import DisaggregationMode, poll_and_all_reduce from sglang.srt.distributed.parallel_state import P2PWork +from sglang.srt.environ import envs from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.utils import ( GenerationBatchResult, @@ -120,6 +121,7 @@ def f(l: float) -> float: def predict_next_chunk_size( self, history_len: int, + base_chunk_size: int, page_size: int, context_len: int, max_chunk_size: Optional[int] = None, @@ -129,6 +131,7 @@ def predict_next_chunk_size( Args: history_len: Current sequence length (L) + base_chunk_size: Base chunk size page_size: Page size for alignment context_len: Maximum context length max_chunk_size: Maximum allowed chunk size (optional) @@ -170,7 +173,12 @@ def predict_next_chunk_size( ) return None - calculated_chunk_size = int(calculated_chunk_size_float) + # Use a smooth coefficient to reduce the abrupt decrease in chunk size + smooth_coeff = envs.SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR.get() + smoothed_chunk_size = base_chunk_size + smooth_coeff * ( + calculated_chunk_size_float - base_chunk_size + ) + calculated_chunk_size = int(smoothed_chunk_size) # Align to page_size (round down to nearest multiple) alignment_size = max(page_size, 1) @@ -353,6 +361,7 @@ def predict_next_chunk_size(self: "Scheduler", history_len: int) -> Optional[int max_chunk_size = getattr(self, "max_prefill_tokens", None) predicted_size = self.length_predictor.predict_next_chunk_size( history_len=history_len, + base_chunk_size=self.chunked_prefill_size, page_size=self.page_size, context_len=self.model_config.context_len, max_chunk_size=max_chunk_size,