Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
66e8b7a
introduce EagleDraftExtendInput; split phase from EagleDraftInput
hnyls2002 May 9, 2026
057b8f7
rename extend_input -> draft_extend_input
hnyls2002 May 9, 2026
f5b85fd
add isinstance asserts at draft_extend_input phase boundaries
hnyls2002 May 9, 2026
0634680
move draft_extend_input install out of verify() into forward_draft_ex…
hnyls2002 May 9, 2026
1fc71dd
move spec_info phase install to executor (forward_batch_generation)
hnyls2002 May 9, 2026
88fffbe
V1: forward_draft_extend_after_decode returns next_draft_input; execu…
hnyls2002 May 9, 2026
958365e
drop unused model_worker_batch from verify() return
hnyls2002 May 9, 2026
bfabaa7
drop redundant spec_info.positions = None
hnyls2002 May 9, 2026
74c9ca9
drop stale draft_extend shape note on EagleDraftInput.hidden_states
hnyls2002 May 9, 2026
5f80483
move generate_attn_arg_prefill to EagleDraftExtendInput; tighten sele…
hnyls2002 May 9, 2026
6ad0e49
drop redundant spec_info args; read spec_info from batch
hnyls2002 May 9, 2026
1ff8f8f
move verify->extend handoff fields onto EagleDraftExtendInput
hnyls2002 May 9, 2026
7bfe3c0
forward_batch_info: getattr-guard num_accepted_drafts on draft-phase …
hnyls2002 May 9, 2026
4c680f1
forward_batch_info: getattr-guard topk_p/topk_index on draft-extend s…
hnyls2002 May 9, 2026
490bcc0
v1: install empty EagleDraftInput when extend skipped (retract edge c…
hnyls2002 May 9, 2026
a6c4467
restore num_accepted_drafts/tokens on EagleDraftInput for V2
hnyls2002 May 9, 2026
3d123cd
stash spec_info comments; drop unused batch param
hnyls2002 May 9, 2026
c26ca56
cleanup forward_batch_info comments; drop dead server_args param
hnyls2002 May 9, 2026
be72838
drop redundant spec_info.positions = None
hnyls2002 May 9, 2026
dd35129
drop unused model_worker_batch from verify() return
hnyls2002 May 9, 2026
32b40e2
drop dead server_args param from enable_num_token_non_padded
hnyls2002 May 9, 2026
56d55d9
drop dead batch param from check_forward_draft_extend_after_decode
hnyls2002 May 9, 2026
2caa926
merge lsyin/spec-drop-dead-code
hnyls2002 May 9, 2026
7c8825f
merge origin/main
hnyls2002 May 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,17 +989,18 @@ def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs):
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)

if self.spec_info is not None and self.spec_info.is_draft_input():
# FIXME(lsyin): remove this isinstance logic
spec_info = self.spec_info
self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states
if spec_info.topk_p is not None:
# spec_info is EagleDraftInput | EagleDraftExtendInput; each carries
# a disjoint subset of the fields below, so getattr-guard each one.
if getattr(spec_info, "topk_p", None) is not None:
spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs)
if spec_info.topk_index is not None:
if getattr(spec_info, "topk_index", None) is not None:
spec_info.topk_index = self._pad_tensor_to_size(
spec_info.topk_index, bs
)
if spec_info.num_accepted_drafts is not None:
if getattr(spec_info, "num_accepted_drafts", None) is not None:
spec_info.num_accepted_drafts = self._pad_tensor_to_size(
spec_info.num_accepted_drafts, bs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
ForwardMode,
)
from sglang.srt.model_executor.input_buffers import ForwardInputBuffers
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.eagle_info import EagleDraftExtendInput
from sglang.srt.speculative.spec_utils import fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
Expand Down Expand Up @@ -360,7 +360,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0
else:
global_dp_buffer_len = None

spec_info = EagleDraftInput(
spec_info = EagleDraftExtendInput(
hidden_states=hidden_states,
num_accepted_drafts=num_accepted_drafts,
num_accepted_tokens=num_accepted_tokens,
Expand Down
306 changes: 161 additions & 145 deletions python/sglang/srt/speculative/eagle_info.py

Large diffs are not rendered by default.

90 changes: 56 additions & 34 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
EAGLEDraftExtendCudaGraphRunner,
)
from sglang.srt.speculative.eagle_info import (
EagleDraftExtendInput,
EagleDraftInput,
EagleVerifyInput,
EagleVerifyOutput,
Expand Down Expand Up @@ -480,14 +481,13 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
spec_info = self.draft(batch)
verify_input = self.draft(batch)

set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True)
set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True)

logits_output, verify_output, can_run_cuda_graph = self.verify(
batch, spec_info
)
batch.spec_info = verify_input
logits_output, verify_output, can_run_cuda_graph = self.verify(batch)

