Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
cf28d0a
spec_v2: seq_lens through future_map; drop verify_done.wait
hnyls2002 May 21, 2026
3da55d9
future_map: resolve_seq_lens_cpu helper; allocate verify_done in sche…
hnyls2002 May 21, 2026
c6bd952
future: attach done event to FutureIndices, drop EagleDraftInput.veri…
hnyls2002 May 21, 2026
ed61801
future_map: own _last_store_done; drop per-future done
hnyls2002 May 21, 2026
ed42f62
Merge branch 'main' into lsyin/draft-prefix-lens
hnyls2002 May 21, 2026
427a1b4
fence at verify-end via on_verify_complete; split buf writes
hnyls2002 May 21, 2026
3e94adc
store_to_map always writes all fields; fence flag drives record
hnyls2002 May 21, 2026
7905aa3
Merge branch 'main' into lsyin/draft-prefix-lens
hnyls2002 May 21, 2026
7ced612
skip redundant new_seq_lens write when post_verify ran
hnyls2002 May 21, 2026
a915bbc
Merge branch 'main' into lsyin/draft-prefix-lens
hnyls2002 May 21, 2026
fd3153b
trim comments in FutureMap
hnyls2002 May 21, 2026
ab9b4df
uniform on_verify_complete; drop fence flag
hnyls2002 May 21, 2026
af9eb45
rename: _last_store_done -> _post_verify_buf_ready
hnyls2002 May 21, 2026
2783b6f
eager-alloc schedule-consumed bufs; drop first-iter catch-up
hnyls2002 May 21, 2026
7500073
rename future_map.store_* to publish/stash
hnyls2002 May 21, 2026
2964be5
trim init comments to one-liners
hnyls2002 May 21, 2026
e707182
move bonus_tokens_buf alloc into lazy forward-only init
hnyls2002 May 21, 2026
107b788
inline refresh_seq_lens_cpu and event record; rename to publish_ready
hnyls2002 May 21, 2026
c08b14d
gate resolve_seq_lens_cpu by is_spec_v2
hnyls2002 May 21, 2026
8ec9115
fix
hnyls2002 May 21, 2026
4accaab
refine publish/stash docstrings
hnyls2002 May 21, 2026
74a625f
use functools.partial for publish callback
hnyls2002 May 21, 2026
62b9185
trim on_verify_complete callsite comments
hnyls2002 May 21, 2026
3f8eedc
Merge branch 'main' into lsyin/draft-prefix-lens
hnyls2002 May 21, 2026
a8b4b3c
refine seq_lens sentinel comment
hnyls2002 May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,8 @@ def process_prebuilt(
from sglang.srt.managers.overlap_utils import FutureIndices

spec_info.future_indices = FutureIndices(indices=self.req_pool_indices)
future_map.store_to_map_for_new_batch(
spec_info.future_indices, spec_info
)
future_map.publish(spec_info.future_indices, spec_info.new_seq_lens)
future_map.stash(spec_info.future_indices, spec_info)
self.spec_info = spec_info
else:
# Non-spec: input_ids feeds the next decode forward directly.
Expand Down
102 changes: 50 additions & 52 deletions python/sglang/srt/managers/overlap_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import torch

Expand All @@ -10,7 +10,6 @@

if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
Expand Down Expand Up @@ -57,21 +56,25 @@ def __init__(
self.req_pool_size = req_to_token_pool.req_to_token.shape[0]

if self.spec_algo.is_none():
self.buf_initialized = True
self.token_ids_buf = torch.empty(
(self.req_pool_size,), dtype=torch.int64, device=self.device
)
else:
self.buf_initialized = False
# Schedule-consumed buf, eager fixed dtype.
self.new_seq_lens_buf = torch.empty(
(self.req_pool_size,), dtype=torch.int64, device=self.device
)
# Forward-only bufs are lazy (worker-dependent shape).
self._forward_buf_initialized = False

def _lazy_init_buf(self, draft_input: EagleDraftInput):
self.buf_initialized = True
# Fences the schedule-consumed buf fields.
self.publish_ready: Optional[torch.cuda.Event] = None

def _lazy_init_forward_buf(self, draft_input: EagleDraftInput):
self._forward_buf_initialized = True

topk_p0 = draft_input.topk_p[0]
topk_index0 = draft_input.topk_index[0]
bonus_token0 = draft_input.bonus_tokens[0]
new_seq_lens0 = draft_input.new_seq_lens[0]

self.topk_p_buf = torch.empty(
(self.req_pool_size, *topk_p0.shape),
dtype=topk_p0.dtype,
Expand All @@ -83,16 +86,8 @@ def _lazy_init_buf(self, draft_input: EagleDraftInput):
device=self.device,
)
self.bonus_tokens_buf = torch.empty(
(self.req_pool_size, *bonus_token0.shape),
dtype=bonus_token0.dtype,
device=self.device,
)
self.new_seq_lens_buf = torch.empty(
(self.req_pool_size, *new_seq_lens0.shape),
dtype=new_seq_lens0.dtype,
device=self.device,
(self.req_pool_size,), dtype=torch.int64, device=self.device
)

if spec_need_hidden_states():
hidden_states0 = draft_input.hidden_states[0]
self.hidden_states_buf = torch.empty(
Expand All @@ -118,51 +113,54 @@ def resolve_future(self, batch: ScheduleBatch):
draft_input.topk_index = self.topk_index_buf[indices]
draft_input.bonus_tokens = self.bonus_tokens_buf[indices]
draft_input.new_seq_lens = self.new_seq_lens_buf[indices]
# Resolve seq_lens placeholder (-indices) to the post-verify view.
batch.seq_lens = draft_input.new_seq_lens
if spec_need_hidden_states():
draft_input.hidden_states = self.hidden_states_buf[indices]

def store_to_map(
self, future_indices: FutureIndices, batch_result: GenerationBatchResult
):
def resolve_seq_lens_cpu(self, batch: ScheduleBatch) -> None:
fi = batch.spec_info.future_indices if batch.spec_info is not None else None
if fi is None:
return
if self.publish_ready is not None:
self.publish_ready.wait()
batch.seq_lens_cpu = self.new_seq_lens_buf[fi.indices].cpu()
batch.seq_lens_sum = int(batch.seq_lens_cpu.sum())

def publish(
self, future_indices: FutureIndices, new_seq_lens: torch.Tensor
) -> None:
"""Store schedule-consumed fields and signal publish_ready."""
if self.spec_algo.is_none():
indices = future_indices.indices
if indices.shape[0] == 0:
# DP attention idle rank: indices is empty but next_token_ids
# may carry padded values from sibling ranks. Nothing to store
# for this rank.
return
# next_token_ids is int32; buf is int64. Slice assignment used to
# cast implicitly, but advanced indexing requires an explicit match.
self.token_ids_buf[indices] = batch_result.next_token_ids.to(torch.int64)
else:
draft_input: EagleDraftInput = batch_result.next_draft_input
self.store_to_map_for_new_batch(future_indices, draft_input)

def store_to_map_for_new_batch(
self, future_indices: FutureIndices, draft_input: EagleDraftInput
):
return
indices = future_indices.indices
if indices.shape[0] == 0:
return # DP idle
self.new_seq_lens_buf[indices] = new_seq_lens.to(self.new_seq_lens_buf.dtype)
if self.publish_ready is None:
self.publish_ready = torch.get_device_module(self.device).Event()
self.publish_ready.record()

def stash(self, future_indices: FutureIndices, payload) -> None:
"""Store forward-only fields for the next forward batch to pick up."""
indices = future_indices.indices
if indices.shape[0] == 0:
# DP idle rank: draft_input fields are empty stubs without a usable
# shape, so _lazy_init_buf's shape peek (draft_input.topk_p[0])
# would IndexError. Defer init until a real batch arrives.
return # DP idle
if self.spec_algo.is_none():
# next_token_ids is int32; buf is int64. Advanced indexing requires
# an explicit cast.
self.token_ids_buf[indices] = payload.to(torch.int64)
return

if not self.buf_initialized:
self._lazy_init_buf(draft_input)

# Slice assignment used to coerce src dtype to buf dtype implicitly;
# advanced index requires an explicit cast. bonus_tokens / new_seq_lens
# in particular differ across disagg (int64) and forward (int32) paths.
self.topk_p_buf[indices] = draft_input.topk_p.to(self.topk_p_buf.dtype)
self.topk_index_buf[indices] = draft_input.topk_index.to(
self.topk_index_buf.dtype
)
draft_input: EagleDraftInput = payload
if not self._forward_buf_initialized:
self._lazy_init_forward_buf(draft_input)
self.bonus_tokens_buf[indices] = draft_input.bonus_tokens.to(
self.bonus_tokens_buf.dtype
)
self.new_seq_lens_buf[indices] = draft_input.new_seq_lens.to(
self.new_seq_lens_buf.dtype
self.topk_p_buf[indices] = draft_input.topk_p.to(self.topk_p_buf.dtype)
self.topk_index_buf[indices] = draft_input.topk_index.to(
self.topk_index_buf.dtype
)
if spec_need_hidden_states():
self.hidden_states_buf[indices] = draft_input.hidden_states.to(
Expand Down
41 changes: 3 additions & 38 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2420,8 +2420,7 @@ def prepare_for_decode(self):
self.seq_lens.add_(1)
self.seq_lens_cpu.add_(1)
self.orig_seq_lens.add_(1)
# Defer compute to refresh_seq_lens_cpu (either pre-forward in scheduler.py
# or lazily in ForwardBatch.init_new).
# Sum is recomputed lazily by ForwardBatch.init_new.
self.seq_lens_sum = None

if self.hisparse_coordinator is not None:
Expand All @@ -2447,36 +2446,13 @@ def prepare_for_decode(self):
.to(device=self.device, non_blocking=True)
)

def maybe_wait_verify_done(self):
# Use event.wait() (stream-level wait) instead of .synchronize()
# (CPU block). Schedule-stream prep ops following this call get
# ordered after the forward-stream verify via the wait; CPU is not
# blocked. Subsequent .cpu()/.item() naturally sync the stream.
if self.is_spec_v2:
draft_input: EagleDraftInput = self.spec_info
if draft_input.verify_done is not None:
draft_input.verify_done.wait()

def refresh_seq_lens_cpu(self, sync: bool = True):
# sync=True: D2H from seq_lens (needed when seq_lens_cpu is stale
# relative to seq_lens, i.e. spec v2's mid-forward GPU rebind).
# sync=False: caller asserts seq_lens_cpu already fresh — skip D2H,
# only recompute the cached sum.
if sync and self.is_spec_v2:
self.seq_lens_cpu = self.seq_lens.cpu()
self.seq_lens_sum = int(self.seq_lens_cpu.sum())

def filter_batch(
self,
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
keep_indices: Optional[List[int]] = None,
# FIXME(lsyin): deprecate this API after spec v1 is deprecated
v1_spec_info_filtered: Optional[bool] = False,
):
# FIXME(lsyin): used here to get the correct seq_lens
# The batch has been launched but we need it verified to get correct next batch info
self.maybe_wait_verify_done()

if keep_indices is None:
if isinstance(chunked_req_to_exclude, Req):
chunked_req_to_exclude = [chunked_req_to_exclude]
Expand Down Expand Up @@ -2516,8 +2492,7 @@ def filter_batch(
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
# Defer compute to refresh_seq_lens_cpu (either pre-forward in scheduler.py
# or lazily in ForwardBatch.init_new).
# Sum is recomputed lazily by ForwardBatch.init_new.
self.seq_lens_sum = None

if self.input_ids is not None:
Expand Down Expand Up @@ -2553,15 +2528,6 @@ def filter_batch(
)

def merge_batch(self, other: "ScheduleBatch"):
# In the regular scheduler path:
# 1) self is always prefill, whose seq_lens is not a future
# 2) other is always decode, which is finished in previous step
# so verify_done is already synced and this is a no-op.
# In disagg decode + overlap, merge_batch can be called before
# filter_batch, so running_batch.seq_lens may still be a forward_stream
# future. Synchronize here to avoid a cross-stream data race.
self.maybe_wait_verify_done()

# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
# needs to be called with pre-merged Batch.reqs.
Expand All @@ -2578,8 +2544,7 @@ def merge_batch(self, other: "ScheduleBatch"):
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None
# Defer compute to refresh_seq_lens_cpu (either pre-forward in scheduler.py
# or lazily in ForwardBatch.init_new).
# Sum is recomputed lazily by ForwardBatch.init_new.
self.seq_lens_sum = None
if self.input_ids is not None:
self.input_ids = torch.cat([self.input_ids, other.input_ids])
Expand Down
49 changes: 36 additions & 13 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import time
from collections import deque
from contextlib import contextmanager, nullcontext
from functools import partial
from http import HTTPStatus
from typing import Any, Deque, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -2834,18 +2835,36 @@ def run_batch(
# Run forward
if self.is_generation:
if self.enable_overlap:
# Refresh BEFORE _overlap_forward_isolation so snapshot
# captures fresh values and restore preserves them.
batch.refresh_seq_lens_cpu()
# Spec v2 pre-isolation CPU mirror prep: D2H new_seq_lens_buf
# into batch.seq_lens_cpu + set seq_lens_sum. For non-spec_v2,
# ForwardBatch.init_new lazily computes the sum.
if batch.is_spec_v2:
# FIXME: make this optional to different backends.
self.future_map.resolve_seq_lens_cpu(batch)

with self._overlap_forward_isolation(batch):
future_indices = FutureIndices(indices=batch.req_pool_indices)

# Spec_v2 worker fires this between sample-end and
# draft_extend; publish moves the fence to verify-end so
# schedule prep can overlap with draft_extend.
fwd_kwargs = (
{
"on_verify_complete": partial(
self.future_map.publish, future_indices
)
}
if batch.is_spec_v2
else {}
)

with self.forward_stream_ctx:
self.forward_stream.wait_stream(self.schedule_stream)
self.future_map.resolve_future(batch)
# FIXME: pp is not compatible with overlap
batch_result = self.model_worker.forward_batch_generation(batch)
batch_result = self.model_worker.forward_batch_generation(
batch, **fwd_kwargs
)
# Park any refs the worker wants kept alive 2 iters
# (cross-stream tensor lifetime; pinned in the same
# ring slot as the SB attr snapshot).
Expand All @@ -2856,7 +2875,12 @@ def run_batch(
# FIXME(lsyin): maybe move this to forward_batch_generation
batch_result.copy_done = self.device_module.Event()
if batch_result.delay_sample_func is None:
self.future_map.store_to_map(future_indices, batch_result)
stash_payload = (
batch_result.next_draft_input
if batch.is_spec_v2
else batch_result.next_token_ids
)
self.future_map.stash(future_indices, stash_payload)
batch_result.copy_to_cpu(
return_logprob=batch.return_logprob,
return_hidden_states=batch.return_hidden_states,
Expand All @@ -2869,15 +2893,11 @@ def run_batch(
batch.input_ids = -future_indices.indices

if batch.is_spec_v2:
# FIXME(lsyin): tmp code for spec v2
# We only keep future indices for next draft input

batch.spec_info = batch_result.next_draft_input
batch.spec_info.future_indices = future_indices

# The future value, usually for next batch preparation
# Current implementation strictly synchronizes the seq_lens
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
# Schedule-stream sentinel between iters; next iter's
# resolve_future reassigns batch.seq_lens from new_seq_lens_buf.
batch.seq_lens = -future_indices.indices
elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
batch_result = self.tp_worker.forward_batch_split_prefill(batch)
if isinstance(batch_result.next_token_ids, torch.Tensor):
Expand Down Expand Up @@ -2961,7 +2981,10 @@ def launch_batch_sample_if_needed(
self.forward_stream.wait_stream(self.schedule_stream)
_batch_result = batch_result.delay_sample_func()
assert _batch_result is batch_result
self.future_map.store_to_map(batch_result.future_indices, batch_result)
# Delay-sample is non-spec only; stash takes next_token_ids tensor.
self.future_map.stash(
batch_result.future_indices, batch_result.next_token_ids
)
batch_result.copy_to_cpu(
return_logprob=self.cur_batch.return_logprob,
return_hidden_states=self.cur_batch.return_hidden_states,
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def init_new(
seq_lens_cpu = batch.seq_lens_cpu

if batch.seq_lens_sum is None:
batch.refresh_seq_lens_cpu(sync=False)
batch.seq_lens_sum = int(batch.seq_lens_cpu.sum())

ret = cls(
forward_mode=batch.forward_mode,
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/speculative/eagle_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,6 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
# V2 overlap worker only
future_indices: Optional[FutureIndices] = None
new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None
# V2 reuses `EagleDraftInput` across phases (V1 has a separate
# `EagleDraftExtendInput` for these). Set during V2's draft-extend.
num_correct_drafts: Optional[torch.Tensor] = None
Expand Down
7 changes: 1 addition & 6 deletions python/sglang/srt/speculative/eagle_info_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@ def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):

bs = batch.batch_size()

# Now seq_lens is correct
batch.maybe_wait_verify_done()

# Accumulate penalty
# This is a relaxed version of penalties for speculative decoding.
if batch.sampling_info.penalizer_orchestrator.is_required:
Expand Down Expand Up @@ -231,9 +228,7 @@ def prepare_for_extend_to_fill_draft_kvcache(
batch.input_ids = predict
batch.seq_lens = batch.seq_lens + num_draft_tokens
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
# seq_lens_cpu was just CPU-updated in tandem — sync=False avoids
# a redundant D2H on the draft hot path.
batch.refresh_seq_lens_cpu(sync=False)
batch.seq_lens_sum = int(batch.seq_lens_cpu.sum())
batch.extend_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
batch.prefix_lens = seq_lens_cpu_.tolist()
batch.extend_num_tokens = extend_num_tokens
Expand Down
Loading
Loading