Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
db9ef5e
[scheduler] scheduler overlap
liupeng374 Dec 31, 2025
843c48e
add ENABLE_SPECULATIVE_OVERLAP_REFLOW and run success
litmei Jan 6, 2026
606e622
extract draft_v2 function
litmei Jan 8, 2026
d58c18e
extract draft_v2 function
litmei Jan 8, 2026
0fbe1e9
run draft_v2 with acl graph
litmei Jan 12, 2026
86fc8f3
extracting an individual method with design logic
litmei Jan 15, 2026
ae9b07f
extracting an individual method with design logic
litmei Jan 15, 2026
828cf17
extracting an individual method with design logic
litmei Jan 15, 2026
dcf1e90
compatible with the original logic
litmei Jan 15, 2026
7ece9a4
compatible with the original logic
litmei Jan 15, 2026
546f941
update
litmei Jan 15, 2026
c11e9d1
use F.pad repeat cat
litmei Jan 15, 2026
ab17246
update capture branch info
litmei Jan 15, 2026
a1f576a
fix self.speculative_num_steps == 1 error
litmei Jan 15, 2026
ab48727
use wrapper to simpl draft
litmei Jan 15, 2026
8883b2f
reuse draft_forward on prepare_verify_reflow
litmei Jan 15, 2026
d4f7260
remove save hidden_states todo
litmei Jan 16, 2026
7e1d116
add is_draft_v2 judgement
litmei Jan 16, 2026
1a279f7
add envs parm
litmei Jan 16, 2026
8a16507
Merge remote-tracking branch 'official/main' into mtp_scheduler_overlap
litmei Jan 16, 2026
b1f2400
Merge remote-tracking branch 'official/main' into mtp_scheduler_overlap
litmei Jan 19, 2026
d59463d
fix tp_worker.get_tp_group().cpu_group
litmei Jan 19, 2026
1a70944
pre-commit
litmei Jan 19, 2026
16e886a
add assert
litmei Jan 19, 2026
24f671d
add assert
litmei Jan 19, 2026
66422ba
Merge remote-tracking branch 'official/main' into mtp_scheduler_overlap
litmei Jan 28, 2026
801482f
fix precision
litmei Feb 4, 2026
abe6ed8
Merge branch 'mtp_scheduler_overlap' into mtp_scheduler_overlap_self
litmei Feb 27, 2026
bc93650
update reflow: prepare 不走 init
litmei Mar 4, 2026
608af80
pd update
litmei Mar 4, 2026
a5eb6b9
hang
litmei Mar 12, 2026
2e2ce12
Merge branch 'mtp_scheduler_overlap_self' into mtp_scheduler_overlap_…
litmei Mar 12, 2026
1977214
Merge branch 'mtp_scheduler_overlap_self_2' into mtp_v2
litmei Mar 27, 2026
87e5864
update name
litmei Mar 28, 2026
95bb24f
fix low acc lens
litmei Mar 28, 2026
402c6a9
rm seq_lens_cpu sync on draft_v2
litmei Mar 30, 2026
a3a1b66
rename feature name
litmei Mar 30, 2026
9bd6e74
update comment
litmei Mar 30, 2026
1c9d43c
rename func
litmei Mar 30, 2026
2adee15
add test case
litmei Apr 2, 2026
985d796
comments fix
litmei Apr 2, 2026
94a32a3
Merge branch 'main' into mtp_v2_push
litmei Apr 7, 2026
0e74f9c
fix test case
litmei Apr 8, 2026
27d442d
fix xeon platform
litmei Apr 7, 2026
39cc874
update is_pin_memory_available
litmei Apr 8, 2026
f0661f0
tiny update
litmei Apr 8, 2026
1a1c836
Merge branch 'main' into mtp_v2_push
litmei Apr 10, 2026
4da0ca2
fix pin_memory not work
litmei Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from typing import TYPE_CHECKING

import torch
import torch.nn.functional as F

from sglang.srt.environ import envs
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo

Expand Down Expand Up @@ -162,6 +164,16 @@ def process_prebuilt(
hidden_states_list = [req.hidden_states_tensor for req in self.reqs]
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)

enable_spec_v2_zero_bubble = envs.SGLANG_SPEC_V2_ZERO_BUBBLE.get()

if enable_spec_v2_zero_bubble and server_args.speculative_num_steps > 1:
topk_pad_size = (
server_args.speculative_num_steps * num_states - topk_p.shape[-1]
)

topk_p = F.pad(topk_p, (0, topk_pad_size))
topk_index = F.pad(topk_index, (0, topk_pad_size))

# local import to avoid circular import
from sglang.srt.speculative.eagle_info import EagleDraftInput

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ class Envs:
SGLANG_SPEC_ENABLE_STRICT_FILTER_CHECK = EnvBool(True)
SGLANG_SPEC_NAN_DETECTION = EnvBool(False)
SGLANG_SPEC_OOB_DETECTION = EnvBool(False)
SGLANG_SPEC_V2_ZERO_BUBBLE = EnvBool(False)

# VLM
SGLANG_VLM_CACHE_SIZE_MB = EnvInt(100)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ def replay(
# Replay
if not is_deepseek_nsa(self.model_runner.model_config.hf_config):
if forward_batch.forward_mode.is_target_verify():
seq_lens_cpu = forward_batch.seq_lens.cpu() + self.num_tokens_per_bs
seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_tokens_per_bs
seq_lens = seq_lens_cpu.tolist() + [0] * (self.bs - self.raw_bs)
else:
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
seq_lens = forward_batch.seq_lens_cpu.tolist() + [0] * (
self.bs - self.raw_bs
)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
Expand Down
40 changes: 28 additions & 12 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
is_npu,
support_triton,
)
from sglang.srt.utils.common import ceil_align
from sglang.srt.utils.common import ceil_align, is_pin_memory_available

