Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
66e8b7a
introduce EagleDraftExtendInput; split phase from EagleDraftInput
hnyls2002 May 9, 2026
057b8f7
rename extend_input -> draft_extend_input
hnyls2002 May 9, 2026
f5b85fd
add isinstance asserts at draft_extend_input phase boundaries
hnyls2002 May 9, 2026
0634680
move draft_extend_input install out of verify() into forward_draft_ex…
hnyls2002 May 9, 2026
1fc71dd
move spec_info phase install to executor (forward_batch_generation)
hnyls2002 May 9, 2026
88fffbe
V1: forward_draft_extend_after_decode returns next_draft_input; execu…
hnyls2002 May 9, 2026
958365e
drop unused model_worker_batch from verify() return
hnyls2002 May 9, 2026
bfabaa7
drop redundant spec_info.positions = None
hnyls2002 May 9, 2026
74c9ca9
drop stale draft_extend shape note on EagleDraftInput.hidden_states
hnyls2002 May 9, 2026
5f80483
move generate_attn_arg_prefill to EagleDraftExtendInput; tighten sele…
hnyls2002 May 9, 2026
6ad0e49
drop redundant spec_info args; read spec_info from batch
hnyls2002 May 9, 2026
1ff8f8f
move verify->extend handoff fields onto EagleDraftExtendInput
hnyls2002 May 9, 2026
7bfe3c0
forward_batch_info: getattr-guard num_accepted_drafts on draft-phase …
hnyls2002 May 9, 2026
4c680f1
forward_batch_info: getattr-guard topk_p/topk_index on draft-extend s…
hnyls2002 May 9, 2026
490bcc0
v1: install empty EagleDraftInput when extend skipped (retract edge c…
hnyls2002 May 9, 2026
a6c4467
restore num_accepted_drafts/tokens on EagleDraftInput for V2
hnyls2002 May 9, 2026
6b31206
v2: install EagleDraftExtendInput as draft-extend spec_info
hnyls2002 May 9, 2026
54101ee
v2 sample returns EagleVerifyOutput; thread draft_extend_input via ba…
hnyls2002 May 9, 2026
ee814cb
fold can_run_cuda_graph into EagleVerifyOutput; verify returns single…
hnyls2002 May 9, 2026
9f363a4
drop GenerationBatchResult.next_draft_extend_input; v2 verify returns…
hnyls2002 May 9, 2026
5f76cf8
group spec_info fields by common / v2 / v1 sections
hnyls2002 May 9, 2026
afa7356
merge main
hnyls2002 May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/sglang/srt/speculative/eagle_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.server_args import get_global_server_args
from sglang.srt.speculative.eagle_info_v2 import (
EagleDraftExtendInputV2Mixin,
EagleDraftInputV2Mixin,
EagleVerifyInputV2Mixin,
)
Expand Down Expand Up @@ -819,7 +820,7 @@ def merge_batch(self, spec_info: "EagleDraftInput"):


@dataclass
class EagleDraftExtendInput(SpecInput):
class EagleDraftExtendInput(SpecInput, EagleDraftExtendInputV2Mixin):
"""Inputs to the draft-extend forward (the per-accepted-token pass after verify).

Produced by `EagleVerifyInput.verify`, installed on `batch.spec_info` for
Expand Down Expand Up @@ -1007,6 +1008,10 @@ class EagleVerifyOutput:
num_correct_drafts_per_req_cpu: List[int]
# Accepted indices from logits_output.next_token_logits
accept_indices: torch.Tensor
# Whether the target verify forward ran a captured cuda graph. Set by
# the worker after `EagleVerifyInput.verify` returns; default kept so
# idle / direct constructions don't have to pass it.
can_run_cuda_graph: bool = False

@classmethod
def create_idle(
Expand Down
70 changes: 55 additions & 15 deletions python/sglang/srt/speculative/eagle_info_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Tuple

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -47,7 +47,12 @@
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner,
)
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.eagle_info import (
EagleDraftExtendInput,
EagleDraftInput,
EagleVerifyInput,
EagleVerifyOutput,
)

if is_cuda() or is_musa():
from sgl_kernel import (
Expand Down Expand Up @@ -217,8 +222,11 @@ def prepare_for_v2_draft(
can_cuda_graph = cuda_graph_runner and cuda_graph_runner.can_run(forward_batch)
return forward_batch, can_cuda_graph


@dataclass
class EagleDraftExtendInputV2Mixin:
def prepare_for_extend_to_fill_draft_kvcache(
self,
self: EagleDraftExtendInput,
batch: ModelWorkerBatch,
predict: torch.Tensor,
num_draft_tokens: int,
Expand Down Expand Up @@ -335,20 +343,37 @@ def sample(
batch: ModelWorkerBatch,
logits_output: LogitsProcessorOutput,
vocab_mask: torch.Tensor = None,
):
) -> Tuple["EagleVerifyOutput", torch.Tensor]:
"""
Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information).