if get_global_tracing_enabled():
for idx, req in enumerate(batch.reqs):
Expand All @@ -503,12 +503,24 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
draft_extend_input = verify_output.draft_extend_input
if (
self.server_args.enable_dp_attention
or verify_output.unfinished_accept_tokens.shape[0] > 0
or draft_extend_input.input_ids.shape[0] > 0
):
# decode is not finished
self.forward_draft_extend_after_decode(batch, verify_output)
# decode is not finished; stash for extend, then restash
# the next-iter EagleDraftInput it returns.
batch.spec_info = draft_extend_input
next_draft_input = self.forward_draft_extend_after_decode(batch)
batch.spec_info = next_draft_input
else:
# All reqs finished and dp_attention isn't forcing extend.
# Stash an empty EagleDraftInput so next iter's merge_batch
# short-circuits on None hidden_states (EagleVerifyInput
# has no merge_batch).
batch.spec_info = EagleDraftInput(
capture_hidden_mode=CaptureHiddenMode.LAST,
)

set_time_batch(
batch.reqs, "set_spec_draft_extend_end_time", trace_only=True
Expand All @@ -528,7 +540,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
)

def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput):
local_need_forward = verify_output.unfinished_accept_tokens.shape[0] > 0
local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0
if not self.server_args.enable_dp_attention:
return local_need_forward

Expand Down Expand Up @@ -890,7 +902,8 @@ def clear_cache_pool(self):
# allocator and kv cache pool are shared with target worker
pass

def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
def verify(self, batch: ScheduleBatch):
spec_info: EagleVerifyInput = batch.spec_info
seq_lens_pre_verify = batch.seq_lens.clone()
spec_info.prepare_for_verify(batch, self.page_size)
spec_info.num_tokens_per_req = self.speculative_num_steps + 1
Expand All @@ -900,7 +913,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
if not batch.forward_mode.is_idle()
else ForwardMode.IDLE
)
batch.spec_info = spec_info

model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=spec_info.seq_lens_cpu
Expand Down Expand Up @@ -977,7 +989,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
batch.forward_mode = (
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
)
batch.spec_info = res.next_draft_input

return logits_output, res, can_run_cuda_graph

Expand Down Expand Up @@ -1105,20 +1116,20 @@ def forward_draft_extend(
self.capture_for_decode(logits_output, forward_batch.spec_info)

def forward_draft_extend_after_decode(
self, batch: ScheduleBatch, verify_output: EagleVerifyOutput
):
assert isinstance(batch.spec_info, EagleDraftInput)
self, batch: ScheduleBatch
) -> EagleDraftInput:
draft_extend_input: EagleDraftExtendInput = batch.spec_info

# Backup fields that will be modified in-place
seq_lens_backup = batch.seq_lens.clone()
seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
req_pool_indices_backup = batch.req_pool_indices
num_accepted_drafts_backup = batch.spec_info.num_accepted_drafts.clone()
num_accepted_tokens_backup = batch.spec_info.num_accepted_tokens.clone()
return_logprob_backup = batch.return_logprob

input_is_idle = batch.forward_mode.is_idle()