if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
Expand Down Expand Up @@ -486,6 +486,7 @@ def init_new(
rids=[req.rid for req in batch.reqs],
)
device = model_runner.device
_pin = is_pin_memory_available(device)

if batch.extend_input_logprob_token_ids is not None:
ret.extend_input_logprob_token_ids_gpu = (
Expand All @@ -494,9 +495,9 @@ def init_new(

num_tokens = len(batch.input_ids) if batch.input_ids is not None else 0
if enable_num_token_non_padded(model_runner.server_args):
ret.num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32).to(
device, non_blocking=True
)
ret.num_token_non_padded = torch.tensor(
num_tokens, dtype=torch.int32, pin_memory=_pin
).to(device, non_blocking=True)
ret.num_token_non_padded_cpu = num_tokens

# For MLP sync
Expand All @@ -516,15 +517,18 @@ def init_new(
ret.original_global_num_tokens_cpu = batch.global_num_tokens
ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor(
global_num_tokens, dtype=torch.int64
global_num_tokens, dtype=torch.int64, pin_memory=_pin
).to(device, non_blocking=True)

ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
global_num_tokens_for_logprob, dtype=torch.int64
global_num_tokens_for_logprob, dtype=torch.int64, pin_memory=_pin
).to(device, non_blocking=True)

if ret.forward_mode.is_idle():
if _is_npu:
# This synchronize is necessary to prevent the system from hanging on npu.
torch.npu.synchronize()
ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
return ret

Expand All @@ -540,6 +544,7 @@ def init_new(
for i in range(block_offset, block_offset + block_size)
],
dtype=positions_dtype,
pin_memory=_pin,
).to(device, non_blocking=True)
elif (
ret.spec_info is not None
Expand All @@ -555,10 +560,10 @@ def init_new(
assert isinstance(batch.extend_seq_lens, list)
assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
batch.extend_seq_lens, dtype=torch.int32, pin_memory=_pin
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
batch.extend_prefix_lens, dtype=torch.int32, pin_memory=_pin
).to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position(
Expand Down Expand Up @@ -761,6 +766,7 @@ def _compute_mrope_positions(
# batch_size * [3 * seq_len]
batch_size = self.seq_lens_cpu.shape[0]
mrope_positions_list = [[]] * batch_size
_pin = is_pin_memory_available(model_runner.device)
for batch_idx in range(batch_size):
mm_input = batch.multimodal_inputs[batch_idx]
if self.forward_mode.is_decode():
Expand Down Expand Up @@ -812,10 +818,20 @@ def _compute_mrope_positions(
)
mrope_positions_list[batch_idx] = mrope_positions

self.mrope_positions = torch.cat(
[pos for pos in mrope_positions_list],
dim=1,
).to(dtype=torch.int64, device=model_runner.device, non_blocking=True)
if _pin:
self.mrope_positions = (
torch.cat(
[pos for pos in mrope_positions_list],
dim=1,
)
.pin_memory()
.to(dtype=torch.int64, device=model_runner.device, non_blocking=True)
)
else:
self.mrope_positions = torch.cat(
[pos for pos in mrope_positions_list],
dim=1,
).to(dtype=torch.int64, device=model_runner.device, non_blocking=True)

def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
if value == 0:
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def __init__(
self.init_new_workspace = False
self.draft_model_idx = draft_model_idx
self.enable_hisparse = server_args.enable_hisparse
self.enable_spec_v2_zero_bubble = envs.SGLANG_SPEC_V2_ZERO_BUBBLE.get()

self.remote_instance_transfer_engine = None
self.remote_instance_transfer_engine_session_id = ""
Expand Down Expand Up @@ -2920,6 +2921,7 @@ def _forward_raw(
and forward_batch.global_num_tokens_gpu is not None
and require_gathered_buffer(self.server_args)
and not is_nsa_enable_prefill_cp()
and not self.enable_spec_v2_zero_bubble
):
forward_batch.adjust_num_token_non_padded_for_attn_tp(
server_args=self.server_args,
Expand Down
14 changes: 13 additions & 1 deletion python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch

from sglang.srt.environ import envs
from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self, eagle_worker: EAGLEWorker):
)
self.enable_pdmux = False
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
self.enable_spec_v2_zero_bubble = envs.SGLANG_SPEC_V2_ZERO_BUBBLE.get()

# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
Expand Down Expand Up @@ -329,7 +331,13 @@ def run_once():
output_cache_loc_backup = forward_batch.out_cache_loc
hidden_states_backup = forward_batch.spec_info.hidden_states

ret = self.eagle_worker.draft_forward(forward_batch)
if self.enable_spec_v2_zero_bubble:
assert hasattr(
self.eagle_worker, "draft_forward_zero_bubble"
), "`Spec v2 zero bubble` just support when enable `overlap scheduler` and enable `eagle algorithm` now"
ret = self.eagle_worker.draft_forward_zero_bubble(forward_batch)
else:
ret = self.eagle_worker.draft_forward(forward_batch)

forward_batch.out_cache_loc = output_cache_loc_backup
forward_batch.spec_info.hidden_states = hidden_states_backup
Expand All @@ -348,6 +356,10 @@ def run_once():

def _postprocess_output_to_raw_bs(self, out, raw_bs):
# Keep the variables name for readability
if self.enable_spec_v2_zero_bubble:
ret_topk_p_list, ret_topk_index_list = (t[:raw_bs] for t in out)
return ret_topk_p_list, ret_topk_index_list

parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out)
return parent_list, top_scores_index, draft_tokens

Expand Down
Loading
Loading