Returns `(verify_output, predict)` where `predict` is the full
per-position sampled tokens (V2 caller uses it as `next_token_ids`).
`verify_output.accept_tokens` is the V1-style flat accepted slice.
"""
from sglang.srt.speculative.eagle_info import (
EagleDraftExtendInput,
EagleVerifyOutput,
)

device = batch.input_ids.device
if batch.forward_mode.is_idle():
predict = torch.empty(0, dtype=torch.int32, device=batch.input_ids.device)
num_correct_drafts = torch.empty(
0, dtype=torch.int32, device=batch.input_ids.device
)
accept_index = torch.empty(
0, dtype=torch.int32, device=batch.input_ids.device
predict = torch.empty(0, dtype=torch.int32, device=device)
num_correct_drafts = torch.empty(0, dtype=torch.int32, device=device)
accept_index = torch.empty(0, dtype=torch.int32, device=device)
verify_output = EagleVerifyOutput(
draft_extend_input=EagleDraftExtendInput(
hidden_states=logits_output.hidden_states,
num_correct_drafts=num_correct_drafts,
num_accept_tokens=num_correct_drafts + 1,
),
logits_output=logits_output,
accept_tokens=torch.empty(0, dtype=torch.int32, device=device),
num_correct_drafts_per_req_cpu=[],
accept_indices=accept_index,
)
return predict, num_correct_drafts, accept_index
return verify_output, predict

bs = len(batch.seq_lens)
sampling_info = batch.sampling_info
Expand Down Expand Up @@ -481,10 +506,25 @@ def sample(
spec_steps=self.spec_steps,
)

# `num_correct_drafts` stays drafts-only inside this function; the returned
# tensor includes the trailing/bonus token via out-of-place +1 so the
# name no longer flips semantics mid-function (naming doc C2).
return predict, num_correct_drafts + 1, accept_index
# `num_correct_drafts` is drafts-only here; bonus is added via out-of-place
# +1 when packaged into `EagleDraftExtendInput.num_accept_tokens`, so the
# local name does not flip semantics mid-function (naming doc C2).
verify_output = EagleVerifyOutput(
draft_extend_input=EagleDraftExtendInput(
# V2 keeps `hidden_states` as the full target output (shape
# `[bs * draft_token_num, hidden]`); V1 instead stores the
# accept-sliced view. The downstream V2 cuda-graph runner
# expects the full layout.
hidden_states=logits_output.hidden_states,
num_correct_drafts=num_correct_drafts,
num_accept_tokens=num_correct_drafts + 1,
),
logits_output=logits_output,
accept_tokens=predict[accept_index],
num_correct_drafts_per_req_cpu=[],
accept_indices=accept_index,
)
return verify_output, predict


@triton.jit
Expand Down
15 changes: 9 additions & 6 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True)
set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True)

# Install verify_input as `batch.spec_info` for the verify forward.
batch.spec_info = verify_input
logits_output, verify_output, can_run_cuda_graph = self.verify(batch)
verify_output = self.verify(batch)

if get_global_tracing_enabled():
for idx, req in enumerate(batch.reqs):
Expand All @@ -512,8 +513,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
self.server_args.enable_dp_attention
or draft_extend_input.input_ids.shape[0] > 0
):
# decode is not finished; stash for extend, then restash
# the next-iter EagleDraftInput it returns.
# decode is not finished; install draft_extend_input for
# the extend forward, then install the next-iter
# EagleDraftInput it returns.
batch.spec_info = draft_extend_input
next_draft_input = self.forward_draft_extend_after_decode(batch)
batch.spec_info = next_draft_input
Expand All @@ -534,11 +536,11 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
)

return GenerationBatchResult(
logits_output=logits_output,
logits_output=verify_output.logits_output,
next_token_ids=verify_output.accept_tokens,
num_correct_drafts=sum(verify_output.num_correct_drafts_per_req_cpu),
num_correct_drafts_per_req_cpu=verify_output.num_correct_drafts_per_req_cpu,
can_run_cuda_graph=can_run_cuda_graph,
can_run_cuda_graph=verify_output.can_run_cuda_graph,
)

def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput):
Expand Down Expand Up @@ -1015,7 +1017,8 @@ def verify(self, batch: ScheduleBatch):
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
)

return logits_output, res, can_run_cuda_graph
res.can_run_cuda_graph = can_run_cuda_graph
return res

def _mamba_verify_update(
self,
Expand Down
85 changes: 48 additions & 37 deletions python/sglang/srt/speculative/eagle_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner,
)
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.eagle_info import (
EagleDraftExtendInput,
EagleDraftInput,
EagleVerifyInput,
EagleVerifyOutput,
)
from sglang.srt.speculative.eagle_info_v2 import (
assign_extend_cache_locs,
fill_accepted_out_cache_loc,
Expand Down Expand Up @@ -534,17 +539,13 @@ def _draft_extend_for_prefill(
)
pt += extend_len

# Construct spec_info
next_draft_input = EagleDraftInput(
# Install draft-extend spec_info for the extend forward.
extend_input = EagleDraftExtendInput(
hidden_states=target_hidden_states,
bonus_tokens=next_token_ids,
new_seq_lens=batch.seq_lens,
# draft mode is same with decode mode, only 1 token per req
num_tokens_per_req=1,
num_tokens_for_logprob_per_req=1,
)

batch.spec_info = next_draft_input
batch.spec_info = extend_input

# Run forward
forward_batch = ForwardBatch.init_new(batch, self.draft_runner)
Expand All @@ -554,22 +555,33 @@ def _draft_extend_for_prefill(
logits_output = self.draft_runner.forward(forward_batch).logits_output
maybe_detect_nan(logits_output.next_token_logits, "draft_extend_for_prefill")

# Update spec_info for the next draft step
# Assemble fresh next-iter draft spec_info from the extend output.
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
probs, self.topk, dim=-1
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
next_draft_input = EagleDraftInput(
topk_p=topk_p,
topk_index=topk_index,
hidden_states=logits_output.hidden_states,
bonus_tokens=next_token_ids,
new_seq_lens=batch.seq_lens,
num_tokens_per_req=1,
num_tokens_for_logprob_per_req=1,
)
next_draft_input.hidden_states = logits_output.hidden_states
return next_draft_input

def _draft_extend_for_decode(
self, batch: ModelWorkerBatch, batch_result: GenerationBatchResult
self,
batch: ModelWorkerBatch,
batch_result: GenerationBatchResult,
verify_output: EagleVerifyOutput,
):
# Batch 2: Draft extend
draft_input = EagleDraftInput(
hidden_states=batch_result.logits_output.hidden_states,
num_tokens_per_req=self.speculative_num_steps + 1,
num_tokens_for_logprob_per_req=self.speculative_num_steps + 1,
# Batch 2: Draft extend. verify already built draft_extend_input with
# hidden_states / num_correct_drafts / num_accept_tokens; we only need
# to set the per-req padding info for this forward.
draft_extend_input = verify_output.draft_extend_input
draft_extend_input.num_tokens_per_req = self.speculative_num_steps + 1
draft_extend_input.num_tokens_for_logprob_per_req = (
self.speculative_num_steps + 1
)
select_index = (
torch.arange(len(batch.seq_lens), device=self.device)
Expand All @@ -580,7 +592,7 @@ def _draft_extend_for_decode(

# Prepare for draft extend in a separate stream
with self.plan_stream_ctx:
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
forward_batch = draft_extend_input.prepare_for_extend_to_fill_draft_kvcache(
batch,
batch_result.next_token_ids,
self.speculative_num_draft_tokens,
Expand All @@ -593,12 +605,6 @@ def _draft_extend_for_decode(
self.plan_stream
)

if forward_batch.spec_info.num_correct_drafts is None:
# `batch_result.accept_lens` already includes the bonus token, so use it
# directly for `num_accept_tokens` and subtract 1 for `num_correct_drafts`.
forward_batch.spec_info.num_correct_drafts = batch_result.accept_lens - 1
forward_batch.spec_info.num_accept_tokens = batch_result.accept_lens

# Run draft extend batch in the main compute stream
can_cuda_graph = (
self.cuda_graph_runner_for_draft_extend
Expand Down Expand Up @@ -799,12 +805,12 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
self._draft_done_event = torch.get_device_module(self.device).Event()
self._draft_done_event.record()
model_worker_batch.spec_info = verify_input
batch_output = self.verify(model_worker_batch)
batch_output, verify_output = self.verify(model_worker_batch)
with self.draft_worker.draft_tp_context(
self.draft_worker.draft_runner.tp_group
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.draft_worker._draft_extend_for_decode(
model_worker_batch, batch_output
model_worker_batch, batch_output, verify_output
)

return batch_output
Expand Down Expand Up @@ -1049,11 +1055,8 @@ def verify(self, batch: ModelWorkerBatch):

# Sample
maybe_detect_nan(logits_output.next_token_logits, "verify: target model logits")
(
predict,
accept_lens,
accept_index,
) = verify_input.sample(batch, logits_output, vocab_mask)
verify_output, predict = verify_input.sample(batch, logits_output, vocab_mask)
accept_lens = verify_output.draft_extend_input.num_accept_tokens
new_seq_lens = batch.seq_lens + accept_lens

# Update mamba state for hybrid GDN models after verification
Expand All @@ -1062,17 +1065,20 @@ def verify(self, batch: ModelWorkerBatch):
or self.target_worker.model_runner.mamba2_config is not None
):
self._mamba_verify_update(
batch, verify_input, accept_lens, accept_index, bs
batch,
verify_input,
accept_lens,
verify_output.accept_indices,
bs,
)

verify_done = torch.get_device_module(self.device).Event()
verify_done.record()

if not batch.forward_mode.is_idle():
accept_tokens = predict[accept_index]
bonus_tokens = torch.empty_like(accept_lens, dtype=torch.int32)
fill_bonus_tokens[(bs,)](
accept_tokens,
verify_output.accept_tokens,
accept_lens,
bonus_tokens,
self.speculative_num_draft_tokens,
Expand All @@ -1082,7 +1088,11 @@ def verify(self, batch: ModelWorkerBatch):

if batch.return_logprob and not batch.forward_mode.is_idle():
compute_spec_v2_logprobs(
batch, logits_output, predict, accept_index, self.speculative_num_steps
batch,
logits_output,
predict,
verify_output.accept_indices,
self.speculative_num_steps,
)

# Construct the next draft input
Expand All @@ -1092,7 +1102,7 @@ def verify(self, batch: ModelWorkerBatch):
verify_done=verify_done,
)

return GenerationBatchResult(
batch_result = GenerationBatchResult(
logits_output=logits_output,
next_token_ids=predict,
can_run_cuda_graph=can_run_cuda_graph,
Expand All @@ -1102,6 +1112,7 @@ def verify(self, batch: ModelWorkerBatch):
routed_experts_output=forward_batch_output.routed_experts_output,
indexer_topk_output=forward_batch_output.indexer_topk_output,
)
return batch_result, verify_output

def _mamba_verify_update(
self,
Expand Down
Loading
Loading