diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index d6db74ec0d9e..76cc76e9354a 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -180,9 +180,8 @@ def process_prebuilt( from sglang.srt.managers.overlap_utils import FutureIndices spec_info.future_indices = FutureIndices(indices=self.req_pool_indices) - future_map.store_to_map_for_new_batch( - spec_info.future_indices, spec_info - ) + future_map.publish(spec_info.future_indices, spec_info.new_seq_lens) + future_map.stash(spec_info.future_indices, spec_info) self.spec_info = spec_info else: # Non-spec: input_ids feeds the next decode forward directly. diff --git a/python/sglang/srt/managers/overlap_utils.py b/python/sglang/srt/managers/overlap_utils.py index 795f01b75bae..99b951fab629 100644 --- a/python/sglang/srt/managers/overlap_utils.py +++ b/python/sglang/srt/managers/overlap_utils.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -10,7 +10,6 @@ if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch - from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -57,21 +56,25 @@ def __init__( self.req_pool_size = req_to_token_pool.req_to_token.shape[0] if self.spec_algo.is_none(): - self.buf_initialized = True self.token_ids_buf = torch.empty( (self.req_pool_size,), dtype=torch.int64, device=self.device ) else: - self.buf_initialized = False + # Schedule-consumed buf, eager fixed dtype. + self.new_seq_lens_buf = torch.empty( + (self.req_pool_size,), dtype=torch.int64, device=self.device + ) + # Forward-only bufs are lazy (worker-dependent shape). + self._forward_buf_initialized = False - def _lazy_init_buf(self, draft_input: EagleDraftInput): - self.buf_initialized = True + # Fences the schedule-consumed buf fields. + self.publish_ready: Optional[torch.cuda.Event] = None + + def _lazy_init_forward_buf(self, draft_input: EagleDraftInput): + self._forward_buf_initialized = True topk_p0 = draft_input.topk_p[0] topk_index0 = draft_input.topk_index[0] - bonus_token0 = draft_input.bonus_tokens[0] - new_seq_lens0 = draft_input.new_seq_lens[0] - self.topk_p_buf = torch.empty( (self.req_pool_size, *topk_p0.shape), dtype=topk_p0.dtype, @@ -83,16 +86,8 @@ def _lazy_init_buf(self, draft_input: EagleDraftInput): device=self.device, ) self.bonus_tokens_buf = torch.empty( - (self.req_pool_size, *bonus_token0.shape), - dtype=bonus_token0.dtype, - device=self.device, - ) - self.new_seq_lens_buf = torch.empty( - (self.req_pool_size, *new_seq_lens0.shape), - dtype=new_seq_lens0.dtype, - device=self.device, + (self.req_pool_size,), dtype=torch.int64, device=self.device ) - if spec_need_hidden_states(): hidden_states0 = draft_input.hidden_states[0] self.hidden_states_buf = torch.empty( @@ -118,51 +113,54 @@ def resolve_future(self, batch: ScheduleBatch): draft_input.topk_index = self.topk_index_buf[indices] draft_input.bonus_tokens = self.bonus_tokens_buf[indices] draft_input.new_seq_lens = self.new_seq_lens_buf[indices] + # Resolve seq_lens placeholder (-indices) to the post-verify view. + batch.seq_lens = draft_input.new_seq_lens if spec_need_hidden_states(): draft_input.hidden_states = self.hidden_states_buf[indices] - def store_to_map( - self, future_indices: FutureIndices, batch_result: GenerationBatchResult - ): + def resolve_seq_lens_cpu(self, batch: ScheduleBatch) -> None: + fi = batch.spec_info.future_indices if batch.spec_info is not None else None + if fi is None: + return + if self.publish_ready is not None: + self.publish_ready.wait() + batch.seq_lens_cpu = self.new_seq_lens_buf[fi.indices].cpu() + batch.seq_lens_sum = int(batch.seq_lens_cpu.sum()) + + def publish( + self, future_indices: FutureIndices, new_seq_lens: torch.Tensor + ) -> None: + """Store schedule-consumed fields and signal publish_ready.""" if self.spec_algo.is_none(): - indices = future_indices.indices - if indices.shape[0] == 0: - # DP attention idle rank: indices is empty but next_token_ids - # may carry padded values from sibling ranks. Nothing to store - # for this rank. - return - # next_token_ids is int32; buf is int64. Slice assignment used to - # cast implicitly, but advanced indexing requires an explicit match. - self.token_ids_buf[indices] = batch_result.next_token_ids.to(torch.int64) - else: - draft_input: EagleDraftInput = batch_result.next_draft_input - self.store_to_map_for_new_batch(future_indices, draft_input) - - def store_to_map_for_new_batch( - self, future_indices: FutureIndices, draft_input: EagleDraftInput - ): + return + indices = future_indices.indices + if indices.shape[0] == 0: + return # DP idle + self.new_seq_lens_buf[indices] = new_seq_lens.to(self.new_seq_lens_buf.dtype) + if self.publish_ready is None: + self.publish_ready = torch.get_device_module(self.device).Event() + self.publish_ready.record() + + def stash(self, future_indices: FutureIndices, payload) -> None: + """Store forward-only fields for the next forward batch to pick up.""" indices = future_indices.indices if indices.shape[0] == 0: - # DP idle rank: draft_input fields are empty stubs without a usable - # shape, so _lazy_init_buf's shape peek (draft_input.topk_p[0]) - # would IndexError. Defer init until a real batch arrives. + return # DP idle + if self.spec_algo.is_none(): + # next_token_ids is int32; buf is int64. Advanced indexing requires + # an explicit cast. + self.token_ids_buf[indices] = payload.to(torch.int64) return - if not self.buf_initialized: - self._lazy_init_buf(draft_input) - - # Slice assignment used to coerce src dtype to buf dtype implicitly; - # advanced index requires an explicit cast. bonus_tokens / new_seq_lens - # in particular differ across disagg (int64) and forward (int32) paths. - self.topk_p_buf[indices] = draft_input.topk_p.to(self.topk_p_buf.dtype) - self.topk_index_buf[indices] = draft_input.topk_index.to( - self.topk_index_buf.dtype - ) + draft_input: EagleDraftInput = payload + if not self._forward_buf_initialized: + self._lazy_init_forward_buf(draft_input) self.bonus_tokens_buf[indices] = draft_input.bonus_tokens.to( self.bonus_tokens_buf.dtype ) - self.new_seq_lens_buf[indices] = draft_input.new_seq_lens.to( - self.new_seq_lens_buf.dtype + self.topk_p_buf[indices] = draft_input.topk_p.to(self.topk_p_buf.dtype) + self.topk_index_buf[indices] = draft_input.topk_index.to( + self.topk_index_buf.dtype ) if spec_need_hidden_states(): self.hidden_states_buf[indices] = draft_input.hidden_states.to( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index cb616a0396f3..b54e16f7e118 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -2420,8 +2420,7 @@ def prepare_for_decode(self): self.seq_lens.add_(1) self.seq_lens_cpu.add_(1) self.orig_seq_lens.add_(1) - # Defer compute to refresh_seq_lens_cpu (either pre-forward in scheduler.py - # or lazily in ForwardBatch.init_new). + # Sum is recomputed lazily by ForwardBatch.init_new. self.seq_lens_sum = None if self.hisparse_coordinator is not None: @@ -2447,25 +2446,6 @@ def prepare_for_decode(self): .to(device=self.device, non_blocking=True) ) - def maybe_wait_verify_done(self): - # Use event.wait() (stream-level wait) instead of .synchronize() - # (CPU block). Schedule-stream prep ops following this call get - # ordered after the forward-stream verify via the wait; CPU is not - # blocked. Subsequent .cpu()/.item() naturally sync the stream. - if self.is_spec_v2: - draft_input: EagleDraftInput = self.spec_info - if draft_input.verify_done is not None: - draft_input.verify_done.wait() - - def refresh_seq_lens_cpu(self, sync: bool = True): - # sync=True: D2H from seq_lens (needed when seq_lens_cpu is stale - # relative to seq_lens, i.e. spec v2's mid-forward GPU rebind). - # sync=False: caller asserts seq_lens_cpu already fresh — skip D2H, - # only recompute the cached sum. - if sync and self.is_spec_v2: - self.seq_lens_cpu = self.seq_lens.cpu() - self.seq_lens_sum = int(self.seq_lens_cpu.sum()) - def filter_batch( self, chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, @@ -2473,10 +2453,6 @@ def filter_batch( # FIXME(lsyin): deprecate this API after spec v1 is deprecated v1_spec_info_filtered: Optional[bool] = False, ): - # FIXME(lsyin): used here to get the correct seq_lens - # The batch has been launched but we need it verified to get correct next batch info - self.maybe_wait_verify_done() - if keep_indices is None: if isinstance(chunked_req_to_exclude, Req): chunked_req_to_exclude = [chunked_req_to_exclude] @@ -2516,8 +2492,7 @@ def filter_batch( self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None - # Defer compute to refresh_seq_lens_cpu (either pre-forward in scheduler.py - # or lazily in ForwardBatch.init_new). + # Sum is recomputed lazily by ForwardBatch.init_new. self.seq_lens_sum = None if self.input_ids is not None: @@ -2553,15 +2528,6 @@ def filter_batch( ) def merge_batch(self, other: "ScheduleBatch"): - # In the regular scheduler path: - # 1) self is always prefill, whose seq_lens is not a future - # 2) other is always decode, which is finished in previous step - # so verify_done is already synced and this is a no-op. - # In disagg decode + overlap, merge_batch can be called before - # filter_batch, so running_batch.seq_lens may still be a forward_stream - # future. Synchronize here to avoid a cross-stream data race. - self.maybe_wait_verify_done() - # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it # needs to be called with pre-merged Batch.reqs. @@ -2578,8 +2544,7 @@ def merge_batch(self, other: "ScheduleBatch"): self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu]) self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens]) self.out_cache_loc = None - # Defer compute to refresh_seq_lens_cpu (either pre-forward in scheduler.py - # or lazily in ForwardBatch.init_new). + # Sum is recomputed lazily by ForwardBatch.init_new. self.seq_lens_sum = None if self.input_ids is not None: self.input_ids = torch.cat([self.input_ids, other.input_ids]) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8a08f03dd01d..c4d289826854 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -22,6 +22,7 @@ import time from collections import deque from contextlib import contextmanager, nullcontext +from functools import partial from http import HTTPStatus from typing import Any, Deque, Dict, List, Optional, Tuple, Union @@ -2834,18 +2835,36 @@ def run_batch( # Run forward if self.is_generation: if self.enable_overlap: - # Refresh BEFORE _overlap_forward_isolation so snapshot - # captures fresh values and restore preserves them. - batch.refresh_seq_lens_cpu() + # Spec v2 pre-isolation CPU mirror prep: D2H new_seq_lens_buf + # into batch.seq_lens_cpu + set seq_lens_sum. For non-spec_v2, + # ForwardBatch.init_new lazily computes the sum. + if batch.is_spec_v2: + # FIXME: make this optional to different backends. + self.future_map.resolve_seq_lens_cpu(batch) with self._overlap_forward_isolation(batch): future_indices = FutureIndices(indices=batch.req_pool_indices) + # Spec_v2 worker fires this between sample-end and + # draft_extend; publish moves the fence to verify-end so + # schedule prep can overlap with draft_extend. + fwd_kwargs = ( + { + "on_verify_complete": partial( + self.future_map.publish, future_indices + ) + } + if batch.is_spec_v2 + else {} + ) + with self.forward_stream_ctx: self.forward_stream.wait_stream(self.schedule_stream) self.future_map.resolve_future(batch) # FIXME: pp is not compatible with overlap - batch_result = self.model_worker.forward_batch_generation(batch) + batch_result = self.model_worker.forward_batch_generation( + batch, **fwd_kwargs + ) # Park any refs the worker wants kept alive 2 iters # (cross-stream tensor lifetime; pinned in the same # ring slot as the SB attr snapshot). @@ -2856,7 +2875,12 @@ def run_batch( # FIXME(lsyin): maybe move this to forward_batch_generation batch_result.copy_done = self.device_module.Event() if batch_result.delay_sample_func is None: - self.future_map.store_to_map(future_indices, batch_result) + stash_payload = ( + batch_result.next_draft_input + if batch.is_spec_v2 + else batch_result.next_token_ids + ) + self.future_map.stash(future_indices, stash_payload) batch_result.copy_to_cpu( return_logprob=batch.return_logprob, return_hidden_states=batch.return_hidden_states, @@ -2869,15 +2893,11 @@ def run_batch( batch.input_ids = -future_indices.indices if batch.is_spec_v2: - # FIXME(lsyin): tmp code for spec v2 - # We only keep future indices for next draft input - batch.spec_info = batch_result.next_draft_input batch.spec_info.future_indices = future_indices - - # The future value, usually for next batch preparation - # Current implementation strictly synchronizes the seq_lens - batch.seq_lens = batch_result.next_draft_input.new_seq_lens + # Schedule-stream sentinel between iters; next iter's + # resolve_future reassigns batch.seq_lens from new_seq_lens_buf. + batch.seq_lens = -future_indices.indices elif self.enable_pdmux and batch.forward_mode.is_split_prefill(): batch_result = self.tp_worker.forward_batch_split_prefill(batch) if isinstance(batch_result.next_token_ids, torch.Tensor): @@ -2961,7 +2981,10 @@ def launch_batch_sample_if_needed( self.forward_stream.wait_stream(self.schedule_stream) _batch_result = batch_result.delay_sample_func() assert _batch_result is batch_result - self.future_map.store_to_map(batch_result.future_indices, batch_result) + # Delay-sample is non-spec only; stash takes next_token_ids tensor. + self.future_map.stash( + batch_result.future_indices, batch_result.next_token_ids + ) batch_result.copy_to_cpu( return_logprob=self.cur_batch.return_logprob, return_hidden_states=self.cur_batch.return_hidden_states, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a3e86dc5d973..89f88c235acf 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -502,7 +502,7 @@ def init_new( seq_lens_cpu = batch.seq_lens_cpu if batch.seq_lens_sum is None: - batch.refresh_seq_lens_cpu(sync=False) + batch.seq_lens_sum = int(batch.seq_lens_cpu.sum()) ret = cls( forward_mode=batch.forward_mode, diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index f7d767da193f..bcdeaf0693c1 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -696,7 +696,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): # V2 overlap worker only future_indices: Optional[FutureIndices] = None new_seq_lens: Optional[torch.Tensor] = None - verify_done: Optional[torch.cuda.Event] = None # V2 reuses `EagleDraftInput` across phases (V1 has a separate # `EagleDraftExtendInput` for these). Set during V2's draft-extend. num_correct_drafts: Optional[torch.Tensor] = None diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index fb445177ba9b..52d4ad4aabe5 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -96,9 +96,6 @@ def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch): bs = batch.batch_size() - # Now seq_lens is correct - batch.maybe_wait_verify_done() - # Accumulate penalty # This is a relaxed version of penalties for speculative decoding. if batch.sampling_info.penalizer_orchestrator.is_required: @@ -231,9 +228,7 @@ def prepare_for_extend_to_fill_draft_kvcache( batch.input_ids = predict batch.seq_lens = batch.seq_lens + num_draft_tokens batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens - # seq_lens_cpu was just CPU-updated in tandem — sync=False avoids - # a redundant D2H on the draft hot path. - batch.refresh_seq_lens_cpu(sync=False) + batch.seq_lens_sum = int(batch.seq_lens_cpu.sum()) batch.extend_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))] batch.prefix_lens = seq_lens_cpu_.tolist() batch.extend_num_tokens = extend_num_tokens diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 57ac6d078965..6e81c5c25ee4 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -756,7 +756,7 @@ def clear_cache_pool(self): # allocator and kv cache pool are shared with target worker, which are cleared in scheduler pass - def forward_batch_generation(self, batch: ScheduleBatch): + def forward_batch_generation(self, batch: ScheduleBatch, on_verify_complete=None): if batch.forward_mode.is_extend() or batch.is_extend_in_batch: # Target prefill target_capture_mode = ( @@ -767,6 +767,10 @@ def forward_batch_generation(self, batch: ScheduleBatch): batch.capture_hidden_mode = target_capture_mode batch_output = self.target_worker.forward_batch_generation(batch) + # Publish before draft_extend so the fence is at target-end. + if on_verify_complete is not None: + on_verify_complete(batch.seq_lens) + # Draft prefill with ( self.draft_worker.draft_tp_context( @@ -809,6 +813,9 @@ def forward_batch_generation(self, batch: ScheduleBatch): assert verify_input.is_verify_input() batch.spec_info = verify_input batch_output = self.verify(batch) + # Publish before draft_extend so the fence is at verify-end. + if on_verify_complete is not None: + on_verify_complete(batch_output.next_draft_input.new_seq_lens) with ( self.draft_worker.draft_tp_context( self.draft_worker.draft_runner.tp_group @@ -1068,9 +1075,6 @@ def verify(self, batch: ScheduleBatch): batch, verify_input, accept_lens, accept_index, bs ) - verify_done = torch.get_device_module(self.device).Event() - verify_done.record() - if not batch.forward_mode.is_idle(): accept_tokens = predict[accept_index] bonus_tokens = torch.empty_like(accept_lens, dtype=torch.int32) @@ -1089,15 +1093,12 @@ def verify(self, batch: ScheduleBatch): ) next_draft_input = EagleDraftInput( - bonus_tokens=bonus_tokens, - new_seq_lens=new_seq_lens, - verify_done=verify_done, + bonus_tokens=bonus_tokens, new_seq_lens=new_seq_lens ) # verify_forward_batch transitively holds verify-time GPU tensors # (draft_token / out_cache_loc / ...) that must outlive the imminent - # batch.input_ids rebind in prepare_for_extend_to_fill_draft_kvcache, - # until the next iter's verify_done.synchronize() in filter_batch. + # batch.input_ids rebind in prepare_for_extend_to_fill_draft_kvcache. # Scheduler pins it in batch_record_buf for the 2-iter window. return GenerationBatchResult( logits_output=logits_output, diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py index cb341318a422..35f189c05c5a 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker_v2.py @@ -669,7 +669,7 @@ def clear_cache_pool(self): # allocator and kv cache pool are shared with target worker, which are cleared in scheduler pass - def forward_batch_generation(self, batch: ScheduleBatch): + def forward_batch_generation(self, batch: ScheduleBatch, on_verify_complete=None): if batch.forward_mode.is_extend() or batch.is_extend_in_batch: # Target prefill target_capture_mode = ( @@ -680,6 +680,10 @@ def forward_batch_generation(self, batch: ScheduleBatch): batch.capture_hidden_mode = target_capture_mode batch_output = self.target_worker.forward_batch_generation(batch) + # Publish before draft_extend so the fence is at target-end. + if on_verify_complete is not None: + on_verify_complete(batch.seq_lens) + # Chain-style MTP needs FULL to get all-token hidden states; # non-chain only needs LAST (the target model's hidden states). batch_output.next_draft_input = self.draft_worker._draft_extend_for_prefill( @@ -706,6 +710,9 @@ def forward_batch_generation(self, batch: ScheduleBatch): assert verify_input.is_verify_input() batch.spec_info = verify_input batch_output = self.verify(batch) + # Publish before draft_extend so the fence is at verify-end. + if on_verify_complete is not None: + on_verify_complete(batch_output.next_draft_input.new_seq_lens) self.draft_worker._draft_extend_for_decode(batch, batch_output) return batch_output @@ -767,8 +774,6 @@ def verify( accept_index, ) = verify_input.sample(batch, logits_output) new_seq_lens = batch.seq_lens + accept_lens - verify_done = torch.get_device_module(self.device).Event() - verify_done.record() if not batch.forward_mode.is_idle(): accept_tokens = predict[accept_index] @@ -790,7 +795,6 @@ def verify( next_draft_input = EagleDraftInput( bonus_tokens=bonus_tokens, new_seq_lens=new_seq_lens, - verify_done=verify_done, ) # verify_forward_batch transitively holds verify-time GPU tensors that # must outlive the imminent batch.input_ids rebind; scheduler pins it