if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0:
if not input_is_idle and draft_extend_input.input_ids.numel() == 0:
# All reqs finished this verify; swap to an idle ExtendInput.
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
Expand All @@ -1127,19 +1138,19 @@ def forward_draft_extend_after_decode(
and self.eagle_use_aux_hidden_state
else self.model_config.spec_hidden_size
)
batch.spec_info = EagleDraftInput.create_idle_input(
draft_extend_input = EagleDraftExtendInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.spec_info = draft_extend_input

batch.spec_info.num_tokens_per_req = self.speculative_num_steps + 1
batch.spec_info.num_tokens_for_logprob_per_req = 1
batch.spec_info.prepare_extend_after_decode(
# Phase 1: prepare extend (kernel writes draft_extend_input.{positions, bonus_tokens})
draft_extend_input.num_tokens_per_req = self.speculative_num_steps + 1
draft_extend_input.num_tokens_for_logprob_per_req = 1
draft_extend_input.prepare_extend_after_decode(
batch,
verify_output=verify_output,
speculative_num_steps=self.speculative_num_steps,
)
batch.forward_mode = (
Expand All @@ -1159,7 +1170,7 @@ def forward_draft_extend_after_decode(
else:
forward_batch.seq_lens_sum = batch.seq_lens.sum().item()

# Run
# Phase 2: run draft-extend forward
can_cuda_graph = (
self.cuda_graph_runner_for_draft_extend
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
Expand All @@ -1168,11 +1179,10 @@ def forward_draft_extend_after_decode(
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
forward_batch
)
forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = (
logits_output.topk_p,
logits_output.topk_index,
)
forward_batch.spec_info.hidden_states = logits_output.hidden_states
# cuda-graph replay populates logits_output.{topk_p, topk_index, hidden_states}.
topk_p = logits_output.topk_p
topk_index = logits_output.topk_index
hidden_states = logits_output.hidden_states
else:
forward_batch.can_run_dp_cuda_graph = False
if not forward_batch.forward_mode.is_idle():
Expand All @@ -1185,24 +1195,36 @@ def forward_draft_extend_after_decode(
logits_output = self.draft_model_runner.forward(
forward_batch, skip_attn_backend_init=True
).logits_output
self.capture_for_decode(logits_output, forward_batch.spec_info)
# Non-cuda-graph path: compute topk_p / topk_index inline.
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
hidden_states = logits_output.hidden_states

maybe_detect_nan(
logits_output.next_token_logits,
f"draft_extend_after_decode (cuda_graph={can_cuda_graph})",
)

# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
# Phase 3: assemble next-iter EagleDraftInput from extend output
next_draft_input = EagleDraftInput(
bonus_tokens=draft_extend_input.bonus_tokens,
hidden_states=hidden_states,
topk_p=topk_p,
topk_index=topk_index,
capture_hidden_mode=CaptureHiddenMode.FULL,
)

# Restore batch fields. `seq_lens` etc. were modified by
# `prepare_extend_after_decode`. Caller installs `next_draft_input` as
# `batch.spec_info`.
batch.forward_mode = (
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
)
batch.seq_lens = seq_lens_backup
batch.seq_lens_cpu = seq_lens_cpu_backup
batch.req_pool_indices = req_pool_indices_backup
batch.spec_info.num_accepted_drafts = num_accepted_drafts_backup
batch.spec_info.num_accepted_tokens = num_accepted_tokens_backup
batch.return_logprob = return_logprob_backup
return next_draft_input

def capture_for_decode(
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
Expand Down
29 changes: 20 additions & 9 deletions python/sglang/srt/speculative/frozen_kv_mtp_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.speculative.eagle_info import (
EagleDraftExtendInput,
EagleDraftInput,
EagleVerifyInput,
EagleVerifyOutput,
Expand Down Expand Up @@ -53,6 +54,14 @@ def __post_init__(self):
SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_DRAFT)


@dataclass
class FrozenKVMTPDraftExtendInput(EagleDraftExtendInput):
"""Draft-extend input for Frozen-KV MTP. Tag-only subclass."""

def __post_init__(self):
SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_DRAFT_EXTEND)


@dataclass
class FrozenKVMTPVerifyInput(EagleVerifyInput):
"""Verify input for Frozen-KV MTP."""
Expand All @@ -62,21 +71,23 @@ def __post_init__(self):

def verify(self, *args, **kwargs) -> EagleVerifyOutput:
output = super().verify(*args, **kwargs)
output.next_draft_input = _to_frozen_kv_mtp_draft_input(output.next_draft_input)
output.draft_extend_input = _to_frozen_kv_mtp_draft_extend_input(
output.draft_extend_input
)
return output


FrozenKVMTPVerifyOutput = EagleVerifyOutput


def _to_frozen_kv_mtp_draft_input(
draft_input: EagleDraftInput,
) -> FrozenKVMTPDraftInput:
if isinstance(draft_input, FrozenKVMTPDraftInput):
return draft_input
return FrozenKVMTPDraftInput(
def _to_frozen_kv_mtp_draft_extend_input(
draft_extend_input: EagleDraftExtendInput,
) -> FrozenKVMTPDraftExtendInput:
if isinstance(draft_extend_input, FrozenKVMTPDraftExtendInput):
return draft_extend_input
return FrozenKVMTPDraftExtendInput(
**{
field.name: getattr(draft_input, field.name)
for field in fields(EagleDraftInput)
field.name: getattr(draft_extend_input, field.name)
for field in fields(EagleDraftExtendInput)
}
)
6 changes: 2 additions & 4 deletions python/sglang/srt/speculative/frozen_kv_mtp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.speculative.frozen_kv_mtp_info import (
FrozenKVMTPContext,
FrozenKVMTPDraftExtendInput,
FrozenKVMTPDraftInput,
)
from sglang.srt.speculative.spec_utils import fast_topk
Expand Down Expand Up @@ -134,11 +135,8 @@ def select_last_extend_hidden(


def select_last_verified_seed(
draft_input: FrozenKVMTPDraftInput,
draft_input: FrozenKVMTPDraftExtendInput,
) -> Tuple[torch.Tensor, torch.Tensor]:
if draft_input.num_accepted_tokens is None:
return draft_input.bonus_tokens, draft_input.hidden_states

counts = draft_input.num_accepted_tokens.to(torch.long)
last_indices = torch.cumsum(counts, dim=0) - 1
return (
Expand Down
Loading
Loading