From 66e8b7abd5b082ec6c9375c3db9c23728619dc1a Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 8 May 2026 23:39:34 -0700 Subject: [PATCH 01/22] introduce EagleDraftExtendInput; split phase from EagleDraftInput --- .../eagle_draft_extend_cuda_graph_runner.py | 4 +- python/sglang/srt/speculative/eagle_info.py | 174 +++++++++++------- python/sglang/srt/speculative/eagle_worker.py | 55 ++++-- .../srt/speculative/frozen_kv_mtp_info.py | 27 ++- .../srt/speculative/frozen_kv_mtp_utils.py | 5 +- .../srt/speculative/frozen_kv_mtp_worker.py | 20 +- ...er_eagle_draft_extend_cuda_graph_runner.py | 4 +- .../speculative/multi_layer_eagle_worker.py | 37 ++-- python/sglang/srt/speculative/spec_info.py | 4 + 9 files changed, 209 insertions(+), 121 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index a5ae5b5b3e89..a2d620d5c8cb 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -26,7 +26,7 @@ ForwardMode, ) from sglang.srt.model_executor.input_buffers import ForwardInputBuffers -from sglang.srt.speculative.eagle_info import EagleDraftInput +from sglang.srt.speculative.eagle_info import EagleDraftExtendInput from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( require_attn_tp_gather, @@ -360,7 +360,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0 else: global_dp_buffer_len = None - spec_info = EagleDraftInput( + spec_info = EagleDraftExtendInput( hidden_states=hidden_states, num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_tokens, diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 205c0610ae38..8a785b385ed9 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -240,15 +240,14 @@ def verify( accepted token logits. """ if batch.forward_mode.is_idle(): - next_draft_input = EagleDraftInput.create_idle_input( + extend_input = EagleDraftExtendInput.create_idle_input( device=batch.device, hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, - topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) return EagleVerifyOutput.create_idle( - next_draft_input=next_draft_input, + extend_input=extend_input, logits_output=logits_output, device=batch.device, spec_steps=self.spec_steps, @@ -545,7 +544,7 @@ def verify( batch.seq_lens.add_(num_accepted_drafts + 1) batch.seq_lens_cpu.add_(num_accepted_tokens_cpu) - next_draft_input = EagleDraftInput( + extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[accept_index], num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_drafts + 1, @@ -553,7 +552,7 @@ def verify( ) return EagleVerifyOutput( - next_draft_input=next_draft_input, + extend_input=extend_input, logits_output=logits_output, accept_tokens=accept_tokens, unfinished_accept_tokens=accept_tokens, @@ -620,7 +619,7 @@ def verify( req_pool_indices_for_draft_extend = batch.req_pool_indices[ unfinished_index_device ] - next_draft_input = EagleDraftInput( + extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[ unfinished_accept_index ], @@ -643,16 +642,15 @@ def verify( dtype=batch.req_pool_indices.dtype, device=batch.req_pool_indices.device, ) - next_draft_input = EagleDraftInput.create_idle_input( + extend_input = EagleDraftExtendInput.create_idle_input( device=batch.device, hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, - topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) return EagleVerifyOutput( - next_draft_input=next_draft_input, + extend_input=extend_input, logits_output=logits_output, accept_tokens=accept_tokens, unfinished_accept_tokens=unfinished_accept_tokens, @@ -677,17 +675,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): hidden_states: torch.Tensor = None capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL - # Inputs for extend - # shape: (b,) - # `num_accepted_drafts` and `num_accepted_tokens` are kept in sync: - # `num_accepted_tokens = num_accepted_drafts + 1` (per-req, one bonus per req). - # Storing both avoids repeated `+ 1` at every consumer (attn backends, kernels). + # Per-req bonus token (the "+1" target prediction at the end of each accept + # chain). Set by `EagleDraftExtendInput.prepare_extend_after_decode`'s kernel + # via the worker's post-extend assembly. bonus_tokens: torch.Tensor = None - num_accepted_drafts: torch.Tensor = None - num_accepted_tokens: torch.Tensor = None - # Read by attention backends during draft-extend forward; kept on the - # dataclass because the backends access it via `forward_batch.spec_info`. - num_accepted_tokens_cpu: List[int] = None # Inputs for the attention backends # shape: (b + 1,) @@ -741,47 +732,6 @@ def create_idle_input( topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), capture_hidden_mode=capture_hidden_mode, new_seq_lens=torch.empty((0,), device=device, dtype=torch.int32), - num_accepted_drafts=torch.empty((0,), device=device, dtype=torch.int32), - num_accepted_tokens=torch.empty((0,), device=device, dtype=torch.int32), - num_accepted_tokens_cpu=[], - ) - - def prepare_extend_after_decode( - self, - batch: ScheduleBatch, - verify_output: "EagleVerifyOutput", - speculative_num_steps: int, - ): - - if batch.forward_mode.is_idle(): - return - - # All transient verify->extend handoff state is read off `verify_output`, - # not from `self`. The kernel below populates `self.bonus_tokens` - # ([bs] per-req) for the next decode round; that is the only state on - # `self` that survives past this method. - batch.input_ids = verify_output.unfinished_accept_tokens - batch.extend_lens = batch.spec_info.num_accepted_tokens_cpu - batch.extend_num_tokens = sum(batch.extend_lens) - batch.seq_lens = verify_output.seq_lens_for_draft_extend - batch.seq_lens_cpu = verify_output.seq_lens_for_draft_extend_cpu - batch.req_pool_indices = verify_output.req_pool_indices_for_draft_extend - batch.return_logprob = False - batch.return_hidden_states = False - - self.capture_hidden_mode = CaptureHiddenMode.LAST - self.positions = torch.empty_like(batch.input_ids, dtype=torch.long) - self.bonus_tokens = torch.empty_like( - self.num_accepted_tokens, dtype=torch.int32 - ) - - create_extend_after_decode_spec_info[(len(batch.seq_lens),)]( - batch.input_ids, - batch.seq_lens, - self.num_accepted_tokens, - self.positions, - self.bonus_tokens, - next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))), ) def generate_attn_arg_prefill( @@ -871,10 +821,106 @@ def merge_batch(self, spec_info: "EagleDraftInput"): self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) +@dataclass +class EagleDraftExtendInput(SpecInput): + """Inputs to the draft-extend forward (the per-accepted-token pass after verify). + + Lives on `batch.spec_info` only during the draft-extend forward pass. + After draft-extend completes, the worker assembles a fresh `EagleDraftInput` + for next iter and replaces `batch.spec_info`. + + All fields here have single shapes (no phase-shift across phases). + """ + + # shape: (total_accepted, hidden_size). Sliced from verify-time hidden_states + # by accept_index; consumed by the draft-extend forward. + hidden_states: torch.Tensor = None + + # Per-req accept counts. `num_accepted_tokens = num_accepted_drafts + 1` + # (one bonus per req). Both retained for cuda-graph buffer indexing and + # the `create_extend_after_decode_spec_info` kernel. + num_accepted_drafts: torch.Tensor = None + num_accepted_tokens: torch.Tensor = None + # CPU view, read by attention backends during the extend forward. + num_accepted_tokens_cpu: List[int] = None + + # Set by `prepare_extend_after_decode`: + # - positions: kernel-written, shape `[total_accepted]`. + # - bonus_tokens: kernel-written, shape `[bs]`. The worker reads this + # post-extend to populate next iter's `EagleDraftInput.bonus_tokens`. + positions: Optional[torch.Tensor] = None + bonus_tokens: Optional[torch.Tensor] = None + + # Forward-pass config (set during prepare_extend / construction). + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.LAST + num_tokens_per_req: int = -1 + num_tokens_for_logprob_per_req: int = 1 + + def __post_init__(self): + super().__init__(SpecInputType.EAGLE_DRAFT_EXTEND) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.num_tokens_per_req, self.num_tokens_for_logprob_per_req + + @classmethod + def create_idle_input( + cls, + device: torch.device, + hidden_size: int, + dtype: torch.dtype, + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.LAST, + ) -> "EagleDraftExtendInput": + return cls( + hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), + num_accepted_drafts=torch.empty((0,), device=device, dtype=torch.int32), + num_accepted_tokens=torch.empty((0,), device=device, dtype=torch.int32), + num_accepted_tokens_cpu=[], + capture_hidden_mode=capture_hidden_mode, + ) + + def prepare_extend_after_decode( + self, + batch: ScheduleBatch, + verify_output: "EagleVerifyOutput", + speculative_num_steps: int, + ): + if batch.forward_mode.is_idle(): + return + + # All transient verify->extend handoff state is read off `verify_output`. + # The kernel below populates `self.positions` and `self.bonus_tokens`; + # the worker reads `self.bonus_tokens` to construct next iter's + # `EagleDraftInput`. + batch.input_ids = verify_output.unfinished_accept_tokens + batch.extend_lens = self.num_accepted_tokens_cpu + batch.extend_num_tokens = sum(batch.extend_lens) + batch.seq_lens = verify_output.seq_lens_for_draft_extend + batch.seq_lens_cpu = verify_output.seq_lens_for_draft_extend_cpu + batch.req_pool_indices = verify_output.req_pool_indices_for_draft_extend + batch.return_logprob = False + batch.return_hidden_states = False + + self.capture_hidden_mode = CaptureHiddenMode.LAST + self.positions = torch.empty_like(batch.input_ids, dtype=torch.long) + self.bonus_tokens = torch.empty_like( + self.num_accepted_tokens, dtype=torch.int32 + ) + + create_extend_after_decode_spec_info[(len(batch.seq_lens),)]( + batch.input_ids, + batch.seq_lens, + self.num_accepted_tokens, + self.positions, + self.bonus_tokens, + next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))), + ) + + @dataclass class EagleVerifyOutput: - # Next iter's persistent draft state, ready to be installed as `batch.spec_info`. - next_draft_input: EagleDraftInput + # Next iter's draft-extend input, ready to be installed as `batch.spec_info` + # for the draft-extend forward. + extend_input: EagleDraftExtendInput # Logit outputs from target worker. logits_output: LogitsProcessorOutput # All accepted tokens flat across all reqs incl. those that finished this @@ -903,13 +949,13 @@ class EagleVerifyOutput: def create_idle( cls, *, - next_draft_input: EagleDraftInput, + extend_input: EagleDraftExtendInput, logits_output: LogitsProcessorOutput, device: torch.device, spec_steps: int, ) -> "EagleVerifyOutput": return cls( - next_draft_input=next_draft_input, + extend_input=extend_input, logits_output=logits_output, accept_tokens=torch.empty(0, dtype=torch.long, device=device), unfinished_accept_tokens=torch.empty(0, dtype=torch.long, device=device), diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 3c480e9a1bfc..cd30f049a483 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -46,6 +46,7 @@ EAGLEDraftExtendCudaGraphRunner, ) from sglang.srt.speculative.eagle_info import ( + EagleDraftExtendInput, EagleDraftInput, EagleVerifyInput, EagleVerifyOutput, @@ -979,7 +980,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.next_draft_input + batch.spec_info = res.extend_input return logits_output, res, model_worker_batch, can_run_cuda_graph @@ -1109,18 +1110,19 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): - assert isinstance(batch.spec_info, EagleDraftInput) + assert isinstance(batch.spec_info, EagleDraftExtendInput) + extend_input: EagleDraftExtendInput = batch.spec_info + # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() seq_lens_cpu_backup = batch.seq_lens_cpu.clone() req_pool_indices_backup = batch.req_pool_indices - num_accepted_drafts_backup = batch.spec_info.num_accepted_drafts.clone() - num_accepted_tokens_backup = batch.spec_info.num_accepted_tokens.clone() return_logprob_backup = batch.return_logprob input_is_idle = batch.forward_mode.is_idle() if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: + # All reqs finished this verify; swap to an idle ExtendInput. batch = batch.copy() batch.prepare_for_idle() hidden_size = ( @@ -1129,17 +1131,18 @@ def forward_draft_extend_after_decode( and self.eagle_use_aux_hidden_state else self.model_config.spec_hidden_size ) - batch.spec_info = EagleDraftInput.create_idle_input( + extend_input = EagleDraftExtendInput.create_idle_input( device=self.device, hidden_size=hidden_size, dtype=self.model_config.dtype, - topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) + batch.spec_info = extend_input - batch.spec_info.num_tokens_per_req = self.speculative_num_steps + 1 - batch.spec_info.num_tokens_for_logprob_per_req = 1 - batch.spec_info.prepare_extend_after_decode( + # Phase 1: prepare extend (kernel writes extend_input.{positions, bonus_tokens}) + extend_input.num_tokens_per_req = self.speculative_num_steps + 1 + extend_input.num_tokens_for_logprob_per_req = 1 + extend_input.prepare_extend_after_decode( batch, verify_output=verify_output, speculative_num_steps=self.speculative_num_steps, @@ -1161,7 +1164,7 @@ def forward_draft_extend_after_decode( else: forward_batch.seq_lens_sum = batch.seq_lens.sum().item() - # Run + # Phase 2: run draft-extend forward can_cuda_graph = ( self.cuda_graph_runner_for_draft_extend and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) @@ -1170,11 +1173,10 @@ def forward_draft_extend_after_decode( logits_output = self.cuda_graph_runner_for_draft_extend.replay( forward_batch ) - forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = ( - logits_output.topk_p, - logits_output.topk_index, - ) - forward_batch.spec_info.hidden_states = logits_output.hidden_states + # cuda-graph replay populates logits_output.{topk_p, topk_index, hidden_states}. + topk_p = logits_output.topk_p + topk_index = logits_output.topk_index + hidden_states = logits_output.hidden_states else: forward_batch.can_run_dp_cuda_graph = False if not forward_batch.forward_mode.is_idle(): @@ -1187,23 +1189,36 @@ def forward_draft_extend_after_decode( logits_output = self.draft_model_runner.forward( forward_batch, skip_attn_backend_init=True ).logits_output - self.capture_for_decode(logits_output, forward_batch.spec_info) + # Non-cuda-graph path: compute topk_p / topk_index inline (used to be + # `capture_for_decode` which mutated spec_info; we instead carry the + # values to next-iter `EagleDraftInput` assembly below). + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + hidden_states = logits_output.hidden_states maybe_detect_nan( logits_output.next_token_logits, f"draft_extend_after_decode (cuda_graph={can_cuda_graph})", ) - # Restore backup. - # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + # Phase 3: assemble next-iter EagleDraftInput from extend output + next_draft_input = EagleDraftInput( + bonus_tokens=extend_input.bonus_tokens, + hidden_states=hidden_states, + topk_p=topk_p, + topk_index=topk_index, + capture_hidden_mode=CaptureHiddenMode.FULL, + ) + + # Restore batch fields and install the new draft input. + # `seq_lens` etc. were modified by `prepare_extend_after_decode`. batch.forward_mode = ( ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ) batch.seq_lens = seq_lens_backup batch.seq_lens_cpu = seq_lens_cpu_backup batch.req_pool_indices = req_pool_indices_backup - batch.spec_info.num_accepted_drafts = num_accepted_drafts_backup - batch.spec_info.num_accepted_tokens = num_accepted_tokens_backup + batch.spec_info = next_draft_input batch.return_logprob = return_logprob_backup def capture_for_decode( diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_info.py b/python/sglang/srt/speculative/frozen_kv_mtp_info.py index d092446168bd..78fd0bbb5e9c 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_info.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_info.py @@ -18,6 +18,7 @@ from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.speculative.eagle_info import ( + EagleDraftExtendInput, EagleDraftInput, EagleVerifyInput, EagleVerifyOutput, @@ -53,6 +54,14 @@ def __post_init__(self): SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_DRAFT) +@dataclass +class FrozenKVMTPDraftExtendInput(EagleDraftExtendInput): + """Draft-extend input for Frozen-KV MTP. Tag-only subclass.""" + + def __post_init__(self): + SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_DRAFT_EXTEND) + + @dataclass class FrozenKVMTPVerifyInput(EagleVerifyInput): """Verify input for Frozen-KV MTP.""" @@ -62,21 +71,21 @@ def __post_init__(self): def verify(self, *args, **kwargs) -> EagleVerifyOutput: output = super().verify(*args, **kwargs) - output.next_draft_input = _to_frozen_kv_mtp_draft_input(output.next_draft_input) + output.extend_input = _to_frozen_kv_mtp_draft_extend_input(output.extend_input) return output FrozenKVMTPVerifyOutput = EagleVerifyOutput -def _to_frozen_kv_mtp_draft_input( - draft_input: EagleDraftInput, -) -> FrozenKVMTPDraftInput: - if isinstance(draft_input, FrozenKVMTPDraftInput): - return draft_input - return FrozenKVMTPDraftInput( +def _to_frozen_kv_mtp_draft_extend_input( + extend_input: EagleDraftExtendInput, +) -> FrozenKVMTPDraftExtendInput: + if isinstance(extend_input, FrozenKVMTPDraftExtendInput): + return extend_input + return FrozenKVMTPDraftExtendInput( **{ - field.name: getattr(draft_input, field.name) - for field in fields(EagleDraftInput) + field.name: getattr(extend_input, field.name) + for field in fields(EagleDraftExtendInput) } ) diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py index 05512ddf555c..cf591f1b33f6 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py @@ -14,7 +14,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import Tuple +from typing import Tuple, Union import torch @@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.speculative.frozen_kv_mtp_info import ( FrozenKVMTPContext, + FrozenKVMTPDraftExtendInput, FrozenKVMTPDraftInput, ) from sglang.srt.speculative.spec_utils import fast_topk @@ -134,7 +135,7 @@ def select_last_extend_hidden( def select_last_verified_seed( - draft_input: FrozenKVMTPDraftInput, + draft_input: Union[FrozenKVMTPDraftInput, FrozenKVMTPDraftExtendInput], ) -> Tuple[torch.Tensor, torch.Tensor]: if draft_input.num_accepted_tokens is None: return draft_input.bonus_tokens, draft_input.hidden_states diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index 9039577cc976..e7bd590c2d13 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -50,6 +50,7 @@ ) from sglang.srt.speculative.frozen_kv_mtp_info import ( FrozenKVMTPContext, + FrozenKVMTPDraftExtendInput, FrozenKVMTPDraftInput, FrozenKVMTPVerifyInput, FrozenKVMTPVerifyOutput, @@ -462,7 +463,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): if ( self.server_args.enable_dp_attention - or batch.spec_info.bonus_tokens.numel() + or verify_output.unfinished_accept_tokens.numel() > 0 ): self.forward_draft_extend_after_decode(batch, verify_output) set_time_batch(batch.reqs, "set_spec_draft_extend_end_time", trace_only=True) @@ -508,9 +509,13 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ) -> None: - assert isinstance(batch.spec_info, FrozenKVMTPDraftInput) + assert isinstance(batch.spec_info, FrozenKVMTPDraftExtendInput) + extend_input: FrozenKVMTPDraftExtendInput = batch.spec_info input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle and batch.spec_info.bonus_tokens.numel() == 0: + + if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: + # All reqs finished. Install an idle FrozenKVMTPDraftInput so the + # next-iter draft sees a valid spec_info. batch = batch.copy() batch.prepare_for_idle() batch.spec_info = FrozenKVMTPDraftInput.create_idle_input( @@ -520,11 +525,11 @@ def forward_draft_extend_after_decode( topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) + return if batch.forward_mode.is_idle(): return - draft_input = batch.spec_info seq_lens_backup = batch.seq_lens.clone() seq_lens_cpu_backup = batch.seq_lens_cpu.clone() req_pool_indices_backup = batch.req_pool_indices @@ -536,13 +541,14 @@ def forward_draft_extend_after_decode( batch.seq_lens_cpu = verify_output.seq_lens_for_draft_extend_cpu batch.req_pool_indices = verify_output.req_pool_indices_for_draft_extend - last_token_ids, last_hidden = self._select_last_verified_seed(draft_input) + last_token_ids, last_hidden = self._select_last_verified_seed(extend_input) + # `_run_assistant_seed_step` constructs a fresh `FrozenKVMTPDraftInput` + # and installs it on `batch.spec_info` for next iter. self._run_assistant_seed_step( batch, last_token_ids, last_hidden, seq_lens_cpu=verify_output.seq_lens_for_draft_extend_cpu, - draft_input=draft_input, ) finally: batch.seq_lens = seq_lens_backup @@ -768,7 +774,7 @@ def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.next_draft_input + batch.spec_info = res.extend_input del seq_lens_pre_verify return logits_output, res, model_worker_batch, can_run_cuda_graph diff --git a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py index 295c1009bd7d..18050f9e0c2c 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py @@ -41,7 +41,7 @@ ForwardMode, ) from sglang.srt.model_executor.input_buffers import ForwardInputBuffers -from sglang.srt.speculative.eagle_info import EagleDraftInput +from sglang.srt.speculative.eagle_info import EagleDraftExtendInput from sglang.srt.speculative.multi_layer_eagle_utils import assign_new_state_triton from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( @@ -349,7 +349,7 @@ def get_forward_batch(self, bs: int) -> ForwardBatch: else: global_dp_buffer_len = None - spec_info = EagleDraftInput( + spec_info = EagleDraftExtendInput( hidden_states=hidden_states, num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_tokens, diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index b03492905f87..4030ce1d1e80 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -34,6 +34,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.draft_utils import DraftBackendFactory from sglang.srt.speculative.eagle_info import ( + EagleDraftExtendInput, EagleDraftInput, EagleVerifyInput, EagleVerifyOutput, @@ -591,7 +592,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.next_draft_input + batch.spec_info = res.extend_input return logits_output, res, model_worker_batch, can_run_cuda_graph @@ -657,13 +658,13 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): - assert isinstance(batch.spec_info, EagleDraftInput) + assert isinstance(batch.spec_info, EagleDraftExtendInput) + extend_input: EagleDraftExtendInput = batch.spec_info + # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() seq_lens_cpu_backup = batch.seq_lens_cpu.clone() req_pool_indices_backup = batch.req_pool_indices - num_accepted_drafts_backup = batch.spec_info.num_accepted_drafts - num_accepted_tokens_backup = batch.spec_info.num_accepted_tokens return_logprob_backup = batch.return_logprob input_is_idle = batch.forward_mode.is_idle() @@ -676,17 +677,18 @@ def forward_draft_extend_after_decode( if self.speculative_algorithm.is_eagle3() else self.model_config.hidden_size ) - batch.spec_info = EagleDraftInput.create_idle_input( + extend_input = EagleDraftExtendInput.create_idle_input( device=self.device, hidden_size=hidden_size, dtype=self.model_config.dtype, - topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) + batch.spec_info = extend_input - batch.spec_info.num_tokens_per_req = self.speculative_num_steps + 1 - batch.spec_info.num_tokens_for_logprob_per_req = 1 - batch.spec_info.prepare_extend_after_decode( + # Phase 1: prepare extend (kernel writes extend_input.{positions, bonus_tokens}) + extend_input.num_tokens_per_req = self.speculative_num_steps + 1 + extend_input.num_tokens_for_logprob_per_req = 1 + extend_input.prepare_extend_after_decode( batch, verify_output=verify_output, speculative_num_steps=self.speculative_num_steps, @@ -750,17 +752,22 @@ def forward_draft_extend_after_decode( ) pt += extend_len - forward_batch.spec_info.topk_p = torch.cat(topk_p_list, dim=1) - forward_batch.spec_info.topk_index = torch.cat(topk_index_list, dim=1) + # Phase 3: assemble next-iter EagleDraftInput from extend output + next_draft_input = EagleDraftInput( + bonus_tokens=extend_input.bonus_tokens, + hidden_states=logits_output.hidden_states, + topk_p=torch.cat(topk_p_list, dim=1), + topk_index=torch.cat(topk_index_list, dim=1), + capture_hidden_mode=CaptureHiddenMode.FULL, + ) - # Restore backup. - # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + # Restore batch fields and install the new draft input. + # `seq_lens` etc. were modified by `prepare_extend_after_decode`. batch.forward_mode = ( ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ) batch.seq_lens = seq_lens_backup batch.seq_lens_cpu = seq_lens_cpu_backup batch.req_pool_indices = req_pool_indices_backup - batch.spec_info.num_accepted_drafts = num_accepted_drafts_backup - batch.spec_info.num_accepted_tokens = num_accepted_tokens_backup + batch.spec_info = next_draft_input batch.return_logprob = return_logprob_backup diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 8a2588c0c833..8a0565e32f96 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -195,8 +195,10 @@ class SpecInputType(IntEnum): # NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends. # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it EAGLE_DRAFT = auto() + EAGLE_DRAFT_EXTEND = auto() EAGLE_VERIFY = auto() FROZEN_KV_MTP_DRAFT = auto() + FROZEN_KV_MTP_DRAFT_EXTEND = auto() FROZEN_KV_MTP_VERIFY = auto() DFLASH_DRAFT = auto() DFLASH_VERIFY = auto() @@ -212,7 +214,9 @@ def is_draft_input(self) -> bool: # or use another variable name like `draft_input` to substitute `spec_info` return self.spec_input_type in { SpecInputType.EAGLE_DRAFT, + SpecInputType.EAGLE_DRAFT_EXTEND, SpecInputType.FROZEN_KV_MTP_DRAFT, + SpecInputType.FROZEN_KV_MTP_DRAFT_EXTEND, SpecInputType.DFLASH_DRAFT, } From 057b8f79c2dac181a33682092b67ae6befa47c19 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 8 May 2026 23:45:11 -0700 Subject: [PATCH 02/22] rename extend_input -> draft_extend_input --- python/sglang/srt/speculative/eagle_info.py | 20 +++++++++---------- python/sglang/srt/speculative/eagle_worker.py | 18 ++++++++--------- .../srt/speculative/frozen_kv_mtp_info.py | 12 ++++++----- .../srt/speculative/frozen_kv_mtp_worker.py | 8 +++++--- .../speculative/multi_layer_eagle_worker.py | 18 ++++++++--------- 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 8a785b385ed9..a7025b1992ab 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -240,14 +240,14 @@ def verify( accepted token logits. """ if batch.forward_mode.is_idle(): - extend_input = EagleDraftExtendInput.create_idle_input( + draft_extend_input = EagleDraftExtendInput.create_idle_input( device=batch.device, hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, capture_hidden_mode=CaptureHiddenMode.LAST, ) return EagleVerifyOutput.create_idle( - extend_input=extend_input, + draft_extend_input=draft_extend_input, logits_output=logits_output, device=batch.device, spec_steps=self.spec_steps, @@ -544,7 +544,7 @@ def verify( batch.seq_lens.add_(num_accepted_drafts + 1) batch.seq_lens_cpu.add_(num_accepted_tokens_cpu) - extend_input = EagleDraftExtendInput( + draft_extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[accept_index], num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_drafts + 1, @@ -552,7 +552,7 @@ def verify( ) return EagleVerifyOutput( - extend_input=extend_input, + draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=accept_tokens, unfinished_accept_tokens=accept_tokens, @@ -619,7 +619,7 @@ def verify( req_pool_indices_for_draft_extend = batch.req_pool_indices[ unfinished_index_device ] - extend_input = EagleDraftExtendInput( + draft_extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[ unfinished_accept_index ], @@ -642,7 +642,7 @@ def verify( dtype=batch.req_pool_indices.dtype, device=batch.req_pool_indices.device, ) - extend_input = EagleDraftExtendInput.create_idle_input( + draft_extend_input = EagleDraftExtendInput.create_idle_input( device=batch.device, hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, @@ -650,7 +650,7 @@ def verify( ) return EagleVerifyOutput( - extend_input=extend_input, + draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=accept_tokens, unfinished_accept_tokens=unfinished_accept_tokens, @@ -920,7 +920,7 @@ def prepare_extend_after_decode( class EagleVerifyOutput: # Next iter's draft-extend input, ready to be installed as `batch.spec_info` # for the draft-extend forward. - extend_input: EagleDraftExtendInput + draft_extend_input: EagleDraftExtendInput # Logit outputs from target worker. logits_output: LogitsProcessorOutput # All accepted tokens flat across all reqs incl. those that finished this @@ -949,13 +949,13 @@ class EagleVerifyOutput: def create_idle( cls, *, - extend_input: EagleDraftExtendInput, + draft_extend_input: EagleDraftExtendInput, logits_output: LogitsProcessorOutput, device: torch.device, spec_steps: int, ) -> "EagleVerifyOutput": return cls( - extend_input=extend_input, + draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=torch.empty(0, dtype=torch.long, device=device), unfinished_accept_tokens=torch.empty(0, dtype=torch.long, device=device), diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index cd30f049a483..7db2e613197f 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -980,7 +980,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.extend_input + batch.spec_info = res.draft_extend_input return logits_output, res, model_worker_batch, can_run_cuda_graph @@ -1111,7 +1111,7 @@ def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): assert isinstance(batch.spec_info, EagleDraftExtendInput) - extend_input: EagleDraftExtendInput = batch.spec_info + draft_extend_input: EagleDraftExtendInput = batch.spec_info # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() @@ -1131,18 +1131,18 @@ def forward_draft_extend_after_decode( and self.eagle_use_aux_hidden_state else self.model_config.spec_hidden_size ) - extend_input = EagleDraftExtendInput.create_idle_input( + draft_extend_input = EagleDraftExtendInput.create_idle_input( device=self.device, hidden_size=hidden_size, dtype=self.model_config.dtype, capture_hidden_mode=CaptureHiddenMode.LAST, ) - batch.spec_info = extend_input + batch.spec_info = draft_extend_input - # Phase 1: prepare extend (kernel writes extend_input.{positions, bonus_tokens}) - extend_input.num_tokens_per_req = self.speculative_num_steps + 1 - extend_input.num_tokens_for_logprob_per_req = 1 - extend_input.prepare_extend_after_decode( + # Phase 1: prepare extend (kernel writes draft_extend_input.{positions, bonus_tokens}) + draft_extend_input.num_tokens_per_req = self.speculative_num_steps + 1 + draft_extend_input.num_tokens_for_logprob_per_req = 1 + draft_extend_input.prepare_extend_after_decode( batch, verify_output=verify_output, speculative_num_steps=self.speculative_num_steps, @@ -1203,7 +1203,7 @@ def forward_draft_extend_after_decode( # Phase 3: assemble next-iter EagleDraftInput from extend output next_draft_input = EagleDraftInput( - bonus_tokens=extend_input.bonus_tokens, + bonus_tokens=draft_extend_input.bonus_tokens, hidden_states=hidden_states, topk_p=topk_p, topk_index=topk_index, diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_info.py b/python/sglang/srt/speculative/frozen_kv_mtp_info.py index 78fd0bbb5e9c..7b562b52094f 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_info.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_info.py @@ -71,7 +71,9 @@ def __post_init__(self): def verify(self, *args, **kwargs) -> EagleVerifyOutput: output = super().verify(*args, **kwargs) - output.extend_input = _to_frozen_kv_mtp_draft_extend_input(output.extend_input) + output.draft_extend_input = _to_frozen_kv_mtp_draft_extend_input( + output.draft_extend_input + ) return output @@ -79,13 +81,13 @@ def verify(self, *args, **kwargs) -> EagleVerifyOutput: def _to_frozen_kv_mtp_draft_extend_input( - extend_input: EagleDraftExtendInput, + draft_extend_input: EagleDraftExtendInput, ) -> FrozenKVMTPDraftExtendInput: - if isinstance(extend_input, FrozenKVMTPDraftExtendInput): - return extend_input + if isinstance(draft_extend_input, FrozenKVMTPDraftExtendInput): + return draft_extend_input return FrozenKVMTPDraftExtendInput( **{ - field.name: getattr(extend_input, field.name) + field.name: getattr(draft_extend_input, field.name) for field in fields(EagleDraftExtendInput) } ) diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index e7bd590c2d13..c7df723780d1 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -510,7 +510,7 @@ def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ) -> None: assert isinstance(batch.spec_info, FrozenKVMTPDraftExtendInput) - extend_input: FrozenKVMTPDraftExtendInput = batch.spec_info + draft_extend_input: FrozenKVMTPDraftExtendInput = batch.spec_info input_is_idle = batch.forward_mode.is_idle() if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: @@ -541,7 +541,9 @@ def forward_draft_extend_after_decode( batch.seq_lens_cpu = verify_output.seq_lens_for_draft_extend_cpu batch.req_pool_indices = verify_output.req_pool_indices_for_draft_extend - last_token_ids, last_hidden = self._select_last_verified_seed(extend_input) + last_token_ids, last_hidden = self._select_last_verified_seed( + draft_extend_input + ) # `_run_assistant_seed_step` constructs a fresh `FrozenKVMTPDraftInput` # and installs it on `batch.spec_info` for next iter. self._run_assistant_seed_step( @@ -774,7 +776,7 @@ def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.extend_input + batch.spec_info = res.draft_extend_input del seq_lens_pre_verify return logits_output, res, model_worker_batch, can_run_cuda_graph diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 4030ce1d1e80..55bb94f6f9e9 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -592,7 +592,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.extend_input + batch.spec_info = res.draft_extend_input return logits_output, res, model_worker_batch, can_run_cuda_graph @@ -659,7 +659,7 @@ def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): assert isinstance(batch.spec_info, EagleDraftExtendInput) - extend_input: EagleDraftExtendInput = batch.spec_info + draft_extend_input: EagleDraftExtendInput = batch.spec_info # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() @@ -677,18 +677,18 @@ def forward_draft_extend_after_decode( if self.speculative_algorithm.is_eagle3() else self.model_config.hidden_size ) - extend_input = EagleDraftExtendInput.create_idle_input( + draft_extend_input = EagleDraftExtendInput.create_idle_input( device=self.device, hidden_size=hidden_size, dtype=self.model_config.dtype, capture_hidden_mode=CaptureHiddenMode.LAST, ) - batch.spec_info = extend_input + batch.spec_info = draft_extend_input - # Phase 1: prepare extend (kernel writes extend_input.{positions, bonus_tokens}) - extend_input.num_tokens_per_req = self.speculative_num_steps + 1 - extend_input.num_tokens_for_logprob_per_req = 1 - extend_input.prepare_extend_after_decode( + # Phase 1: prepare extend (kernel writes draft_extend_input.{positions, bonus_tokens}) + draft_extend_input.num_tokens_per_req = self.speculative_num_steps + 1 + draft_extend_input.num_tokens_for_logprob_per_req = 1 + draft_extend_input.prepare_extend_after_decode( batch, verify_output=verify_output, speculative_num_steps=self.speculative_num_steps, @@ -754,7 +754,7 @@ def forward_draft_extend_after_decode( # Phase 3: assemble next-iter EagleDraftInput from extend output next_draft_input = EagleDraftInput( - bonus_tokens=extend_input.bonus_tokens, + bonus_tokens=draft_extend_input.bonus_tokens, hidden_states=logits_output.hidden_states, topk_p=torch.cat(topk_p_list, dim=1), topk_index=torch.cat(topk_index_list, dim=1), From f5b85fd898205b663e92db1d1f7f6cbd246db75f Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 00:02:13 -0700 Subject: [PATCH 03/22] add isinstance asserts at draft_extend_input phase boundaries --- python/sglang/srt/speculative/eagle_info.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index a7025b1992ab..e62ec09059c5 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -884,6 +884,8 @@ def prepare_extend_after_decode( verify_output: "EagleVerifyOutput", speculative_num_steps: int, ): + # Caller must have installed `self` as `batch.spec_info` before calling. + assert batch.spec_info is self if batch.forward_mode.is_idle(): return From 063468033abaa75f1aee7dc31b495a28ea516383 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Fri, 8 May 2026 23:55:38 -0700 Subject: [PATCH 04/22] move draft_extend_input install out of verify() into forward_draft_extend_after_decode --- python/sglang/srt/speculative/eagle_worker.py | 7 ++++--- python/sglang/srt/speculative/frozen_kv_mtp_worker.py | 9 ++++++--- .../sglang/srt/speculative/multi_layer_eagle_worker.py | 7 ++++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 7db2e613197f..311d664db319 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -980,7 +980,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.draft_extend_input return logits_output, res, model_worker_batch, can_run_cuda_graph @@ -1110,8 +1109,10 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): - assert isinstance(batch.spec_info, EagleDraftExtendInput) - draft_extend_input: EagleDraftExtendInput = batch.spec_info + # Install the draft-extend input as `batch.spec_info` for this method's + # forward pass. Replaced with a fresh `EagleDraftInput` post-extend. + draft_extend_input: EagleDraftExtendInput = verify_output.draft_extend_input + batch.spec_info = draft_extend_input # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index c7df723780d1..fe1ecbf48d4e 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -509,8 +509,12 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ) -> None: - assert isinstance(batch.spec_info, FrozenKVMTPDraftExtendInput) - draft_extend_input: FrozenKVMTPDraftExtendInput = batch.spec_info + # Install the draft-extend input as `batch.spec_info` for the seed step. + # `_run_assistant_seed_step` will replace it with a fresh + # `FrozenKVMTPDraftInput` for the next iter's draft. + draft_extend_input = verify_output.draft_extend_input + assert isinstance(draft_extend_input, FrozenKVMTPDraftExtendInput) + batch.spec_info = draft_extend_input input_is_idle = batch.forward_mode.is_idle() if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: @@ -776,7 +780,6 @@ def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.draft_extend_input del seq_lens_pre_verify return logits_output, res, model_worker_batch, can_run_cuda_graph diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 55bb94f6f9e9..4cecab6d16c0 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -592,7 +592,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.draft_extend_input return logits_output, res, model_worker_batch, can_run_cuda_graph @@ -658,8 +657,10 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): - assert isinstance(batch.spec_info, EagleDraftExtendInput) - draft_extend_input: EagleDraftExtendInput = batch.spec_info + # Install the draft-extend input as `batch.spec_info` for this method's + # forward pass. Replaced with a fresh `EagleDraftInput` post-extend. + draft_extend_input: EagleDraftExtendInput = verify_output.draft_extend_input + batch.spec_info = draft_extend_input # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() From 1fc71dd00a84e2115c2da56920551ff23a7fb117 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 00:22:37 -0700 Subject: [PATCH 05/22] move spec_info phase install to executor (forward_batch_generation) --- python/sglang/srt/speculative/eagle_worker.py | 23 +++++++++++++------ .../srt/speculative/frozen_kv_mtp_worker.py | 19 +++++++++++---- .../speculative/multi_layer_eagle_worker.py | 23 +++++++++++++------ 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 311d664db319..45ba94afab86 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -481,13 +481,15 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.draft_model_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): - spec_info = self.draft(batch) + verify_input = self.draft(batch) set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) + # Install verify_input as `batch.spec_info` for the verify forward. + batch.spec_info = verify_input logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( - self.verify(batch, spec_info) + self.verify(batch, verify_input) ) if get_global_tracing_enabled(): @@ -509,6 +511,10 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul or verify_output.unfinished_accept_tokens.shape[0] > 0 ): # decode is not finished + # Install draft_extend_input as `batch.spec_info` for the + # draft-extend forward (replaced post-extend with a fresh + # EagleDraftInput by `forward_draft_extend_after_decode`). + batch.spec_info = verify_output.draft_extend_input self.forward_draft_extend_after_decode(batch, verify_output) set_time_batch( @@ -894,6 +900,9 @@ def clear_cache_pool(self): pass def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): + # Caller (forward_batch_generation) is responsible for installing + # `spec_info` as `batch.spec_info` before calling. + assert batch.spec_info is spec_info seq_lens_pre_verify = batch.seq_lens.clone() spec_info.prepare_for_verify(batch, self.page_size) spec_info.num_tokens_per_req = self.speculative_num_steps + 1 @@ -903,7 +912,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu @@ -1109,10 +1117,11 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): - # Install the draft-extend input as `batch.spec_info` for this method's - # forward pass. Replaced with a fresh `EagleDraftInput` post-extend. - draft_extend_input: EagleDraftExtendInput = verify_output.draft_extend_input - batch.spec_info = draft_extend_input + # Caller (forward_batch_generation) is responsible for installing + # verify_output.draft_extend_input as batch.spec_info before calling. + draft_extend_input = verify_output.draft_extend_input + assert isinstance(draft_extend_input, EagleDraftExtendInput) + assert batch.spec_info is draft_extend_input # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index fe1ecbf48d4e..ce443ed4efee 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -444,12 +444,14 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.draft_model_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): - spec_info = self.draft(batch) + verify_input = self.draft(batch) set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) + # Install verify_input as `batch.spec_info` for the verify forward. + batch.spec_info = verify_input logits_output, verify_output, _, can_run_cuda_graph = self.verify( - batch, spec_info + batch, verify_input ) if get_global_tracing_enabled(): @@ -465,6 +467,10 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul self.server_args.enable_dp_attention or verify_output.unfinished_accept_tokens.numel() > 0 ): + # Install draft_extend_input as `batch.spec_info` for the seed + # step (`_run_assistant_seed_step` replaces it with a fresh + # `FrozenKVMTPDraftInput` for next iter). + batch.spec_info = verify_output.draft_extend_input self.forward_draft_extend_after_decode(batch, verify_output) set_time_batch(batch.reqs, "set_spec_draft_extend_end_time", trace_only=True) @@ -509,12 +515,13 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ) -> None: - # Install the draft-extend input as `batch.spec_info` for the seed step. + # Caller (forward_batch_generation) is responsible for installing + # verify_output.draft_extend_input as batch.spec_info before calling. # `_run_assistant_seed_step` will replace it with a fresh # `FrozenKVMTPDraftInput` for the next iter's draft. draft_extend_input = verify_output.draft_extend_input assert isinstance(draft_extend_input, FrozenKVMTPDraftExtendInput) - batch.spec_info = draft_extend_input + assert batch.spec_info is draft_extend_input input_is_idle = batch.forward_mode.is_idle() if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: @@ -702,6 +709,9 @@ def draft_forward( ) def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): + # Caller (forward_batch_generation) is responsible for installing + # `spec_info` as `batch.spec_info` before calling. + assert batch.spec_info is spec_info seq_lens_pre_verify = batch.seq_lens.clone() spec_info.prepare_for_verify(batch, self.page_size) spec_info.num_tokens_per_req = self.speculative_num_steps + 1 @@ -711,7 +721,6 @@ def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 4cecab6d16c0..b70812a76a49 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -272,9 +272,11 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.mtp_model_runner(0).tp_group ), speculative_moe_backend_context(): - spec_info = self.draft(batch) + verify_input = self.draft(batch) + # Install verify_input as `batch.spec_info` for the verify forward. + batch.spec_info = verify_input logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( - self.verify(batch, spec_info) + self.verify(batch, verify_input) ) with self.draft_tp_context( @@ -287,6 +289,10 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul or verify_output.unfinished_accept_tokens.shape[0] > 0 ): # decode is not finished + # Install draft_extend_input as `batch.spec_info` for the + # draft-extend forward (replaced post-extend with a fresh + # EagleDraftInput by `forward_draft_extend_after_decode`). + batch.spec_info = verify_output.draft_extend_input self.forward_draft_extend_after_decode(batch, verify_output) return GenerationBatchResult( @@ -475,6 +481,9 @@ def clear_cache_pool(self): pass def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): + # Caller (forward_batch_generation) is responsible for installing + # `spec_info` as `batch.spec_info` before calling. + assert batch.spec_info is spec_info spec_info.prepare_for_verify(batch, self.page_size) batch.return_hidden_states = False batch.forward_mode = ( @@ -482,7 +491,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu @@ -657,10 +665,11 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): - # Install the draft-extend input as `batch.spec_info` for this method's - # forward pass. Replaced with a fresh `EagleDraftInput` post-extend. - draft_extend_input: EagleDraftExtendInput = verify_output.draft_extend_input - batch.spec_info = draft_extend_input + # Caller (forward_batch_generation) is responsible for installing + # verify_output.draft_extend_input as batch.spec_info before calling. + draft_extend_input = verify_output.draft_extend_input + assert isinstance(draft_extend_input, EagleDraftExtendInput) + assert batch.spec_info is draft_extend_input # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() From 88fffbec107dfa2be749f51af982516fd6187f05 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 00:27:29 -0700 Subject: [PATCH 06/22] V1: forward_draft_extend_after_decode returns next_draft_input; executor installs --- python/sglang/srt/speculative/eagle_worker.py | 19 +++++++++++-------- .../speculative/multi_layer_eagle_worker.py | 19 +++++++++++-------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 45ba94afab86..582eb84494d5 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -511,11 +511,13 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul or verify_output.unfinished_accept_tokens.shape[0] > 0 ): # decode is not finished - # Install draft_extend_input as `batch.spec_info` for the - # draft-extend forward (replaced post-extend with a fresh - # EagleDraftInput by `forward_draft_extend_after_decode`). + # Install draft_extend_input for the extend forward, then + # install the assembled next-iter EagleDraftInput it returns. batch.spec_info = verify_output.draft_extend_input - self.forward_draft_extend_after_decode(batch, verify_output) + next_draft_input = self.forward_draft_extend_after_decode( + batch, verify_output + ) + batch.spec_info = next_draft_input set_time_batch( batch.reqs, "set_spec_draft_extend_end_time", trace_only=True @@ -1116,7 +1118,7 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ): + ) -> EagleDraftInput: # Caller (forward_batch_generation) is responsible for installing # verify_output.draft_extend_input as batch.spec_info before calling. draft_extend_input = verify_output.draft_extend_input @@ -1220,16 +1222,17 @@ def forward_draft_extend_after_decode( capture_hidden_mode=CaptureHiddenMode.FULL, ) - # Restore batch fields and install the new draft input. - # `seq_lens` etc. were modified by `prepare_extend_after_decode`. + # Restore batch fields. `seq_lens` etc. were modified by + # `prepare_extend_after_decode`. Caller installs `next_draft_input` as + # `batch.spec_info`. batch.forward_mode = ( ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ) batch.seq_lens = seq_lens_backup batch.seq_lens_cpu = seq_lens_cpu_backup batch.req_pool_indices = req_pool_indices_backup - batch.spec_info = next_draft_input batch.return_logprob = return_logprob_backup + return next_draft_input def capture_for_decode( self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index b70812a76a49..f15c74fba9cf 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -289,11 +289,13 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul or verify_output.unfinished_accept_tokens.shape[0] > 0 ): # decode is not finished - # Install draft_extend_input as `batch.spec_info` for the - # draft-extend forward (replaced post-extend with a fresh - # EagleDraftInput by `forward_draft_extend_after_decode`). + # Install draft_extend_input for the extend forward, then + # install the assembled next-iter EagleDraftInput it returns. batch.spec_info = verify_output.draft_extend_input - self.forward_draft_extend_after_decode(batch, verify_output) + next_draft_input = self.forward_draft_extend_after_decode( + batch, verify_output + ) + batch.spec_info = next_draft_input return GenerationBatchResult( logits_output=logits_output, @@ -664,7 +666,7 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ): + ) -> EagleDraftInput: # Caller (forward_batch_generation) is responsible for installing # verify_output.draft_extend_input as batch.spec_info before calling. draft_extend_input = verify_output.draft_extend_input @@ -771,13 +773,14 @@ def forward_draft_extend_after_decode( capture_hidden_mode=CaptureHiddenMode.FULL, ) - # Restore batch fields and install the new draft input. - # `seq_lens` etc. were modified by `prepare_extend_after_decode`. + # Restore batch fields. `seq_lens` etc. were modified by + # `prepare_extend_after_decode`. Caller installs `next_draft_input` as + # `batch.spec_info`. batch.forward_mode = ( ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ) batch.seq_lens = seq_lens_backup batch.seq_lens_cpu = seq_lens_cpu_backup batch.req_pool_indices = req_pool_indices_backup - batch.spec_info = next_draft_input batch.return_logprob = return_logprob_backup + return next_draft_input From 958365e52bf8b918fff6c8ab8736269e5f539a01 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 01:52:35 -0700 Subject: [PATCH 07/22] drop unused model_worker_batch from verify() return --- python/sglang/srt/speculative/eagle_worker.py | 6 +++--- python/sglang/srt/speculative/frozen_kv_mtp_worker.py | 4 ++-- python/sglang/srt/speculative/multi_layer_eagle_worker.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 582eb84494d5..ce5f3062fdf9 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -488,8 +488,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input - logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( - self.verify(batch, verify_input) + logits_output, verify_output, can_run_cuda_graph = self.verify( + batch, verify_input ) if get_global_tracing_enabled(): @@ -991,7 +991,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - return logits_output, res, model_worker_batch, can_run_cuda_graph + return logits_output, res, can_run_cuda_graph def _mamba_verify_update( self, diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index ce443ed4efee..c77dc334c9b6 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -450,7 +450,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input - logits_output, verify_output, _, can_run_cuda_graph = self.verify( + logits_output, verify_output, can_run_cuda_graph = self.verify( batch, verify_input ) @@ -791,4 +791,4 @@ def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): ) del seq_lens_pre_verify - return logits_output, res, model_worker_batch, can_run_cuda_graph + return logits_output, res, can_run_cuda_graph diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index f15c74fba9cf..8baf0cea8a3e 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -275,8 +275,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul verify_input = self.draft(batch) # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input - logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( - self.verify(batch, verify_input) + logits_output, verify_output, can_run_cuda_graph = self.verify( + batch, verify_input ) with self.draft_tp_context( @@ -603,7 +603,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - return logits_output, res, model_worker_batch, can_run_cuda_graph + return logits_output, res, can_run_cuda_graph def forward_draft_extend( self, From bfabaa76386cf6935db500f7eefadb93f0415641 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 01:54:57 -0700 Subject: [PATCH 08/22] drop redundant spec_info.positions = None --- .../srt/speculative/eagle_draft_extend_cuda_graph_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index a2d620d5c8cb..2da921adf695 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -365,7 +365,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0 num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_tokens, ) - spec_info.positions = None self.deepep_adapter.capture(is_extend_in_batch=True) From 74c9ca96627ce4fb9b0ef81793960101c2c43dc4 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 02:01:33 -0700 Subject: [PATCH 09/22] drop stale draft_extend shape note on EagleDraftInput.hidden_states --- python/sglang/srt/speculative/eagle_info.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index e62ec09059c5..a589e971c100 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -668,10 +668,7 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): # shape: (b, topk) topk_p: torch.Tensor = None topk_index: torch.Tensor = None - # shape: (b, hidden_size) when consumed by `draft` forward (one hidden per req); - # shape: (total_accepted, hidden_size) when consumed by `draft_extend` forward - # (one hidden per accepted token). Workers maintain this invariant locally; - # there is no type-level guard. Don't add new readers without checking phase. + # shape: (b, hidden_size) - one hidden per req, consumed by `draft` forward. hidden_states: torch.Tensor = None capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL From 5f80483cfebb9f299d10e5c1174d16786a482e41 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 02:09:25 -0700 Subject: [PATCH 10/22] move generate_attn_arg_prefill to EagleDraftExtendInput; tighten select_last_verified_seed --- python/sglang/srt/speculative/eagle_info.py | 64 +++++++++---------- .../srt/speculative/frozen_kv_mtp_utils.py | 7 +- .../srt/speculative/frozen_kv_mtp_worker.py | 2 +- 3 files changed, 35 insertions(+), 38 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index a589e971c100..ca2e662a0a45 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -731,38 +731,6 @@ def create_idle_input( new_seq_lens=torch.empty((0,), device=device, dtype=torch.int32), ) - def generate_attn_arg_prefill( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - paged_kernel_lens_sum: int, - req_to_token: torch.Tensor, - ): - device = req_pool_indices.device - bs = self.num_accepted_drafts.numel() - qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device) - qo_indptr[1:] = torch.cumsum(self.num_accepted_tokens, dim=0) - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device) - cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - if paged_kernel_lens_sum is None: - paged_kernel_lens_sum = cum_kv_seq_len[-1] - - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device=device - ) - - create_flashinfer_kv_indices_triton[(bs,)]( - req_to_token, - req_pool_indices, - paged_kernel_lens, - cum_kv_seq_len, - None, - kv_indices, - req_to_token.size(1), - ) - return kv_indices, cum_kv_seq_len, qo_indptr, None - def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): if self.future_indices is not None: self.future_indices.indices = self.future_indices.indices[new_indices] @@ -914,6 +882,38 @@ def prepare_extend_after_decode( next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))), ) + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: Optional[int], + req_to_token: torch.Tensor, + ): + device = req_pool_indices.device + bs = self.num_accepted_drafts.numel() + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + qo_indptr[1:] = torch.cumsum(self.num_accepted_tokens, dim=0) + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + if paged_kernel_lens_sum is None: + paged_kernel_lens_sum = cum_kv_seq_len[-1] + + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device=device + ) + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, cum_kv_seq_len, qo_indptr, None + @dataclass class EagleVerifyOutput: diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py index cf591f1b33f6..74ff0ef7ee70 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py @@ -14,7 +14,7 @@ from __future__ import annotations from contextlib import contextmanager -from typing import Tuple, Union +from typing import Tuple import torch @@ -135,11 +135,8 @@ def select_last_extend_hidden( def select_last_verified_seed( - draft_input: Union[FrozenKVMTPDraftInput, FrozenKVMTPDraftExtendInput], + draft_input: FrozenKVMTPDraftExtendInput, ) -> Tuple[torch.Tensor, torch.Tensor]: - if draft_input.num_accepted_tokens is None: - return draft_input.bonus_tokens, draft_input.hidden_states - counts = draft_input.num_accepted_tokens.to(torch.long) last_indices = torch.cumsum(counts, dim=0) - 1 return ( diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index c77dc334c9b6..81143f960515 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -336,7 +336,7 @@ def _select_last_extend_hidden( return select_last_extend_hidden(batch, hidden_states) def _select_last_verified_seed( - self, draft_input: FrozenKVMTPDraftInput + self, draft_input: FrozenKVMTPDraftExtendInput ) -> Tuple[torch.Tensor, torch.Tensor]: return select_last_verified_seed(draft_input) From 6ad0e497777f7154cfa227b53a325547c0094c11 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 02:10:24 -0700 Subject: [PATCH 11/22] drop redundant spec_info args; read spec_info from batch --- python/sglang/srt/speculative/eagle_worker.py | 16 ++++------------ .../srt/speculative/frozen_kv_mtp_worker.py | 18 +++++------------- .../speculative/multi_layer_eagle_worker.py | 16 ++++------------ 3 files changed, 13 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index ce5f3062fdf9..28d43230cf3e 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -488,9 +488,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input - logits_output, verify_output, can_run_cuda_graph = self.verify( - batch, verify_input - ) + logits_output, verify_output, can_run_cuda_graph = self.verify(batch) if get_global_tracing_enabled(): for idx, req in enumerate(batch.reqs): @@ -901,10 +899,8 @@ def clear_cache_pool(self): # allocator and kv cache pool are shared with target worker pass - def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): - # Caller (forward_batch_generation) is responsible for installing - # `spec_info` as `batch.spec_info` before calling. - assert batch.spec_info is spec_info + def verify(self, batch: ScheduleBatch): + spec_info: EagleVerifyInput = batch.spec_info seq_lens_pre_verify = batch.seq_lens.clone() spec_info.prepare_for_verify(batch, self.page_size) spec_info.num_tokens_per_req = self.speculative_num_steps + 1 @@ -1119,11 +1115,7 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ) -> EagleDraftInput: - # Caller (forward_batch_generation) is responsible for installing - # verify_output.draft_extend_input as batch.spec_info before calling. - draft_extend_input = verify_output.draft_extend_input - assert isinstance(draft_extend_input, EagleDraftExtendInput) - assert batch.spec_info is draft_extend_input + draft_extend_input: EagleDraftExtendInput = batch.spec_info # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index 81143f960515..0251bf0c8fc7 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -450,9 +450,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input - logits_output, verify_output, can_run_cuda_graph = self.verify( - batch, verify_input - ) + logits_output, verify_output, can_run_cuda_graph = self.verify(batch) if get_global_tracing_enabled(): for idx, req in enumerate(batch.reqs): @@ -515,13 +513,9 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ) -> None: - # Caller (forward_batch_generation) is responsible for installing - # verify_output.draft_extend_input as batch.spec_info before calling. - # `_run_assistant_seed_step` will replace it with a fresh + # `_run_assistant_seed_step` will replace `batch.spec_info` with a fresh # `FrozenKVMTPDraftInput` for the next iter's draft. - draft_extend_input = verify_output.draft_extend_input - assert isinstance(draft_extend_input, FrozenKVMTPDraftExtendInput) - assert batch.spec_info is draft_extend_input + draft_extend_input: FrozenKVMTPDraftExtendInput = batch.spec_info input_is_idle = batch.forward_mode.is_idle() if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: @@ -708,10 +702,8 @@ def draft_forward( score_list, token_list, parents_list, self.speculative_num_draft_tokens ) - def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): - # Caller (forward_batch_generation) is responsible for installing - # `spec_info` as `batch.spec_info` before calling. - assert batch.spec_info is spec_info + def verify(self, batch: ScheduleBatch): + spec_info: FrozenKVMTPVerifyInput = batch.spec_info seq_lens_pre_verify = batch.seq_lens.clone() spec_info.prepare_for_verify(batch, self.page_size) spec_info.num_tokens_per_req = self.speculative_num_steps + 1 diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 8baf0cea8a3e..56953b1566b7 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -275,9 +275,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul verify_input = self.draft(batch) # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input - logits_output, verify_output, can_run_cuda_graph = self.verify( - batch, verify_input - ) + logits_output, verify_output, can_run_cuda_graph = self.verify(batch) with self.draft_tp_context( self.mtp_model_runner(0).tp_group @@ -482,10 +480,8 @@ def clear_cache_pool(self): # allocator and kv cache pool are shared with target worker pass - def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): - # Caller (forward_batch_generation) is responsible for installing - # `spec_info` as `batch.spec_info` before calling. - assert batch.spec_info is spec_info + def verify(self, batch: ScheduleBatch): + spec_info: EagleVerifyInput = batch.spec_info spec_info.prepare_for_verify(batch, self.page_size) batch.return_hidden_states = False batch.forward_mode = ( @@ -667,11 +663,7 @@ def forward_draft_extend( def forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ) -> EagleDraftInput: - # Caller (forward_batch_generation) is responsible for installing - # verify_output.draft_extend_input as batch.spec_info before calling. - draft_extend_input = verify_output.draft_extend_input - assert isinstance(draft_extend_input, EagleDraftExtendInput) - assert batch.spec_info is draft_extend_input + draft_extend_input: EagleDraftExtendInput = batch.spec_info # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() From 1ff8f8f9d7191274c70dd33380ea97e79aca95d3 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 02:23:57 -0700 Subject: [PATCH 12/22] move verify->extend handoff fields onto EagleDraftExtendInput --- python/sglang/srt/speculative/eagle_info.py | 111 ++++++------------ python/sglang/srt/speculative/eagle_worker.py | 20 ++-- .../srt/speculative/frozen_kv_mtp_worker.py | 26 ++-- .../speculative/multi_layer_eagle_worker.py | 16 ++- 4 files changed, 65 insertions(+), 108 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index ca2e662a0a45..73ce3cdc2674 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -549,16 +549,16 @@ def verify( num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_drafts + 1, num_accepted_tokens_cpu=num_accepted_tokens_list, + input_ids=accept_tokens, + seq_lens=batch.seq_lens, + seq_lens_cpu=batch.seq_lens_cpu, + req_pool_indices=batch.req_pool_indices, ) return EagleVerifyOutput( draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=accept_tokens, - unfinished_accept_tokens=accept_tokens, - seq_lens_for_draft_extend=batch.seq_lens, - seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu, - req_pool_indices_for_draft_extend=batch.req_pool_indices, num_accepted_drafts_per_req_cpu=num_accepted_drafts_list, accepted_indices=accept_index, ) @@ -613,12 +613,6 @@ def verify( unfinished_num_accepted_drafts = num_accepted_drafts[ unfinished_index_device ] - unfinished_accept_tokens = predict[unfinished_accept_index] - seq_lens_for_draft_extend = batch.seq_lens[unfinished_index_device] - seq_lens_for_draft_extend_cpu = batch.seq_lens_cpu[unfinished_index] - req_pool_indices_for_draft_extend = batch.req_pool_indices[ - unfinished_index_device - ] draft_extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[ unfinished_accept_index @@ -626,22 +620,12 @@ def verify( num_accepted_tokens_cpu=draft_input_num_accepted_tokens_cpu, num_accepted_drafts=unfinished_num_accepted_drafts, num_accepted_tokens=unfinished_num_accepted_drafts + 1, + input_ids=predict[unfinished_accept_index], + seq_lens=batch.seq_lens[unfinished_index_device], + seq_lens_cpu=batch.seq_lens_cpu[unfinished_index], + req_pool_indices=batch.req_pool_indices[unfinished_index_device], ) else: - unfinished_accept_tokens = torch.empty( - (0,), dtype=accept_tokens.dtype, device=accept_tokens.device - ) - seq_lens_for_draft_extend = torch.empty( - (0,), dtype=batch.seq_lens.dtype, device=batch.seq_lens.device - ) - seq_lens_for_draft_extend_cpu = torch.empty( - (0,), dtype=batch.seq_lens_cpu.dtype - ) - req_pool_indices_for_draft_extend = torch.empty( - (0,), - dtype=batch.req_pool_indices.dtype, - device=batch.req_pool_indices.device, - ) draft_extend_input = EagleDraftExtendInput.create_idle_input( device=batch.device, hidden_size=batch.model_config.spec_hidden_size, @@ -653,10 +637,6 @@ def verify( draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=accept_tokens, - unfinished_accept_tokens=unfinished_accept_tokens, - seq_lens_for_draft_extend=seq_lens_for_draft_extend, - seq_lens_for_draft_extend_cpu=seq_lens_for_draft_extend_cpu, - req_pool_indices_for_draft_extend=req_pool_indices_for_draft_extend, num_accepted_drafts_per_req_cpu=num_accepted_drafts_list, accepted_indices=accept_index, ) @@ -664,7 +644,6 @@ def verify( @dataclass class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): - # The inputs for decode # shape: (b, topk) topk_p: torch.Tensor = None topk_index: torch.Tensor = None @@ -672,21 +651,19 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): hidden_states: torch.Tensor = None capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL - # Per-req bonus token (the "+1" target prediction at the end of each accept - # chain). Set by `EagleDraftExtendInput.prepare_extend_after_decode`'s kernel - # via the worker's post-extend assembly. + # Per-req bonus token (the "+1" target prediction at end of each accept + # chain). Written by `EagleDraftExtendInput.prepare_extend_after_decode`; + # the worker copies it here for next iter's draft. bonus_tokens: torch.Tensor = None - # Inputs for the attention backends # shape: (b + 1,) kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None - # Shape info for padding num_tokens_per_req: int = -1 num_tokens_for_logprob_per_req: int = -1 - # Inputs for V2 overlap worker + # V2 overlap worker only future_indices: Optional[FutureIndices] = None new_seq_lens: Optional[torch.Tensor] = None verify_done: Optional[torch.cuda.Event] = None @@ -790,25 +767,34 @@ def merge_batch(self, spec_info: "EagleDraftInput"): class EagleDraftExtendInput(SpecInput): """Inputs to the draft-extend forward (the per-accepted-token pass after verify). - Lives on `batch.spec_info` only during the draft-extend forward pass. - After draft-extend completes, the worker assembles a fresh `EagleDraftInput` - for next iter and replaces `batch.spec_info`. - - All fields here have single shapes (no phase-shift across phases). + Produced by `EagleVerifyInput.verify`, installed on `batch.spec_info` for + the draft-extend forward, then replaced with a fresh `EagleDraftInput` for + the next iter's draft. """ # shape: (total_accepted, hidden_size). Sliced from verify-time hidden_states # by accept_index; consumed by the draft-extend forward. hidden_states: torch.Tensor = None - # Per-req accept counts. `num_accepted_tokens = num_accepted_drafts + 1` - # (one bonus per req). Both retained for cuda-graph buffer indexing and - # the `create_extend_after_decode_spec_info` kernel. + # Per-req accept counts. `num_accepted_tokens = num_accepted_drafts + 1`. + # Both kept for cuda-graph buffer indexing and the + # `create_extend_after_decode_spec_info` kernel. num_accepted_drafts: torch.Tensor = None num_accepted_tokens: torch.Tensor = None # CPU view, read by attention backends during the extend forward. num_accepted_tokens_cpu: List[int] = None + # Batch-state slices for the draft-extend forward. Set by verify (sliced to + # reqs continuing into next iter). `prepare_extend_after_decode` copies + # these onto `batch.{input_ids, seq_lens, seq_lens_cpu, req_pool_indices}`. + # - input_ids: accept tokens flat over surviving reqs + # - seq_lens / _cpu: per-req sequence length (post-accept) + # - req_pool_indices: per-req kv-pool slot + input_ids: torch.Tensor = None + seq_lens: torch.Tensor = None + seq_lens_cpu: torch.Tensor = None + req_pool_indices: torch.Tensor = None + # Set by `prepare_extend_after_decode`: # - positions: kernel-written, shape `[total_accepted]`. # - bonus_tokens: kernel-written, shape `[bs]`. The worker reads this @@ -816,7 +802,6 @@ class EagleDraftExtendInput(SpecInput): positions: Optional[torch.Tensor] = None bonus_tokens: Optional[torch.Tensor] = None - # Forward-pass config (set during prepare_extend / construction). capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.LAST num_tokens_per_req: int = -1 num_tokens_for_logprob_per_req: int = 1 @@ -840,13 +825,16 @@ def create_idle_input( num_accepted_drafts=torch.empty((0,), device=device, dtype=torch.int32), num_accepted_tokens=torch.empty((0,), device=device, dtype=torch.int32), num_accepted_tokens_cpu=[], + input_ids=torch.empty((0,), device=device, dtype=torch.long), + seq_lens=torch.empty((0,), device=device, dtype=torch.int32), + seq_lens_cpu=torch.empty((0,), dtype=torch.int32), + req_pool_indices=torch.empty((0,), device=device, dtype=torch.int64), capture_hidden_mode=capture_hidden_mode, ) def prepare_extend_after_decode( self, batch: ScheduleBatch, - verify_output: "EagleVerifyOutput", speculative_num_steps: int, ): # Caller must have installed `self` as `batch.spec_info` before calling. @@ -854,16 +842,15 @@ def prepare_extend_after_decode( if batch.forward_mode.is_idle(): return - # All transient verify->extend handoff state is read off `verify_output`. # The kernel below populates `self.positions` and `self.bonus_tokens`; # the worker reads `self.bonus_tokens` to construct next iter's # `EagleDraftInput`. - batch.input_ids = verify_output.unfinished_accept_tokens + batch.input_ids = self.input_ids batch.extend_lens = self.num_accepted_tokens_cpu batch.extend_num_tokens = sum(batch.extend_lens) - batch.seq_lens = verify_output.seq_lens_for_draft_extend - batch.seq_lens_cpu = verify_output.seq_lens_for_draft_extend_cpu - batch.req_pool_indices = verify_output.req_pool_indices_for_draft_extend + batch.seq_lens = self.seq_lens + batch.seq_lens_cpu = self.seq_lens_cpu + batch.req_pool_indices = self.req_pool_indices batch.return_logprob = False batch.return_hidden_states = False @@ -917,28 +904,14 @@ def generate_attn_arg_prefill( @dataclass class EagleVerifyOutput: - # Next iter's draft-extend input, ready to be installed as `batch.spec_info` - # for the draft-extend forward. + # Next iter's draft-extend input, installed as `batch.spec_info` for the + # draft-extend forward. draft_extend_input: EagleDraftExtendInput # Logit outputs from target worker. logits_output: LogitsProcessorOutput # All accepted tokens flat across all reqs incl. those that finished this # step. Includes the bonus token. Used for output processing. accept_tokens: torch.Tensor - # Below are transient handoff fields for the next iter's draft-extend pass. - # They are scoped to the verify -> prepare_extend_after_decode window only; - # `prepare_extend_after_decode` reads them off this object via method arg - # rather than smuggling them through `EagleDraftInput`. - # - # Subset of `accept_tokens` for reqs continuing into next iter's draft-extend - # forward (= `accept_tokens` when no req finished; flat over unfinished - # reqs only otherwise). Becomes `batch.input_ids` for that forward pass. - unfinished_accept_tokens: torch.Tensor - # `batch.seq_lens` / `batch.seq_lens_cpu` / `batch.req_pool_indices` to - # use for the next iter's draft-extend forward; sliced to surviving reqs. - seq_lens_for_draft_extend: torch.Tensor - seq_lens_for_draft_extend_cpu: torch.Tensor - req_pool_indices_for_draft_extend: torch.Tensor # Accepted token length per sequence in a batch in CPU (full set). num_accepted_drafts_per_req_cpu: List[int] # Accepted indices from logits_output.next_token_logits @@ -957,12 +930,6 @@ def create_idle( draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=torch.empty(0, dtype=torch.long, device=device), - unfinished_accept_tokens=torch.empty(0, dtype=torch.long, device=device), - seq_lens_for_draft_extend=torch.empty(0, dtype=torch.int32, device=device), - seq_lens_for_draft_extend_cpu=torch.empty(0, dtype=torch.int32), - req_pool_indices_for_draft_extend=torch.empty( - 0, dtype=torch.int64, device=device - ), num_accepted_drafts_per_req_cpu=[], accepted_indices=torch.full( (0, spec_steps + 1), -1, dtype=torch.int32, device=device diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 28d43230cf3e..f4de48e8f1dd 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -504,17 +504,16 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. + draft_extend_input = verify_output.draft_extend_input if ( self.server_args.enable_dp_attention - or verify_output.unfinished_accept_tokens.shape[0] > 0 + or draft_extend_input.input_ids.shape[0] > 0 ): # decode is not finished # Install draft_extend_input for the extend forward, then # install the assembled next-iter EagleDraftInput it returns. - batch.spec_info = verify_output.draft_extend_input - next_draft_input = self.forward_draft_extend_after_decode( - batch, verify_output - ) + batch.spec_info = draft_extend_input + next_draft_input = self.forward_draft_extend_after_decode(batch) batch.spec_info = next_draft_input set_time_batch( @@ -537,7 +536,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul def check_forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): - local_need_forward = verify_output.unfinished_accept_tokens.shape[0] > 0 + local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward @@ -1113,7 +1112,7 @@ def forward_draft_extend( self.capture_for_decode(logits_output, forward_batch.spec_info) def forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput + self, batch: ScheduleBatch ) -> EagleDraftInput: draft_extend_input: EagleDraftExtendInput = batch.spec_info @@ -1125,7 +1124,7 @@ def forward_draft_extend_after_decode( input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: + if not input_is_idle and draft_extend_input.input_ids.numel() == 0: # All reqs finished this verify; swap to an idle ExtendInput. batch = batch.copy() batch.prepare_for_idle() @@ -1148,7 +1147,6 @@ def forward_draft_extend_after_decode( draft_extend_input.num_tokens_for_logprob_per_req = 1 draft_extend_input.prepare_extend_after_decode( batch, - verify_output=verify_output, speculative_num_steps=self.speculative_num_steps, ) batch.forward_mode = ( @@ -1193,9 +1191,7 @@ def forward_draft_extend_after_decode( logits_output = self.draft_model_runner.forward( forward_batch, skip_attn_backend_init=True ).logits_output - # Non-cuda-graph path: compute topk_p / topk_index inline (used to be - # `capture_for_decode` which mutated spec_info; we instead carry the - # values to next-iter `EagleDraftInput` assembly below). + # Non-cuda-graph path: compute topk_p / topk_index inline. probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) hidden_states = logits_output.hidden_states diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index 0251bf0c8fc7..cb0aa68c66e0 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -43,7 +43,6 @@ from sglang.srt.observability.req_time_stats import set_time_batch from sglang.srt.observability.trace import get_global_tracing_enabled from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.eagle_info import EagleVerifyOutput from sglang.srt.speculative.eagle_utils import ( build_tree_kernel_efficient, organize_draft_results, @@ -461,15 +460,16 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.draft_model_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + draft_extend_input = verify_output.draft_extend_input if ( self.server_args.enable_dp_attention - or verify_output.unfinished_accept_tokens.numel() > 0 + or draft_extend_input.input_ids.numel() > 0 ): # Install draft_extend_input as `batch.spec_info` for the seed # step (`_run_assistant_seed_step` replaces it with a fresh # `FrozenKVMTPDraftInput` for next iter). - batch.spec_info = verify_output.draft_extend_input - self.forward_draft_extend_after_decode(batch, verify_output) + batch.spec_info = draft_extend_input + self.forward_draft_extend_after_decode(batch) set_time_batch(batch.reqs, "set_spec_draft_extend_end_time", trace_only=True) return GenerationBatchResult( @@ -510,15 +510,11 @@ def forward_draft_extend( mm_input_embeds=mm_input_embeds, ) - def forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ) -> None: - # `_run_assistant_seed_step` will replace `batch.spec_info` with a fresh - # `FrozenKVMTPDraftInput` for the next iter's draft. + def forward_draft_extend_after_decode(self, batch: ScheduleBatch) -> None: draft_extend_input: FrozenKVMTPDraftExtendInput = batch.spec_info input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: + if not input_is_idle and draft_extend_input.input_ids.numel() == 0: # All reqs finished. Install an idle FrozenKVMTPDraftInput so the # next-iter draft sees a valid spec_info. batch = batch.copy() @@ -541,10 +537,10 @@ def forward_draft_extend_after_decode( try: # Verify may leave finished requests in ScheduleBatch; seed only - # the unfinished requests carried by `verify_output`. - batch.seq_lens = verify_output.seq_lens_for_draft_extend - batch.seq_lens_cpu = verify_output.seq_lens_for_draft_extend_cpu - batch.req_pool_indices = verify_output.req_pool_indices_for_draft_extend + # the unfinished reqs carried by `draft_extend_input`. + batch.seq_lens = draft_extend_input.seq_lens + batch.seq_lens_cpu = draft_extend_input.seq_lens_cpu + batch.req_pool_indices = draft_extend_input.req_pool_indices last_token_ids, last_hidden = self._select_last_verified_seed( draft_extend_input @@ -555,7 +551,7 @@ def forward_draft_extend_after_decode( batch, last_token_ids, last_hidden, - seq_lens_cpu=verify_output.seq_lens_for_draft_extend_cpu, + seq_lens_cpu=draft_extend_input.seq_lens_cpu, ) finally: batch.seq_lens = seq_lens_backup diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 56953b1566b7..64e79be2c8d8 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -282,17 +282,16 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ), speculative_moe_backend_context(): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. + draft_extend_input = verify_output.draft_extend_input if ( self.server_args.enable_dp_attention - or verify_output.unfinished_accept_tokens.shape[0] > 0 + or draft_extend_input.input_ids.shape[0] > 0 ): # decode is not finished # Install draft_extend_input for the extend forward, then # install the assembled next-iter EagleDraftInput it returns. - batch.spec_info = verify_output.draft_extend_input - next_draft_input = self.forward_draft_extend_after_decode( - batch, verify_output - ) + batch.spec_info = draft_extend_input + next_draft_input = self.forward_draft_extend_after_decode(batch) batch.spec_info = next_draft_input return GenerationBatchResult( @@ -305,7 +304,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul def check_forward_draft_extend_after_decode( self, batch: ScheduleBatch, verify_output: EagleVerifyOutput ): - local_need_forward = verify_output.unfinished_accept_tokens.shape[0] > 0 + local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward @@ -661,7 +660,7 @@ def forward_draft_extend( forward_batch.spec_info.topk_index = torch.cat(topk_index_list, dim=1) def forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput + self, batch: ScheduleBatch ) -> EagleDraftInput: draft_extend_input: EagleDraftExtendInput = batch.spec_info @@ -673,7 +672,7 @@ def forward_draft_extend_after_decode( input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: + if not input_is_idle and draft_extend_input.input_ids.numel() == 0: batch = batch.copy() batch.prepare_for_idle() hidden_size = ( @@ -694,7 +693,6 @@ def forward_draft_extend_after_decode( draft_extend_input.num_tokens_for_logprob_per_req = 1 draft_extend_input.prepare_extend_after_decode( batch, - verify_output=verify_output, speculative_num_steps=self.speculative_num_steps, ) batch.forward_mode = ( From 7bfe3c02d8b2645703f63f5b7bd54a54b2eb01ac Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 05:21:16 -0700 Subject: [PATCH 13/22] forward_batch_info: getattr-guard num_accepted_drafts on draft-phase spec_info --- python/sglang/srt/model_executor/forward_batch_info.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3192138f7264..e6ba5df7f29b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -999,7 +999,10 @@ def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs): spec_info.topk_index = self._pad_tensor_to_size( spec_info.topk_index, bs ) - if spec_info.num_accepted_drafts is not None: + # `num_accepted_*` only live on `EagleDraftExtendInput` (draft-extend + # phase). `EagleDraftInput` (draft phase) doesn't have these fields, + # so use `getattr` to skip when spec_info is the latter. + if getattr(spec_info, "num_accepted_drafts", None) is not None: spec_info.num_accepted_drafts = self._pad_tensor_to_size( spec_info.num_accepted_drafts, bs ) From 4c680f16d8449289084fe083edeacee725185682 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 05:56:22 -0700 Subject: [PATCH 14/22] forward_batch_info: getattr-guard topk_p/topk_index on draft-extend spec_info --- python/sglang/srt/model_executor/forward_batch_info.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e6ba5df7f29b..a4570a477e9b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -993,9 +993,12 @@ def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs): spec_info = self.spec_info self.output_cache_loc_backup = self.out_cache_loc self.hidden_states_backup = spec_info.hidden_states - if spec_info.topk_p is not None: + # `topk_p` / `topk_index` only live on `EagleDraftInput` (draft phase). + # `EagleDraftExtendInput` (draft-extend phase) doesn't have these, + # so use `getattr` so the guard skips cleanly there. + if getattr(spec_info, "topk_p", None) is not None: spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs) - if spec_info.topk_index is not None: + if getattr(spec_info, "topk_index", None) is not None: spec_info.topk_index = self._pad_tensor_to_size( spec_info.topk_index, bs ) From 490bcc0b86cd3faf345c39cd323aad7d869317ab Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 15:10:53 -0700 Subject: [PATCH 15/22] v1: install empty EagleDraftInput when extend skipped (retract edge case) --- python/sglang/srt/speculative/eagle_worker.py | 8 ++++++++ python/sglang/srt/speculative/multi_layer_eagle_worker.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index f4de48e8f1dd..4f459c00f32c 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -515,6 +515,14 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul batch.spec_info = draft_extend_input next_draft_input = self.forward_draft_extend_after_decode(batch) batch.spec_info = next_draft_input + else: + # All reqs finished this verify and dp_attention is not + # forcing the forward. Install an empty EagleDraftInput so + # next iter's merge_batch short-circuits on None + # hidden_states (EagleVerifyInput has no merge_batch). + batch.spec_info = EagleDraftInput( + capture_hidden_mode=CaptureHiddenMode.LAST, + ) set_time_batch( batch.reqs, "set_spec_draft_extend_end_time", trace_only=True diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 64e79be2c8d8..b869fec6045f 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -293,6 +293,14 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul batch.spec_info = draft_extend_input next_draft_input = self.forward_draft_extend_after_decode(batch) batch.spec_info = next_draft_input + else: + # All reqs finished this verify and dp_attention is not + # forcing the forward. Install an empty EagleDraftInput so + # next iter's merge_batch short-circuits on None + # hidden_states (EagleVerifyInput has no merge_batch). + batch.spec_info = EagleDraftInput( + capture_hidden_mode=CaptureHiddenMode.LAST, + ) return GenerationBatchResult( logits_output=logits_output, From a6c4467d29f585f0e239c1d4053cfc9c439bec00 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 02:27:09 -0700 Subject: [PATCH 16/22] restore num_accepted_drafts/tokens on EagleDraftInput for V2 --- python/sglang/srt/speculative/eagle_info.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 73ce3cdc2674..593e5392d854 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -667,6 +667,10 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): future_indices: Optional[FutureIndices] = None new_seq_lens: Optional[torch.Tensor] = None verify_done: Optional[torch.cuda.Event] = None + # V2 reuses `EagleDraftInput` across phases (V1 has a separate + # `EagleDraftExtendInput` for these). Set during V2's draft-extend. + num_accepted_drafts: Optional[torch.Tensor] = None + num_accepted_tokens: Optional[torch.Tensor] = None def __post_init__(self): super().__init__(SpecInputType.EAGLE_DRAFT) From 3d123cd42a572b2443ac9e8bbd589b659f2c767d Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 15:22:40 -0700 Subject: [PATCH 17/22] stash spec_info comments; drop unused batch param --- python/sglang/srt/speculative/eagle_worker.py | 18 +++++++----------- .../srt/speculative/frozen_kv_mtp_worker.py | 8 +++----- .../speculative/multi_layer_eagle_worker.py | 18 +++++++----------- 3 files changed, 17 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 4f459c00f32c..515ef3796739 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -486,7 +486,6 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) - # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input logits_output, verify_output, can_run_cuda_graph = self.verify(batch) @@ -509,17 +508,16 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul self.server_args.enable_dp_attention or draft_extend_input.input_ids.shape[0] > 0 ): - # decode is not finished - # Install draft_extend_input for the extend forward, then - # install the assembled next-iter EagleDraftInput it returns. + # decode is not finished; stash for extend, then restash + # the next-iter EagleDraftInput it returns. batch.spec_info = draft_extend_input next_draft_input = self.forward_draft_extend_after_decode(batch) batch.spec_info = next_draft_input else: - # All reqs finished this verify and dp_attention is not - # forcing the forward. Install an empty EagleDraftInput so - # next iter's merge_batch short-circuits on None - # hidden_states (EagleVerifyInput has no merge_batch). + # All reqs finished and dp_attention isn't forcing extend. + # Stash an empty EagleDraftInput so next iter's merge_batch + # short-circuits on None hidden_states (EagleVerifyInput + # has no merge_batch). batch.spec_info = EagleDraftInput( capture_hidden_mode=CaptureHiddenMode.LAST, ) @@ -541,9 +539,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul can_run_cuda_graph=can_run_cuda_graph, ) - def check_forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ): + def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput): local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index cb0aa68c66e0..b603800726b6 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -447,7 +447,6 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) - # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input logits_output, verify_output, can_run_cuda_graph = self.verify(batch) @@ -465,9 +464,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul self.server_args.enable_dp_attention or draft_extend_input.input_ids.numel() > 0 ): - # Install draft_extend_input as `batch.spec_info` for the seed - # step (`_run_assistant_seed_step` replaces it with a fresh - # `FrozenKVMTPDraftInput` for next iter). + # Stash for the seed step; _run_assistant_seed_step swaps in + # a fresh FrozenKVMTPDraftInput for next iter. batch.spec_info = draft_extend_input self.forward_draft_extend_after_decode(batch) set_time_batch(batch.reqs, "set_spec_draft_extend_end_time", trace_only=True) @@ -515,7 +513,7 @@ def forward_draft_extend_after_decode(self, batch: ScheduleBatch) -> None: input_is_idle = batch.forward_mode.is_idle() if not input_is_idle and draft_extend_input.input_ids.numel() == 0: - # All reqs finished. Install an idle FrozenKVMTPDraftInput so the + # All reqs finished; stash an idle FrozenKVMTPDraftInput so the # next-iter draft sees a valid spec_info. batch = batch.copy() batch.prepare_for_idle() diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index b869fec6045f..eaca6d74d0c5 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -273,7 +273,6 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul self.mtp_model_runner(0).tp_group ), speculative_moe_backend_context(): verify_input = self.draft(batch) - # Install verify_input as `batch.spec_info` for the verify forward. batch.spec_info = verify_input logits_output, verify_output, can_run_cuda_graph = self.verify(batch) @@ -287,17 +286,16 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul self.server_args.enable_dp_attention or draft_extend_input.input_ids.shape[0] > 0 ): - # decode is not finished - # Install draft_extend_input for the extend forward, then - # install the assembled next-iter EagleDraftInput it returns. + # decode is not finished; stash for extend, then restash + # the next-iter EagleDraftInput it returns. batch.spec_info = draft_extend_input next_draft_input = self.forward_draft_extend_after_decode(batch) batch.spec_info = next_draft_input else: - # All reqs finished this verify and dp_attention is not - # forcing the forward. Install an empty EagleDraftInput so - # next iter's merge_batch short-circuits on None - # hidden_states (EagleVerifyInput has no merge_batch). + # All reqs finished and dp_attention isn't forcing extend. + # Stash an empty EagleDraftInput so next iter's merge_batch + # short-circuits on None hidden_states (EagleVerifyInput + # has no merge_batch). batch.spec_info = EagleDraftInput( capture_hidden_mode=CaptureHiddenMode.LAST, ) @@ -309,9 +307,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul can_run_cuda_graph=can_run_cuda_graph, ) - def check_forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ): + def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput): local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward From c26ca56705c983d5987102425e9208e2d35a1838 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 15:28:05 -0700 Subject: [PATCH 18/22] cleanup forward_batch_info comments; drop dead server_args param --- .../sglang/srt/model_executor/cpu_graph_runner.py | 2 +- .../sglang/srt/model_executor/cuda_graph_runner.py | 6 ++---- .../sglang/srt/model_executor/forward_batch_info.py | 13 ++++--------- 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/model_executor/cpu_graph_runner.py b/python/sglang/srt/model_executor/cpu_graph_runner.py index 3dc6c11dc2f4..1c5a153a1c9d 100644 --- a/python/sglang/srt/model_executor/cpu_graph_runner.py +++ b/python/sglang/srt/model_executor/cpu_graph_runner.py @@ -802,7 +802,7 @@ def prepare_replay( captured_forward_batch.encoder_lens[:raw_bs].copy_( forward_batch.encoder_lens ) - if enable_num_token_non_padded(self.model_runner.server_args): + if enable_num_token_non_padded(): captured_forward_batch.num_token_non_padded.copy_( forward_batch.num_token_non_padded ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b04ee47cdd96..e5e6b02bda08 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -987,7 +987,7 @@ def capture_one_batch_size( # populate_from_forward_batch). buffers.num_token_non_padded[...] = num_tokens if ( - enable_num_token_non_padded(self.model_runner.server_args) + enable_num_token_non_padded() and self.require_gathered_buffer and not self.nsa_enable_prefill_cp ): @@ -1255,9 +1255,7 @@ def replay_prepare( require_gathered_buffer=self.require_gathered_buffer, num_tokens_per_bs=self.num_tokens_per_bs, nsa_enable_prefill_cp=self.nsa_enable_prefill_cp, - enable_num_token_non_padded_flag=enable_num_token_non_padded( - self.model_runner.server_args - ), + enable_num_token_non_padded_flag=enable_num_token_non_padded(), pp_proxy_tensors=pp_proxy_tensors, ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a4570a477e9b..c6fac4352ef8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -499,7 +499,7 @@ def init_new( ) num_tokens = len(batch.input_ids) if batch.input_ids is not None else 0 - if enable_num_token_non_padded(model_runner.server_args): + if enable_num_token_non_padded(): ret.num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32).to( device, non_blocking=True ) @@ -989,22 +989,17 @@ def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs): self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs) if self.spec_info is not None and self.spec_info.is_draft_input(): - # FIXME(lsyin): remove this isinstance logic spec_info = self.spec_info self.output_cache_loc_backup = self.out_cache_loc self.hidden_states_backup = spec_info.hidden_states - # `topk_p` / `topk_index` only live on `EagleDraftInput` (draft phase). - # `EagleDraftExtendInput` (draft-extend phase) doesn't have these, - # so use `getattr` so the guard skips cleanly there. + # spec_info is EagleDraftInput | EagleDraftExtendInput; each carries + # a disjoint subset of the fields below, so getattr-guard each one. if getattr(spec_info, "topk_p", None) is not None: spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs) if getattr(spec_info, "topk_index", None) is not None: spec_info.topk_index = self._pad_tensor_to_size( spec_info.topk_index, bs ) - # `num_accepted_*` only live on `EagleDraftExtendInput` (draft-extend - # phase). `EagleDraftInput` (draft phase) doesn't have these fields, - # so use `getattr` to skip when spec_info is the latter. if getattr(spec_info, "num_accepted_drafts", None) is not None: spec_info.num_accepted_drafts = self._pad_tensor_to_size( spec_info.num_accepted_drafts, bs @@ -1092,7 +1087,7 @@ def can_run_tbo(self): return self.tbo_split_seq_index is not None -def enable_num_token_non_padded(server_args): +def enable_num_token_non_padded(): return get_moe_expert_parallel_world_size() > 1 From be72838af5b42ddc0a21e609a1b627cca7dd87ac Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 01:54:57 -0700 Subject: [PATCH 19/22] drop redundant spec_info.positions = None --- .../srt/speculative/eagle_draft_extend_cuda_graph_runner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index a5ae5b5b3e89..f477f4ef932c 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -365,7 +365,6 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0 num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_tokens, ) - spec_info.positions = None self.deepep_adapter.capture(is_extend_in_batch=True) From dd35129d1443309382350e90cc6cda127f529b69 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 15:34:53 -0700 Subject: [PATCH 20/22] drop unused model_worker_batch from verify() return --- python/sglang/srt/speculative/eagle_worker.py | 6 +++--- python/sglang/srt/speculative/frozen_kv_mtp_worker.py | 6 ++---- python/sglang/srt/speculative/multi_layer_eagle_worker.py | 6 +++--- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 3c480e9a1bfc..e59de0604ac4 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -485,8 +485,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) - logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( - self.verify(batch, spec_info) + logits_output, verify_output, can_run_cuda_graph = self.verify( + batch, spec_info ) if get_global_tracing_enabled(): @@ -981,7 +981,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ) batch.spec_info = res.next_draft_input - return logits_output, res, model_worker_batch, can_run_cuda_graph + return logits_output, res, can_run_cuda_graph def _mamba_verify_update( self, diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index 9039577cc976..09ea8f0b98c9 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -447,9 +447,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) - logits_output, verify_output, _, can_run_cuda_graph = self.verify( - batch, spec_info - ) + logits_output, verify_output, can_run_cuda_graph = self.verify(batch, spec_info) if get_global_tracing_enabled(): for idx, req in enumerate(batch.reqs): @@ -771,4 +769,4 @@ def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): batch.spec_info = res.next_draft_input del seq_lens_pre_verify - return logits_output, res, model_worker_batch, can_run_cuda_graph + return logits_output, res, can_run_cuda_graph diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index b03492905f87..15de2a36f3e8 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -272,8 +272,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul self.mtp_model_runner(0).tp_group ), speculative_moe_backend_context(): spec_info = self.draft(batch) - logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( - self.verify(batch, spec_info) + logits_output, verify_output, can_run_cuda_graph = self.verify( + batch, spec_info ) with self.draft_tp_context( @@ -593,7 +593,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ) batch.spec_info = res.next_draft_input - return logits_output, res, model_worker_batch, can_run_cuda_graph + return logits_output, res, can_run_cuda_graph def forward_draft_extend( self, From 32b40e266b26540055e355db11c460e991075300 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 15:35:40 -0700 Subject: [PATCH 21/22] drop dead server_args param from enable_num_token_non_padded --- python/sglang/srt/model_executor/cpu_graph_runner.py | 2 +- python/sglang/srt/model_executor/cuda_graph_runner.py | 6 ++---- python/sglang/srt/model_executor/forward_batch_info.py | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/model_executor/cpu_graph_runner.py b/python/sglang/srt/model_executor/cpu_graph_runner.py index 3dc6c11dc2f4..1c5a153a1c9d 100644 --- a/python/sglang/srt/model_executor/cpu_graph_runner.py +++ b/python/sglang/srt/model_executor/cpu_graph_runner.py @@ -802,7 +802,7 @@ def prepare_replay( captured_forward_batch.encoder_lens[:raw_bs].copy_( forward_batch.encoder_lens ) - if enable_num_token_non_padded(self.model_runner.server_args): + if enable_num_token_non_padded(): captured_forward_batch.num_token_non_padded.copy_( forward_batch.num_token_non_padded ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b04ee47cdd96..e5e6b02bda08 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -987,7 +987,7 @@ def capture_one_batch_size( # populate_from_forward_batch). buffers.num_token_non_padded[...] = num_tokens if ( - enable_num_token_non_padded(self.model_runner.server_args) + enable_num_token_non_padded() and self.require_gathered_buffer and not self.nsa_enable_prefill_cp ): @@ -1255,9 +1255,7 @@ def replay_prepare( require_gathered_buffer=self.require_gathered_buffer, num_tokens_per_bs=self.num_tokens_per_bs, nsa_enable_prefill_cp=self.nsa_enable_prefill_cp, - enable_num_token_non_padded_flag=enable_num_token_non_padded( - self.model_runner.server_args - ), + enable_num_token_non_padded_flag=enable_num_token_non_padded(), pp_proxy_tensors=pp_proxy_tensors, ) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3192138f7264..7ea1336c31f2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -499,7 +499,7 @@ def init_new( ) num_tokens = len(batch.input_ids) if batch.input_ids is not None else 0 - if enable_num_token_non_padded(model_runner.server_args): + if enable_num_token_non_padded(): ret.num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32).to( device, non_blocking=True ) @@ -1086,7 +1086,7 @@ def can_run_tbo(self): return self.tbo_split_seq_index is not None -def enable_num_token_non_padded(server_args): +def enable_num_token_non_padded(): return get_moe_expert_parallel_world_size() > 1 From 56d55d93fd360af334fe602778465fc142603176 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Sat, 9 May 2026 15:36:07 -0700 Subject: [PATCH 22/22] drop dead batch param from check_forward_draft_extend_after_decode --- python/sglang/srt/speculative/eagle_worker.py | 4 +--- python/sglang/srt/speculative/multi_layer_eagle_worker.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index e59de0604ac4..751dee72846b 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -527,9 +527,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul can_run_cuda_graph=can_run_cuda_graph, ) - def check_forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ): + def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput): local_need_forward = verify_output.unfinished_accept_tokens.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 15de2a36f3e8..8dcb6685c481 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -295,9 +295,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul can_run_cuda_graph=can_run_cuda_graph, ) - def check_forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ): + def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput): local_need_forward = verify_output.unfinished_accept_tokens.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward