diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 952374ed5868..95bab1240369 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -55,7 +55,6 @@ SWAKVPool, ) from sglang.srt.tracing.trace import trace_event_batch, trace_slice, trace_slice_end -from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj if TYPE_CHECKING: from torch.distributed import ProcessGroup @@ -252,8 +251,6 @@ def pop_bootstrapped( # if req not in reqs_info_to_check, skip if req.rid not in rids_to_check: continue - # Either waiting for input or failed - assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed if poll == KVPoll.Bootstrapping: continue @@ -710,36 +707,3 @@ def send_kv_chunk( ) return req.disagg_kv_sender.send(page_indices, state_indices) - - def send_pyobj_to_next_stage(self, data): - if self.attn_tp_rank == 0: - dp_offset = self.attn_dp_rank * self.attn_tp_size - point_to_point_pyobj( - data, - self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, - self.pp_rank * self.tp_size + dp_offset, - ((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset, - ) - - def recv_pyobj_from_prev_stage(self): - if self.attn_tp_rank == 0: - dp_offset = self.attn_dp_rank * self.attn_tp_size - data = point_to_point_pyobj( - [], - self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, - ((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset, - self.pp_rank * self.tp_size + dp_offset, - ) - else: - data = None - - if self.attn_tp_size != 1: - data = broadcast_pyobj( - data, - self.attn_tp_group.rank, - self.attn_tp_cpu_group, - src=self.attn_tp_group.ranks[0], - ) - return data diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 3afc41ca54b2..98d4fc33ac7d 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -192,6 +192,7 @@ class Envs: SGLANG_DISABLE_CONSECUTIVE_PREFILL_OVERLAP = EnvBool(False) SGLANG_SCHEDULER_MAX_RECV_PER_POLL = EnvInt(-1) SGLANG_EXPERIMENTAL_CPP_RADIX_TREE = EnvBool(False) + SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR = EnvFloat(0.75) # Test: pd-disaggregation SGLANG_TEST_PD_DISAGG_BACKEND = EnvStr("mooncake") diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 152d9717b221..94ca26391565 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -155,7 +155,7 @@ from sglang.srt.mem_cache.cache_init_params import CacheInitParams from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.mem_cache.radix_cache import RadixCache -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin from sglang.srt.parser.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args @@ -472,6 +472,21 @@ def __init__( self.chunked_prefill_size is not None and server_args.enable_mixed_chunk ) + self.enable_dynamic_chunking = ( + server_args.enable_dynamic_chunking and self.pp_size > 1 + ) + + # Init the dynamic chunking predictor for PP + if self.enable_dynamic_chunking: + try: + self.profile_and_init_predictor() + except Exception as e: + logger.warning( + f"[PP Dynamic Chunk] Failed to profile prefill latency: {e}. " + "Dynamic chunking will be disabled." + ) + self.enable_dynamic_chunking = False + # Init the grammar backend for constrained generation self.grammar_queue: List[Req] = [] if not server_args.skip_tokenizer_init: @@ -934,8 +949,7 @@ def init_disaggregation(self): def init_overlap(self): self.future_map = None - - if not self.enable_overlap: + if not self.enable_overlap and self.pp_size == 1: return self.forward_stream: CudaStream = torch.get_device_module(self.device).Stream() @@ -947,6 +961,9 @@ def init_overlap(self): self.device ).stream(self.copy_stream) + if not self.enable_overlap: + return + self.future_map = FutureMap( self.max_running_requests, self.chunked_prefill_size, @@ -1108,7 +1125,7 @@ def recv_requests( recv_reqs = point_to_point_pyobj( [], self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, + self.world_group.cpu_group, (self.pp_rank - 1) * self.tp_size + dp_offset, self.pp_rank * self.tp_size + dp_offset, ) @@ -1766,6 +1783,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # in the waiting queue. return None + # Determine chunked_prefill_size for this batch + chunked_prefill_size = self.chunked_prefill_size + if self.chunked_req is not None: + self.chunked_req.init_next_round_input() + if self.enable_dynamic_chunking: + history_len = len(self.chunked_req.prefix_indices) + dynamic_size = self.predict_next_chunk_size(history_len) + if dynamic_size is not None: + chunked_prefill_size = dynamic_size + # Prefill policy adder = PrefillAdder( self.page_size, @@ -1774,7 +1801,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.running_batch, self.new_token_ratio, self.max_prefill_tokens, - self.chunked_prefill_size, + chunked_prefill_size, running_bs if self.is_mixed_chunk else 0, self.priority_scheduling_preemption_threshold, ) @@ -1966,7 +1993,9 @@ def update_cache_from_scheduler( pass def run_batch( - self, batch: ScheduleBatch + self, + batch: ScheduleBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: """Run a batch.""" self.forward_ct += 1 @@ -2014,6 +2043,7 @@ def run_batch( self.future_map.resolve_future(model_worker_batch) batch_result = self.model_worker.forward_batch_generation( model_worker_batch + # here pp is not compatible with overlap ) # FIXME(lsyin): maybe move this to forward_batch_generation batch_result.copy_done = torch.get_device_module( @@ -2047,8 +2077,13 @@ def run_batch( batch_result = self.tp_worker.forward_batch_split_prefill(batch) future_indices_or_next_token_ids = batch_result.next_token_ids else: + kwargs = ( + {"pp_proxy_tensors": pp_proxy_tensors} + if self.spec_algorithm.is_none() + else {} + ) batch_result = self.model_worker.forward_batch_generation( - batch_or_worker_batch + batch_or_worker_batch, **kwargs ) future_indices_or_next_token_ids = batch_result.next_token_ids self.update_cache_from_scheduler(batch, batch_result) diff --git a/python/sglang/srt/managers/scheduler_pp_mixin.py b/python/sglang/srt/managers/scheduler_pp_mixin.py index cfa4e2369bd6..e1f30ceade01 100644 --- a/python/sglang/srt/managers/scheduler_pp_mixin.py +++ b/python/sglang/srt/managers/scheduler_pp_mixin.py @@ -1,349 +1,1054 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional +import logging +import math +import time +from collections import deque +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.managers.schedule_batch import ScheduleBatch -from sglang.srt.managers.utils import GenerationBatchResult +import numpy as np +import torch +import torch.distributed + +from sglang.srt.disaggregation.base.conn import KVPoll +from sglang.srt.disaggregation.utils import DisaggregationMode, poll_and_all_reduce +from sglang.srt.distributed.parallel_state import P2PWork +from sglang.srt.environ import envs +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.utils import ( + GenerationBatchResult, + get_logprob_dict_from_result, + get_logprob_from_pp_outputs, +) from sglang.srt.model_executor.forward_batch_info import PPProxyTensors -from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj +from sglang.srt.sampling.sampling_params import SamplingParams +from sglang.srt.utils import DynamicGradMode, broadcast_pyobj, point_to_point_pyobj + +logger = logging.getLogger(__name__) if TYPE_CHECKING: from sglang.srt.managers.scheduler import Scheduler +class ChunkSizePredictor: + """ + Predictor for dynamic chunk size based on quadratic latency model. + + Models latency as: f(l) = a*l^2 + b*l + c + Predicts next chunk size x such that: f(L+x) - f(L) = target_latency + """ + + def __init__(self): + self.quadratic_coeff_a = 0.0 + self.linear_coeff_b = 0.0 + self.constant_coeff_c = 0.0 + self.target_latency: Optional[float] = None + self.is_ready = False + + def fit(self, seq_lens: List[int], latencies: List[float]): + """Fit quadratic coefficients f(l) = al^2 + bl + c from data points.""" + L = np.array(seq_lens, dtype=np.float64) + T = np.array(latencies, dtype=np.float64) + + if len(L) < 8: + raise ValueError( + f"Not enough data points for quadratic fitting ({len(L)} < 8). " + "Need at least 8 samples with different sequence lengths." + ) + + # Build design matrix for f(l) = al^2 + bl + c + X = np.column_stack([L * L, L, np.ones_like(L)]) # [l^2, l, 1] + + try: + coeffs, residuals, rank, s = np.linalg.lstsq(X, T, rcond=None) + if len(coeffs) >= 3: + fitted_a = float(coeffs[0]) # quadratic coefficient + fitted_b = float(coeffs[1]) # linear coefficient + fitted_c = float(coeffs[2]) # constant coefficient + else: + raise ValueError("Failed to fit coefficients: insufficient rank") + except np.linalg.LinAlgError as e: + raise ValueError(f"Failed to fit f(l) = al^2 + bl + c: {e}") + + # Validate coefficients + if fitted_a <= 0: + raise ValueError( + f"Fitted quadratic coefficient a={fitted_a:.2e} is not positive. " + "Attention has O(n^2) complexity, so a must be positive. " + "Check warmup data quality." + ) + + if fitted_b < 0: + logger.warning( + f"Fitted linear coefficient b={fitted_b:.2e} is negative. Setting b=0." + ) + fitted_b = 0.0 + + self.quadratic_coeff_a = fitted_a + self.linear_coeff_b = fitted_b + self.constant_coeff_c = fitted_c + + logger.info( + f"[ChunkSizePredictor] Fitted coefficients: a={fitted_a:.2e}, " + f"b={fitted_b:.2e}, c={fitted_c:.2e}" + ) + + def set_target_latency(self, base_chunk_size: int): + """Set target latency based on base chunk size: target = f(base_chunk_size) - f(0).""" + + def f(l: float) -> float: + """Total latency function: f(l) = al^2 + bl + c (or bl + c for linear)""" + return ( + self.quadratic_coeff_a * l * l + + self.linear_coeff_b * l + + self.constant_coeff_c + ) + + self.target_latency = f(float(base_chunk_size)) - f(0.0) + + if self.target_latency <= 0: + raise ValueError( + f"Calculated target_latency={self.target_latency:.2f}ms is not positive. " + "Check warmup data quality." + ) + + logger.info( + f"[ChunkSizePredictor] Target latency: {self.target_latency:.2f}ms " + f"(base_chunk_size={base_chunk_size})" + ) + + def predict_next_chunk_size( + self, + history_len: int, + base_chunk_size: int, + page_size: int, + context_len: int, + max_chunk_size: Optional[int] = None, + ) -> Optional[int]: + """ + Predict next chunk size x such that f(history_len + x) - f(history_len) = target_latency. + + Args: + history_len: Current sequence length (L) + base_chunk_size: Base chunk size + page_size: Page size for alignment + context_len: Maximum context length + max_chunk_size: Maximum allowed chunk size (optional) + + Returns: + Predicted chunk size, or None if prediction fails + """ + if not self.is_ready or self.target_latency is None: + return None + + # Handle quadratic model: f(l) = al^2 + bl + c + if self.quadratic_coeff_a <= 0: + return None + + # Solve f(L+x) - f(L) = T + # where f(L) = a*L^2 + b*L + c + # This expands to: ax^2 + (2aL+b)x - T = 0 + # A = a, B = 2aL + b, C = -T + A = self.quadratic_coeff_a + B = 2 * self.quadratic_coeff_a * history_len + self.linear_coeff_b + C = -self.target_latency + + discriminant = B * B - 4 * A * C + + if discriminant < 0: + logger.warning( + f"Discriminant is negative ({discriminant:.2e}). " + f"No real solution for chunk size. L={history_len}, T={self.target_latency:.2f}ms." + ) + return None + + sqrt_discriminant = math.sqrt(discriminant) + calculated_chunk_size_float = (-B + sqrt_discriminant) / (2 * A) + + if calculated_chunk_size_float <= 0: + logger.warning( + f"Calculated chunk size is non-positive ({calculated_chunk_size_float:.2f}). " + f"L={history_len}, T={self.target_latency:.2f}ms." + ) + return None + + # Use a smooth coefficient to reduce the abrupt decrease in chunk size + smooth_coeff = envs.SGLANG_DYNAMIC_CHUNKING_SMOOTH_FACTOR.get() + smoothed_chunk_size = base_chunk_size + smooth_coeff * ( + calculated_chunk_size_float - base_chunk_size + ) + calculated_chunk_size = int(smoothed_chunk_size) + + # Align to page_size (round down to nearest multiple) + alignment_size = max(page_size, 1) + dynamic_chunk_size = (calculated_chunk_size // alignment_size) * alignment_size + + # Ensure aligned size is at least alignment_size + if dynamic_chunk_size < alignment_size: + dynamic_chunk_size = alignment_size + + # Apply constraints + max_allowed = context_len - history_len - 100 # Leave 100 tokens margin + if max_chunk_size is not None: + max_allowed = min(max_allowed, max_chunk_size) + dynamic_chunk_size = min(dynamic_chunk_size, max_allowed) + + # Align again after min operation + dynamic_chunk_size = (dynamic_chunk_size // alignment_size) * alignment_size + + if dynamic_chunk_size < alignment_size: + return None + + return dynamic_chunk_size + + +@dataclass +class PPBatchMetadata: + can_run_cuda_graph: bool + + class SchedulerPPMixin: + def profile_and_init_predictor(self: Scheduler): + """ + Profile prefill latency for dynamic chunk sizing. + + Only runs on PP0 (first rank), then broadcasts data to all ranks. + All ranks fit coefficients using the same data. + """ + seq_lens: List[int] = [] + latencies: List[float] = [] + + if self.pp_group.is_first_rank: + logger.info("Profiling prefill latency for dynamic chunk sizing...") + + # Create requests with different lengths: base_chunk_size // (2**i) for i in range(10) + input_ids_list = [] + for i in range(32): + chunk_size = self.chunked_prefill_size - i * ( + self.chunked_prefill_size // 32 + ) + if chunk_size <= 0: + break + input_ids = np.random.randint( + 0, 10000, size=chunk_size, dtype=np.int64 + ).tolist() + input_ids_list.append(input_ids) + + sampling_params = SamplingParams( + temperature=0, + max_new_tokens=1, + ) + + # Create and profile requests + for i, input_ids in enumerate(input_ids_list): + req = Req( + rid=str(i), + origin_input_text="", + origin_input_ids=input_ids, + sampling_params=sampling_params, + ) + req.fill_ids = req.origin_input_ids + req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) + req.logprob_start_len = len(req.origin_input_ids) - 1 + + # Prepare batch + batch = ScheduleBatch.init_new( + [req], + self.req_to_token_pool, + self.token_to_kv_pool_allocator, + self.tree_cache, + self.model_config, + False, + self.spec_algorithm, + ) + + current_seq_len = len(req.fill_ids) + proxy_tensors = { + "hidden_states": torch.zeros( + ( + current_seq_len, + self.tp_worker.model_runner.model_config.hidden_size, + ), + dtype=self.tp_worker.model_runner.model_config.dtype, + device="cuda", + ), + "residual": torch.zeros( + ( + current_seq_len, + self.tp_worker.model_runner.model_config.hidden_size, + ), + dtype=self.tp_worker.model_runner.model_config.dtype, + device="cuda", + ), + } + from sglang.srt.managers.scheduler_pp_mixin import PPProxyTensors + + pp_proxy = PPProxyTensors(proxy_tensors) + + # Measure latency with CUDA synchronization for accurate timing + # Synchronize before starting timing to ensure clean measurement + if torch.cuda.is_available(): + torch.cuda.synchronize() + + start = time.perf_counter() + batch.prepare_for_extend() + model_worker_batch = batch.get_model_worker_batch() + from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + forward_batch = ForwardBatch.init_new( + model_worker_batch, self.tp_worker.model_runner + ) + _, _ = self.tp_worker.model_runner.forward( + forward_batch=forward_batch, pp_proxy_tensors=pp_proxy + ) + + # Synchronize after forward to ensure GPU operations complete + if torch.cuda.is_available(): + torch.cuda.synchronize() + + latency_seconds = time.perf_counter() - start + latency_ms = latency_seconds * 1e3 # Convert to milliseconds + seq_lens.append(len(input_ids)) + latencies.append(latency_ms) + + # Release KV cache + if req.req_pool_idx is not None: + kv_indices = self.req_to_token_pool.req_to_token[ + req.req_pool_idx, : len(req.fill_ids) + ] + self.token_to_kv_pool_allocator.free(kv_indices) + self.req_to_token_pool.free(req.req_pool_idx) + + logger.info( + f"[PP Dynamic Chunk] [PP0] Profiled {len(seq_lens)} samples: " + f"seq_lens={seq_lens}, latencies_ms={latencies}" + ) + + # Broadcast data to all ranks + if torch.distributed.is_available() and torch.distributed.is_initialized(): + data_to_sync = [seq_lens, latencies] + self.pp_group.broadcast_object_list(data_to_sync, src=0) + seq_lens, latencies = data_to_sync + + # Quadratic model: f(l) = al^2 + bl + c + self.length_predictor = ChunkSizePredictor() + self.length_predictor.fit(seq_lens, latencies) + self.length_predictor.set_target_latency(self.chunked_prefill_size) + self.length_predictor.is_ready = True + logger.info( + f"[PP Dynamic Chunk] [PP{self.pp_rank}] Predictor ready (quadratic). " + f"Target latency: {self.length_predictor.target_latency:.2f}ms" + ) + + def predict_next_chunk_size(self: "Scheduler", history_len: int) -> Optional[int]: + """ + Predict next chunk size dynamically based on current history length. + + Args: + history_len: Current sequence length + + Returns: + Predicted chunk size, or None to use default chunked_prefill_size + """ + if ( + not self.enable_dynamic_chunking + or self.length_predictor is None + or not self.length_predictor.is_ready + ): + return None + + max_chunk_size = getattr(self, "max_prefill_tokens", None) + predicted_size = self.length_predictor.predict_next_chunk_size( + history_len=history_len, + base_chunk_size=self.chunked_prefill_size, + page_size=self.page_size, + context_len=self.model_config.context_len, + max_chunk_size=max_chunk_size, + ) + + if predicted_size is not None: + logger.debug( + f"[PP Dynamic Chunk] [PP{self.pp_rank}] Predicted chunk size: " + f"{predicted_size} (history_len={history_len})" + ) + + return predicted_size + + def _pp_commit_comm_work(self: Scheduler, work: List[P2PWork]) -> None: + for p2p_work in work: + p2p_work.work.wait() + work.clear() + + def _pp_send_pyobj_to_next_stage(self: Scheduler, data, async_send: bool = False): + p2p_work = [] + if self.attn_tp_rank == 0: + dp_offset = self.attn_dp_rank * self.attn_tp_size + p2p_work = point_to_point_pyobj( + data, + self.pp_rank * self.tp_size + dp_offset, + self.world_group.cpu_group, + self.pp_rank * self.tp_size + dp_offset, + ((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset, + async_send=async_send, + ) + return p2p_work + + def _pp_recv_pyobj_from_prev_stage(self: Scheduler): + if self.attn_tp_rank == 0: + dp_offset = self.attn_dp_rank * self.attn_tp_size + data = point_to_point_pyobj( + [], + self.pp_rank * self.tp_size + dp_offset, + self.world_group.cpu_group, + ((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset, + self.pp_rank * self.tp_size + dp_offset, + ) + else: + data = None + + if self.attn_tp_size != 1: + data = broadcast_pyobj( + data, + self.attn_tp_group.rank, + self.attn_tp_cpu_group, + src=self.attn_tp_group.ranks[0], + ) + + return data + + def _pp_prepare_tensor_dict( + self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch + ) -> Dict[str, torch.Tensor]: + tensor_dict = { + "next_token_ids": result.next_token_ids, + } + + if batch.return_logprob: + logprob_dict = get_logprob_dict_from_result(result) + tensor_dict = { + **tensor_dict, + **logprob_dict, + } + return tensor_dict + + def _pp_send_dict_to_next_stage( + self: Scheduler, + tensor_dict: Dict[str, torch.Tensor], + async_send: bool = True, + ): + p2p_work = [] + p2p_work.extend( + self.pp_group.send_tensor_dict( + tensor_dict=tensor_dict, + all_gather_group=self.attn_tp_group, + async_send=async_send, + ) + ) + return p2p_work + + def _pp_recv_proxy_tensors(self: Scheduler) -> Optional[PPProxyTensors]: + pp_proxy_tensors = None + if not self.pp_group.is_first_rank: + pp_proxy_tensors = PPProxyTensors( + self.pp_group.recv_tensor_dict(all_gather_group=self.attn_tp_group) + ) + return pp_proxy_tensors + + def _pp_recv_dict_from_prev_stage( + self: Scheduler, + ) -> Dict[str, torch.Tensor]: + res = self.pp_group.recv_tensor_dict( + all_gather_group=self.attn_tp_group, + ) + return res + + def _pp_prep_batch_result( + self: Scheduler, + batch: ScheduleBatch, + mb_metadata: PPBatchMetadata, + pp_outputs: PPProxyTensors, + ): + from sglang.srt.managers.scheduler import GenerationBatchResult + + logits_output = None + extend_input_len_per_req = None + extend_logprob_start_len_per_req = None + + if batch.return_logprob: + ( + logits_output, + extend_input_len_per_req, + extend_logprob_start_len_per_req, + ) = get_logprob_from_pp_outputs(pp_outputs) + batch.output_ids = pp_outputs["next_token_ids"] + output_result = GenerationBatchResult( + logits_output=logits_output, + pp_hidden_states_proxy_tensors=None, + next_token_ids=pp_outputs["next_token_ids"], + extend_input_len_per_req=extend_input_len_per_req, + extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, + can_run_cuda_graph=mb_metadata.can_run_cuda_graph, + ) + return output_result + + def _pp_process_batch_result( + self: Scheduler, batch: ScheduleBatch, output_result: GenerationBatchResult + ): + if self.disaggregation_mode == DisaggregationMode.PREFILL: + self.process_batch_result_disagg_prefill(batch, output_result) + else: + self.process_batch_result(batch, output_result) + + def _pp_send_output_to_next_stage( + self: Scheduler, + next_first_rank_mb_id: int, + mbs: List[ScheduleBatch], + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], + pp_outputs: PPProxyTensors | None, + ) -> List[P2PWork]: + send_output_work = [] + if self.pp_group.is_last_rank: + # send ready PP output to rank 0 + if mbs[next_first_rank_mb_id] is not None: + q_event, pp_outputs_to_send = last_rank_comm_queue.popleft() + torch.cuda.current_stream().wait_event(q_event) + with torch.profiler.record_function("send_res_dict_to_next_stage"): + send_output_work = self._pp_send_dict_to_next_stage( + pp_outputs_to_send.tensors, + async_send=True, + ) + # send the outputs from the last round to let the next stage worker run post processing + if not self.pp_group.is_last_rank: + if pp_outputs: + with torch.profiler.record_function("send_res_dict_to_next_stage"): + send_output_work = self._pp_send_dict_to_next_stage( + pp_outputs.tensors, + async_send=True, + ) + return send_output_work + + def _pp_send_recv_and_preprocess_output_tensors( + self: Scheduler, + next_first_rank_mb_id: int, + next_mb_id: int, + mbs: List[ScheduleBatch], + mb_metadata: List[PPBatchMetadata], + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], + pp_outputs: PPProxyTensors | None, + ) -> Tuple[PPProxyTensors, List[P2PWork], torch.cuda.Event]: + next_pp_outputs = None + d2h_event = None + batch_result = None + send_output_work = self._pp_send_output_to_next_stage( + next_first_rank_mb_id, + mbs, + last_rank_comm_queue, + pp_outputs, + ) + + if mbs[next_mb_id] is not None: + with torch.profiler.record_function("recv_res_dict_from_prev_stage"): + next_pp_outputs = PPProxyTensors(self._pp_recv_dict_from_prev_stage()) + with self.copy_stream_ctx: + self.copy_stream.wait_stream(self.default_stream) + batch_result = self._pp_prep_batch_result( + mbs[next_mb_id], mb_metadata[next_mb_id], next_pp_outputs + ) + d2h_event = torch.cuda.Event() + d2h_event.record(torch.cuda.current_stream()) + + return next_pp_outputs, batch_result, d2h_event, send_output_work + + def _pp_launch_batch( + self: Scheduler, + mb_id: int, + pp_proxy_tensors: PPProxyTensors, + mb_metadata: List[Optional[PPBatchMetadata]], + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]], + ): + with torch.profiler.record_function("run_batch"): + with self.forward_stream_ctx: + self.forward_stream.wait_stream(self.default_stream) + result = self.run_batch(self.cur_batch, pp_proxy_tensors) + mb_metadata[mb_id] = PPBatchMetadata( + can_run_cuda_graph=result.can_run_cuda_graph, + ) + event = torch.cuda.Event() + event.record(torch.cuda.current_stream()) + if self.pp_group.is_last_rank: + # (last rank) buffer the outputs for async batch depth + last_rank_comm_queue.append( + ( + event, + PPProxyTensors( + self._pp_prepare_tensor_dict(result, self.cur_batch) + ), + ) + ) + return result, event + + def get_rids(self: Scheduler, req_queue: List[Req], *poll_statuses_group): + """ + Used by PP, get the required rids with the given poll statuses. + """ + polls = poll_and_all_reduce( + [req.disagg_kv_sender for req in req_queue], + self.tp_worker.get_attention_tp_cpu_group(), + ) + rids: List = [] + for poll_statuses in poll_statuses_group: + rids.append( + [ + req.rid + for req, poll in zip(req_queue, polls) + if poll in poll_statuses + ] + ) + return tuple(rids) if len(rids) > 1 else rids[0] @DynamicGradMode() - def event_loop_pp(self): - """A non-overlap scheduler loop for pipeline parallelism.""" - mbs = [None] * self.pp_size - last_mbs = [None] * self.pp_size + def event_loop_pp(self: Scheduler): + """ + A scheduler loop for pipeline parallelism. + Notes: + 1. Each stage runs in the same order and is notified by the previous stage. + 2. We use async send but sync recv to avoid desynchronization while minimizing the communication overhead. + 3. We can use async batch depth to buffer the outputs in the last stage for to allow overlapping the GPU computation and CPU processing and avoid last PP rank staggler. + + Unified Schedule: + ==================================================================== + Stage P + recv ith req from previous stage + recv ith proxy from previous stage + run ith batch + recv prev (i+1)% mb_size th outputs + process batch result of prev (i+1)% mb_size th batch (can be run in parallel with the curr batch GPU computation) + send ith req to next stage + send ith proxy to next stage + send current stage's outputs to next stage(can be stashed and delayed to send later) + + the above order can be optimized and reordered to minimize communication-related CPU stall and overhead bubbles. + + ==================================================================== + """ + self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth + mbs = [None] * self.pp_loop_size + last_mbs = [None] * self.pp_loop_size self.running_mbs = [ - ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) + ScheduleBatch(reqs=[], batch_is_full=False) + for _ in range(self.pp_loop_size) ] + mb_metadata: List[Optional[PPBatchMetadata]] = [None] * self.pp_loop_size pp_outputs: Optional[PPProxyTensors] = None + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = deque() + send_req_work = [] + send_proxy_work = [] + send_output_work = [] + event = None while True: server_is_idle = True - for mb_id in range(self.pp_size): + for mb_id in range(self.pp_loop_size): self.running_batch = self.running_mbs[mb_id] self.last_batch = last_mbs[mb_id] - - recv_reqs = self.recv_requests() - self.process_input_requests(recv_reqs) - mbs[mb_id] = self.get_next_batch_to_run() + next_first_rank_mb_id = (mb_id + self.pp_size) % self.pp_loop_size + next_mb_id = (mb_id + 1) % self.pp_loop_size + with torch.profiler.record_function("recv_requests"): + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + if not self.pp_group.is_last_rank: + self._pp_commit_comm_work(send_req_work) + with torch.profiler.record_function("send_reqs_to_next_stage"): + send_req_work = self._pp_send_pyobj_to_next_stage( + recv_reqs, + async_send=True, + ) + with torch.profiler.record_function("get_next_batch_to_run"): + mbs[mb_id] = self.get_next_batch_to_run() self.running_mbs[mb_id] = self.running_batch - - self.cur_batch = mbs[mb_id] + self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] if self.cur_batch: server_is_idle = False - result = self.run_batch(self.cur_batch) - - # (last rank) send the outputs to the next step - if self.pp_group.is_last_rank: - if self.cur_batch: - next_token_ids = result.next_token_ids - if self.cur_batch.return_logprob: - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - "extend_input_len_per_req": result.extend_input_len_per_req, - "extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req, - } - | ( - { - f"logits_output.{k}": v - for k, v in result.logits_output.__dict__.items() - } - if result.logits_output is not None - else {} - ) - ) - else: - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - } - ) - # send the output from the last round to let the next stage worker run post processing - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, - ) - - # receive outputs and post-process (filter finished reqs) the coming microbatch - next_mb_id = (mb_id + 1) % self.pp_size + pp_proxy_tensors = self._pp_recv_proxy_tensors() next_pp_outputs = None - if mbs[next_mb_id] is not None: - next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( - self.pp_group.recv_tensor_dict( - all_gather_group=self.attn_tp_group + next_batch_result = None + d2h_event = None + if self.server_args.pp_async_batch_depth > 0: + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, ) ) - mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] - logits_output_args = { - k[len("logits_output.") :]: v - for k, v in next_pp_outputs.tensors.items() - if k.startswith("logits_output.") - } - if len(logits_output_args) > 0: - logits_output = LogitsProcessorOutput(**logits_output_args) - else: - logits_output = None - - output_result = GenerationBatchResult.from_pp_proxy( - logits_output=logits_output, - next_pp_outputs=next_pp_outputs, - can_run_cuda_graph=result.can_run_cuda_graph, + self._pp_commit_comm_work(send_proxy_work) + if self.cur_batch: + result, event = self._pp_launch_batch( + mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue ) - self.process_batch_result(mbs[next_mb_id], output_result) - last_mbs[next_mb_id] = mbs[next_mb_id] - - # (not last rank) - if not self.pp_group.is_last_rank: - # carry the outputs to the next stage - # send the outputs from the last round to let the next stage worker run post processing - if pp_outputs: - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, + if self.server_args.pp_async_batch_depth == 0: + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, ) - - # send out reqs to the next stage - dp_offset = self.attn_dp_rank * self.attn_tp_size - if self.attn_tp_rank == 0: - point_to_point_pyobj( - recv_reqs, - self.pp_rank * self.tp_size + dp_offset, - self.world_group.device_group, - self.pp_rank * self.tp_size + dp_offset, - (self.pp_rank + 1) * self.tp_size + dp_offset, + ) + if mbs[next_mb_id] is not None: + d2h_event.synchronize() + with torch.profiler.record_function("process_batch_result"): + self._pp_process_batch_result( + mbs[next_mb_id], + next_batch_result, ) - - # send out proxy tensors to the next stage + last_mbs[next_mb_id] = mbs[next_mb_id] + if not self.pp_group.is_last_rank: if self.cur_batch: - self.pp_group.send_tensor_dict( - result.pp_hidden_states_proxy_tensors.tensors, - all_gather_group=self.attn_tp_group, - ) + torch.cuda.current_stream().wait_event(event) + with torch.profiler.record_function( + "send_proxy_dict_to_next_stage" + ): + send_proxy_work = self._pp_send_dict_to_next_stage( + result.pp_hidden_states_proxy_tensors.tensors, + async_send=True, + ) + + # if self.delayed_weight_sync_fn: + # self.delayed_weight_sync_fn() + # self.delayed_weight_sync_fn = None pp_outputs = next_pp_outputs # When the server is idle, self-check and re-init some states if server_is_idle: - # When the server is idle, do self-check and re-init some states - self.self_check_during_idle() + self.check_memory() + self.check_tree_cache() + self.new_token_ratio = self.init_new_token_ratio + self.maybe_sleep_on_idle() + + def process_bootstrapped_queue( + self: Scheduler, bootstrapped_rids: Optional[List[str]] + ): + # finished consensus bootstrapped reqs and prepare the waiting queue + if bootstrapped_rids is not None: + ( + good_consensus_bootstrapped_rids, + bad_consensus_bootstrapped_rids, + ) = bootstrapped_rids + good_reqs, failed_reqs = ( + self.disagg_prefill_bootstrap_queue.pop_bootstrapped( + return_failed_reqs=True, + rids_to_check=good_consensus_bootstrapped_rids + + bad_consensus_bootstrapped_rids, + ) + ) + self.waiting_queue.extend(good_reqs) + return [[req.rid for req in good_reqs], [req.rid for req in failed_reqs]] + return None + + def _pp_pd_get_bootstrapped_ids(self: Scheduler): + # communicate pre-consensus bootstrapp reqs + if self.pp_group.is_first_rank: + # First rank, pop the bootstrap reqs from the bootstrap queue + good_bootstrapped_rids, bad_bootstrapped_rids = self.get_rids( + self.disagg_prefill_bootstrap_queue.queue, + [KVPoll.WaitingForInput], + [KVPoll.Failed], + ) + else: + # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus + prev_bootstrapped_rids = self._pp_recv_pyobj_from_prev_stage() + prev_good_bootstrapped_rids, prev_bad_bootstrapped_rids = ( + prev_bootstrapped_rids + ) + curr_good_bootstrapped_rids, curr_bad_bootstrapped_rids = self.get_rids( + self.disagg_prefill_bootstrap_queue.queue, + [KVPoll.WaitingForInput], + [KVPoll.Failed], + ) + good_bootstrapped_rids = list( + set(prev_good_bootstrapped_rids) & set(curr_good_bootstrapped_rids) + ) + bad_bootstrapped_rids = list( + set(prev_bad_bootstrapped_rids) | set(curr_bad_bootstrapped_rids) + ) + return [good_bootstrapped_rids, bad_bootstrapped_rids] + + def _pp_pd_get_transferred_ids(self: Scheduler): + # get the current stage transfer success + if self.pp_group.is_first_rank: + transferred_rids = self.get_rids( + self.disagg_prefill_inflight_queue, + [KVPoll.Success, KVPoll.Failed], + ) + # if other ranks, do intersection with the previous rank's transferred rids + else: + # 2 (Release): Receive the transferred rids from the previous rank + # 1. recv previous stage's transferred reqs info + prev_transferred_rids = self._pp_recv_pyobj_from_prev_stage() + # 2. get the current stage's transferred reqs info + curr_transferred_rids = self.get_rids( + self.disagg_prefill_inflight_queue, + [KVPoll.Success, KVPoll.Failed], + ) + # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) + transferred_rids = list( + set(prev_transferred_rids) & set(curr_transferred_rids) + ) + return transferred_rids + + def _pp_pd_send_consensus_bootstrapped_ids( + self: Scheduler, + bmbs: List[List[str]], + next_first_rank_mb_id: int, + consensus_bootstrapped_rids: List[str], + bootstrapped_rids: List[str], + ): + # 3 (Release): send the release rids from last stage to the first stage + send_consensus_bootstrapped_work = [] + if self.pp_group.is_last_rank: + if bmbs[next_first_rank_mb_id] is not None: + consensus_bootstrapped_rids = bootstrapped_rids + send_consensus_bootstrapped_work = self._pp_send_pyobj_to_next_stage( + consensus_bootstrapped_rids, async_send=True + ) + # 4 (Release): send the release rids from non last rank to the next rank + else: + if consensus_bootstrapped_rids is not None: + send_consensus_bootstrapped_work = self._pp_send_pyobj_to_next_stage( + consensus_bootstrapped_rids, async_send=True + ) + return send_consensus_bootstrapped_work, consensus_bootstrapped_rids + + def _pp_pd_send_consensus_release_ids( + self: Scheduler, + tmbs: List[List[str]], + next_first_rank_mb_id: int, + release_rids: List[str], + transferred_rids: List[str], + ): + send_release_work = [] + if self.pp_group.is_last_rank: + if tmbs[next_first_rank_mb_id] is not None: + release_rids = transferred_rids + send_release_work = self._pp_send_pyobj_to_next_stage( + release_rids, async_send=True + ) + # 4 (Release): send the release rids from non last rank to the next rank + else: + if release_rids is not None: + send_release_work = self._pp_send_pyobj_to_next_stage( + release_rids, async_send=True + ) + return send_release_work, release_rids @DynamicGradMode() def event_loop_pp_disagg_prefill(self: Scheduler): """ - An event loop for the prefill server in pipeline parallelism. + This is the prefill server event loop for pipeline parallelism. - Rules: - 1. Each stage runs in the same order and is notified by the previous stage. - 2. Each send/recv operation is blocking and matched by the neighboring stage. - - Regular Schedule: - ==================================================================== - Stage i | Stage i+1 - send ith req | recv ith req - send ith proxy | recv ith proxy - send prev (i+1)th carry | recv prev (i+1)th carry - ==================================================================== + Notes: + 1. Following the same rules as the event_loop_pp. + 2. Adds extra steps for KV transfer process: bootstrap + release. Prefill Server Schedule: ==================================================================== - Stage i | Stage i+1 - send ith req | recv ith req - send ith bootstrap req | recv ith bootstrap req - send ith transferred req | recv ith transferred req - send ith proxy | recv ith proxy - send prev (i+1)th carry | recv prev (i+1)th carry - send prev (i+1)th release req | recv prev (i+1)th release req + Stage P + recv ith req from previous stage + recv ith bootstrap req from previous stage + recv ith transferred req from previous stage + recv ith proxy from previous stage + run ith batch + recv prev (i+1) % mb_size th consensus bootstrapped req from previous stage + local consensus on bootstrapped req + recv prev (i+1) % mb_size th release req from previous stage + local consensus on release req + recv prev (i+1) % mb_size th outputs + process batch result of prev (i+1)% mb_size th batch (can be run in parallel with the curr batch GPU computation) + send ith req to next stage + send ith bootstrap req to next stage + send ith transferred req to next stage + send ith proxy to next stage + send current stage's outputs to next stage (can be stashed and delayed to send later) + + the above order can be optimized and reordered to minimize communication-related CPU stall and overhead bubbles. ==================================================================== There are two additional elements compared to the regular schedule: - 1. Bootstrap Requests: - a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization. - b. The first stage polls the status and propagates the bootstrapped requests down to all other stages. - c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together. + Bootstrap Requests + Release Requests: + - Both can have local failure and need to be consensus on. PP needs to guarantee eventual consistency of local failure and flush malfunc requests out as soft error. - 2. Transferred Requests + Release Requests: - a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage. - b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory. - c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage. """ - mbs = [None] * self.pp_size - last_mbs = [None] * self.pp_size + self.pp_loop_size: int = self.pp_size + self.server_args.pp_async_batch_depth + mbs = [None] * self.pp_loop_size + last_mbs = [None] * self.pp_loop_size self.running_mbs = [ - ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) + ScheduleBatch(reqs=[], batch_is_full=False) + for _ in range(self.pp_loop_size) ] + mb_metadata: List[Optional[PPBatchMetadata]] = [None] * self.pp_loop_size pp_outputs: Optional[PPProxyTensors] = None + last_rank_comm_queue: deque[Tuple[torch.cuda.Event, PPProxyTensors]] = deque() - # Either success or failed - bootstrapped_rids: List[str] = [] + # PD additional + consensus_bootstrapped_rids: Optional[List[str]] = None transferred_rids: List[str] = [] release_rids: Optional[List[str]] = None + tmbs = [None] * self.pp_loop_size + bmbs = [None] * self.pp_loop_size - # transferred microbatch - tmbs = [None] * self.pp_size - - ENABLE_RELEASE = True # For debug + send_req_work = [] + send_bootstrapped_work = [] + send_consensus_bootstrapped_work = [] + send_proxy_work = [] + send_output_work = [] + send_release_work = [] + send_transfer_work = [] while True: server_is_idle = True - - for mb_id in range(self.pp_size): + for mb_id in range(self.pp_loop_size): self.running_batch = self.running_mbs[mb_id] self.last_batch = last_mbs[mb_id] + next_first_rank_mb_id = (mb_id + self.pp_size) % self.pp_loop_size + next_mb_id = (mb_id + 1) % self.pp_loop_size - recv_reqs = self.recv_requests() + next_pp_outputs = None + next_release_rids = None + next_consensus_bootstrapped_rids = None + d2h_event = None + next_batch_result = None + recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) - if self.pp_group.is_first_rank: - # First rank, pop the bootstrap reqs from the bootstrap queue - bootstrapped_reqs, failed_reqs = ( - self.disagg_prefill_bootstrap_queue.pop_bootstrapped( - return_failed_reqs=True - ) - ) - bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [ - req.rid for req in failed_reqs - ] - self.waiting_queue.extend(bootstrapped_reqs) - else: - # Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus - bootstrapped_rids = self.recv_pyobj_from_prev_stage() - bootstrapped_reqs = ( - self.disagg_prefill_bootstrap_queue.pop_bootstrapped( - rids_to_check=bootstrapped_rids - ) - ) - self.waiting_queue.extend(bootstrapped_reqs) - - if self.pp_group.is_first_rank: - transferred_rids = self.get_transferred_rids() - # if other ranks, - else: - # 1. recv previous stage's transferred reqs info - prev_transferred_rids = self.recv_pyobj_from_prev_stage() - # 2. get the current stage's transferred reqs info - curr_transferred_rids = self.get_transferred_rids() - # 3. new consensus rids = intersection(previous consensus rids, transfer finished rids) - transferred_rids = list( - set(prev_transferred_rids) & set(curr_transferred_rids) - ) + if not self.pp_group.is_last_rank: + self._pp_commit_comm_work(send_req_work) + + bootstrapped_rids = self._pp_pd_get_bootstrapped_ids() + bmbs[mb_id] = bootstrapped_rids + self._pp_commit_comm_work(send_bootstrapped_work) + transferred_rids = self._pp_pd_get_transferred_ids() + self._pp_commit_comm_work(send_transfer_work) tmbs[mb_id] = transferred_rids self.process_prefill_chunk() - batch = self.get_new_batch_prefill() if self.require_mlp_sync: batch = self.prepare_mlp_sync_batch(batch) mbs[mb_id] = batch - self.running_mbs[mb_id] = self.running_batch - self.cur_batch = mbs[mb_id] + self.cur_batch: Optional[ScheduleBatch] = mbs[mb_id] if self.cur_batch: server_is_idle = False - result = self.run_batch(self.cur_batch) + pp_proxy_tensors = self._pp_recv_proxy_tensors() - # send the outputs to the next step - if self.pp_group.is_last_rank: - if self.cur_batch: - next_token_ids = result.next_token_ids - pp_outputs = PPProxyTensors( - { - "next_token_ids": next_token_ids, - } - ) - # send the output from the last round to let the next stage worker run post processing - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, + if self.server_args.pp_async_batch_depth > 0: + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, ) - - if ENABLE_RELEASE: - if self.pp_group.is_last_rank: - # At the last stage, all stages has reached the consensus to release memory for transferred_rids - release_rids = transferred_rids - # send to the first rank - self.send_pyobj_to_next_stage(release_rids) - - # receive outputs and post-process (filter finished reqs) the coming microbatch - next_mb_id = (mb_id + 1) % self.pp_size - next_pp_outputs = None - next_release_rids = None - - if mbs[next_mb_id] is not None: - next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( - self.pp_group.recv_tensor_dict( - all_gather_group=self.attn_tp_group + ) + self._pp_commit_comm_work(send_proxy_work) + if self.cur_batch: + result, event = self._pp_launch_batch( + mb_id, pp_proxy_tensors, mb_metadata, last_rank_comm_queue + ) + if self.server_args.pp_async_batch_depth == 0: + self._pp_commit_comm_work(work=send_output_work) + next_pp_outputs, next_batch_result, d2h_event, send_output_work = ( + self._pp_send_recv_and_preprocess_output_tensors( + next_first_rank_mb_id, + next_mb_id, + mbs, + mb_metadata, + last_rank_comm_queue, + pp_outputs, ) ) - mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] - output_result = GenerationBatchResult( - logits_output=None, - pp_hidden_states_proxy_tensors=None, - next_token_ids=next_pp_outputs["next_token_ids"], - extend_input_len_per_req=None, - extend_logprob_start_len_per_req=None, - can_run_cuda_graph=result.can_run_cuda_graph, + send_consensus_bootstrapped_work, consensus_bootstrapped_rids = ( + self._pp_pd_send_consensus_bootstrapped_ids( + bmbs, + next_first_rank_mb_id, + consensus_bootstrapped_rids, + bootstrapped_rids, ) - self.process_batch_result_disagg_prefill( - mbs[next_mb_id], output_result + ) + send_release_work, release_rids = ( + self._pp_pd_send_consensus_release_ids( + tmbs, next_first_rank_mb_id, release_rids, transferred_rids ) + ) + if bmbs[next_mb_id] is not None: + next_consensus_bootstrapped_rids = ( + self._pp_recv_pyobj_from_prev_stage() + ) + next_consensus_bootstrapped_rids = self.process_bootstrapped_queue( + next_consensus_bootstrapped_rids + ) + self._pp_commit_comm_work(send_consensus_bootstrapped_work) + if tmbs[next_mb_id] is not None: + next_release_rids = self._pp_recv_pyobj_from_prev_stage() + self._pp_commit_comm_work(send_release_work) + # post-process the coming microbatch + if mbs[next_mb_id] is not None: + d2h_event.synchronize() + self._pp_process_batch_result( + mbs[next_mb_id], + next_batch_result, + ) last_mbs[next_mb_id] = mbs[next_mb_id] - if ENABLE_RELEASE: - if tmbs[next_mb_id] is not None: - # recv consensus rids from the previous rank - next_release_rids = self.recv_pyobj_from_prev_stage() - self.process_disagg_prefill_inflight_queue(next_release_rids) - - # carry the outputs to the next stage - if not self.pp_group.is_last_rank: - if pp_outputs: - # send the outputs from the last round to let the next stage worker run post processing - self.pp_group.send_tensor_dict( - pp_outputs.tensors, - all_gather_group=self.attn_tp_group, - ) - if ENABLE_RELEASE: - if release_rids is not None: - self.send_pyobj_to_next_stage(release_rids) - + if tmbs[next_mb_id] is not None: + self.process_disagg_prefill_inflight_queue(next_release_rids) if not self.pp_group.is_last_rank: - # send out reqs to the next stage - self.send_pyobj_to_next_stage(recv_reqs) - self.send_pyobj_to_next_stage(bootstrapped_rids) - self.send_pyobj_to_next_stage(transferred_rids) - - # send out proxy tensors to the next stage + send_req_work = self._pp_send_pyobj_to_next_stage( + recv_reqs, async_send=True + ) + send_bootstrapped_work = self._pp_send_pyobj_to_next_stage( + bootstrapped_rids, async_send=True + ) + send_transfer_work = self._pp_send_pyobj_to_next_stage( + transferred_rids, async_send=True + ) if self.cur_batch: - # FIXME(lsyin): remove this assert - assert result.pp_hidden_states_proxy_tensors.tensors is not None - self.pp_group.send_tensor_dict( + torch.cuda.current_stream().wait_event(event) + send_proxy_work = self._pp_send_dict_to_next_stage( result.pp_hidden_states_proxy_tensors.tensors, - all_gather_group=self.attn_tp_group, + async_send=True, ) + if hasattr(self, "delayed_weight_sync_fn"): + self.delayed_weight_sync_fn() + self.delayed_weight_sync_fn = None + pp_outputs = next_pp_outputs release_rids = next_release_rids + consensus_bootstrapped_rids = next_consensus_bootstrapped_rids self.running_batch.batch_is_full = False - if not ENABLE_RELEASE: - if len(self.disagg_prefill_inflight_queue) > 0: - self.process_disagg_prefill_inflight_queue() - # When the server is idle, self-check and re-init some states if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0: self.check_memory() self.check_tree_cache() self.new_token_ratio = self.init_new_token_ratio + self.maybe_sleep_on_idle() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 758f0ffc9571..8853f5ba198d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -370,6 +370,7 @@ def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, forward_batch: Optional[ForwardBatch] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, is_verify: bool = False, skip_attn_backend_init=False, ) -> GenerationBatchResult: @@ -385,14 +386,6 @@ def forward_batch_generation( # FIXME(lsyin): unify the interface of forward_batch assert forward_batch is not None - pp_proxy_tensors = None - if not self.pp_group.is_first_rank: - pp_proxy_tensors = PPProxyTensors( - self.pp_group.recv_tensor_dict( - all_gather_group=self.get_attention_tp_group() - ) - ) - if self.pp_group.is_last_rank: if self.is_dllm(): return self._forward_batch_generation_dllm(forward_batch) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 70e03119a355..9cba37440250 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -287,6 +287,7 @@ class ServerArgs: max_queued_requests: Optional[int] = None max_total_tokens: Optional[int] = None chunked_prefill_size: Optional[int] = None + enable_dynamic_chunking: bool = False max_prefill_tokens: int = 16384 schedule_policy: str = "fcfs" enable_priority_scheduling: bool = False @@ -305,6 +306,7 @@ class ServerArgs: tp_size: int = 1 pp_size: int = 1 pp_max_micro_batch_size: Optional[int] = None + pp_async_batch_depth: int = 0 stream_interval: int = 1 stream_output: bool = False random_seed: Optional[int] = None @@ -2385,6 +2387,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.chunked_prefill_size, help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.", ) + parser.add_argument( + "--enable-dynamic-chunking", + action="store_true", + default=ServerArgs.enable_dynamic_chunking, + help="Enable dynamic chunk size adjustment for pipeline parallelism. When enabled, chunk sizes are dynamically calculated based on fitted function to maintain consistent execution time across chunks.", + ) parser.add_argument( "--max-prefill-tokens", type=int, @@ -2493,6 +2501,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.pp_max_micro_batch_size, help="The maximum micro batch size in pipeline parallelism.", ) + parser.add_argument( + "--pp-async-batch-depth", + type=int, + default=ServerArgs.pp_async_batch_depth, + help="The async batch depth of pipeline parallelism.", + ) parser.add_argument( "--stream-interval", type=int, diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 5b44d3d19fbf..e032df97337d 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -1262,41 +1262,61 @@ def point_to_point_pyobj( group: Optional[torch.distributed.ProcessGroup] = None, src: int = 0, dst: int = 1, + async_send: bool = False, ): - """Send data from src to dst in group using DeviceToDevice communication.""" - device = torch.get_device_module().current_device() + """Send data from src to dst in group.""" + from sglang.srt.distributed.parallel_state import P2PWork + + if async_send: + send_func = dist.isend + else: + send_func = dist.send if rank == src: + p2p_works = [] if len(data) == 0: - tensor_size = torch.tensor([0], dtype=torch.long, device=device) - dist.send(tensor_size, dst=dst, group=group) + tensor_size = torch.tensor( + [0], + dtype=torch.long, + ) + work = send_func(tensor_size, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_size)) else: serialized_data = pickle.dumps(data) size = len(serialized_data) tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) - ).to( - device=device - ) # Move to GPU - tensor_size = torch.tensor([size], dtype=torch.long, device=device) + ) + tensor_size = torch.tensor([size], dtype=torch.long) - dist.send(tensor_size, dst=dst, group=group) - dist.send(tensor_data, dst=dst, group=group) - return data + work = send_func(tensor_size, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_size)) + work = send_func(tensor_data, dst, group=group) + if async_send: + p2p_works.append(P2PWork(work, tensor_data)) + return p2p_works elif rank == dst: - tensor_size = torch.tensor([0], dtype=torch.long, device=device) - dist.recv(tensor_size, src=src, group=group) + tensor_size = torch.tensor( + [0], + dtype=torch.long, + ) + work = dist.irecv(tensor_size, src=src, group=group) + work.wait() size = tensor_size.item() if size == 0: return [] - tensor_data = torch.empty(size, dtype=torch.uint8, device=device) - dist.recv(tensor_data, src=src, group=group) + tensor_data = torch.empty( + size, + dtype=torch.uint8, + ) + work = dist.irecv(tensor_data, src=src, group=group) + work.wait() - serialized_data = bytes( - tensor_data.cpu().numpy() - ) # Move back to host for deserialization + serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) return data