diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 7ea1336c31f2..c6fac4352ef8 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -989,17 +989,18 @@ def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs): self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs) if self.spec_info is not None and self.spec_info.is_draft_input(): - # FIXME(lsyin): remove this isinstance logic spec_info = self.spec_info self.output_cache_loc_backup = self.out_cache_loc self.hidden_states_backup = spec_info.hidden_states - if spec_info.topk_p is not None: + # spec_info is EagleDraftInput | EagleDraftExtendInput; each carries + # a disjoint subset of the fields below, so getattr-guard each one. + if getattr(spec_info, "topk_p", None) is not None: spec_info.topk_p = self._pad_tensor_to_size(spec_info.topk_p, bs) - if spec_info.topk_index is not None: + if getattr(spec_info, "topk_index", None) is not None: spec_info.topk_index = self._pad_tensor_to_size( spec_info.topk_index, bs ) - if spec_info.num_accepted_drafts is not None: + if getattr(spec_info, "num_accepted_drafts", None) is not None: spec_info.num_accepted_drafts = self._pad_tensor_to_size( spec_info.num_accepted_drafts, bs ) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index f477f4ef932c..2da921adf695 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -26,7 +26,7 @@ ForwardMode, ) from sglang.srt.model_executor.input_buffers import ForwardInputBuffers -from sglang.srt.speculative.eagle_info import EagleDraftInput +from sglang.srt.speculative.eagle_info import EagleDraftExtendInput from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( require_attn_tp_gather, @@ -360,7 +360,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0 else: global_dp_buffer_len = None - spec_info = EagleDraftInput( + spec_info = EagleDraftExtendInput( hidden_states=hidden_states, num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_tokens, diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 205c0610ae38..593e5392d854 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -240,15 +240,14 @@ def verify( accepted token logits. """ if batch.forward_mode.is_idle(): - next_draft_input = EagleDraftInput.create_idle_input( + draft_extend_input = EagleDraftExtendInput.create_idle_input( device=batch.device, hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, - topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) return EagleVerifyOutput.create_idle( - next_draft_input=next_draft_input, + draft_extend_input=draft_extend_input, logits_output=logits_output, device=batch.device, spec_steps=self.spec_steps, @@ -545,21 +544,21 @@ def verify( batch.seq_lens.add_(num_accepted_drafts + 1) batch.seq_lens_cpu.add_(num_accepted_tokens_cpu) - next_draft_input = EagleDraftInput( + draft_extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[accept_index], num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_drafts + 1, num_accepted_tokens_cpu=num_accepted_tokens_list, + input_ids=accept_tokens, + seq_lens=batch.seq_lens, + seq_lens_cpu=batch.seq_lens_cpu, + req_pool_indices=batch.req_pool_indices, ) return EagleVerifyOutput( - next_draft_input=next_draft_input, + draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=accept_tokens, - unfinished_accept_tokens=accept_tokens, - seq_lens_for_draft_extend=batch.seq_lens, - seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu, - req_pool_indices_for_draft_extend=batch.req_pool_indices, num_accepted_drafts_per_req_cpu=num_accepted_drafts_list, accepted_indices=accept_index, ) @@ -614,51 +613,30 @@ def verify( unfinished_num_accepted_drafts = num_accepted_drafts[ unfinished_index_device ] - unfinished_accept_tokens = predict[unfinished_accept_index] - seq_lens_for_draft_extend = batch.seq_lens[unfinished_index_device] - seq_lens_for_draft_extend_cpu = batch.seq_lens_cpu[unfinished_index] - req_pool_indices_for_draft_extend = batch.req_pool_indices[ - unfinished_index_device - ] - next_draft_input = EagleDraftInput( + draft_extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[ unfinished_accept_index ], num_accepted_tokens_cpu=draft_input_num_accepted_tokens_cpu, num_accepted_drafts=unfinished_num_accepted_drafts, num_accepted_tokens=unfinished_num_accepted_drafts + 1, + input_ids=predict[unfinished_accept_index], + seq_lens=batch.seq_lens[unfinished_index_device], + seq_lens_cpu=batch.seq_lens_cpu[unfinished_index], + req_pool_indices=batch.req_pool_indices[unfinished_index_device], ) else: - unfinished_accept_tokens = torch.empty( - (0,), dtype=accept_tokens.dtype, device=accept_tokens.device - ) - seq_lens_for_draft_extend = torch.empty( - (0,), dtype=batch.seq_lens.dtype, device=batch.seq_lens.device - ) - seq_lens_for_draft_extend_cpu = torch.empty( - (0,), dtype=batch.seq_lens_cpu.dtype - ) - req_pool_indices_for_draft_extend = torch.empty( - (0,), - dtype=batch.req_pool_indices.dtype, - device=batch.req_pool_indices.device, - ) - next_draft_input = EagleDraftInput.create_idle_input( + draft_extend_input = EagleDraftExtendInput.create_idle_input( device=batch.device, hidden_size=batch.model_config.spec_hidden_size, dtype=batch.model_config.dtype, - topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) return EagleVerifyOutput( - next_draft_input=next_draft_input, + draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=accept_tokens, - unfinished_accept_tokens=unfinished_accept_tokens, - seq_lens_for_draft_extend=seq_lens_for_draft_extend, - seq_lens_for_draft_extend_cpu=seq_lens_for_draft_extend_cpu, - req_pool_indices_for_draft_extend=req_pool_indices_for_draft_extend, num_accepted_drafts_per_req_cpu=num_accepted_drafts_list, accepted_indices=accept_index, ) @@ -666,42 +644,33 @@ def verify( @dataclass class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): - # The inputs for decode # shape: (b, topk) topk_p: torch.Tensor = None topk_index: torch.Tensor = None - # shape: (b, hidden_size) when consumed by `draft` forward (one hidden per req); - # shape: (total_accepted, hidden_size) when consumed by `draft_extend` forward - # (one hidden per accepted token). Workers maintain this invariant locally; - # there is no type-level guard. Don't add new readers without checking phase. + # shape: (b, hidden_size) - one hidden per req, consumed by `draft` forward. hidden_states: torch.Tensor = None capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL - # Inputs for extend - # shape: (b,) - # `num_accepted_drafts` and `num_accepted_tokens` are kept in sync: - # `num_accepted_tokens = num_accepted_drafts + 1` (per-req, one bonus per req). - # Storing both avoids repeated `+ 1` at every consumer (attn backends, kernels). + # Per-req bonus token (the "+1" target prediction at end of each accept + # chain). Written by `EagleDraftExtendInput.prepare_extend_after_decode`; + # the worker copies it here for next iter's draft. bonus_tokens: torch.Tensor = None - num_accepted_drafts: torch.Tensor = None - num_accepted_tokens: torch.Tensor = None - # Read by attention backends during draft-extend forward; kept on the - # dataclass because the backends access it via `forward_batch.spec_info`. - num_accepted_tokens_cpu: List[int] = None - # Inputs for the attention backends # shape: (b + 1,) kv_indptr: torch.Tensor = None kv_indices: torch.Tensor = None - # Shape info for padding num_tokens_per_req: int = -1 num_tokens_for_logprob_per_req: int = -1 - # Inputs for V2 overlap worker + # V2 overlap worker only future_indices: Optional[FutureIndices] = None new_seq_lens: Optional[torch.Tensor] = None verify_done: Optional[torch.cuda.Event] = None + # V2 reuses `EagleDraftInput` across phases (V1 has a separate + # `EagleDraftExtendInput` for these). Set during V2's draft-extend. + num_accepted_drafts: Optional[torch.Tensor] = None + num_accepted_tokens: Optional[torch.Tensor] = None def __post_init__(self): super().__init__(SpecInputType.EAGLE_DRAFT) @@ -741,31 +710,151 @@ def create_idle_input( topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), capture_hidden_mode=capture_hidden_mode, new_seq_lens=torch.empty((0,), device=device, dtype=torch.int32), + ) + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + if self.future_indices is not None: + self.future_indices.indices = self.future_indices.indices[new_indices] + return + + strict_check = envs.SGLANG_SPEC_ENABLE_STRICT_FILTER_CHECK.get() + if has_been_filtered: + # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index` + # therefore, we don't need to filter the batch again in scheduler + error_msg = f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen" + if len(new_indices) != len(self.topk_p): + if strict_check: + raise ValueError(error_msg) + else: + logger.warning(error_msg) + + self.topk_p = self.topk_p[: len(new_indices)] + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.bonus_tokens = self.bonus_tokens[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] + self.topk_index = self.topk_index[new_indices] + self.hidden_states = self.hidden_states[new_indices] + self.bonus_tokens = self.bonus_tokens[new_indices] + + def merge_batch(self, spec_info: "EagleDraftInput"): + if self.future_indices is not None: + assert spec_info.future_indices is not None + self.future_indices = FutureIndices( + indices=torch.cat( + [self.future_indices.indices, spec_info.future_indices.indices] + ) + ) + return + + if self.hidden_states is None: + self.hidden_states = spec_info.hidden_states + self.bonus_tokens = spec_info.bonus_tokens + self.topk_p = spec_info.topk_p + self.topk_index = spec_info.topk_index + return + if spec_info.hidden_states is None: + return + self.hidden_states = torch.cat( + [self.hidden_states, spec_info.hidden_states], axis=0 + ) + self.bonus_tokens = torch.cat( + [self.bonus_tokens, spec_info.bonus_tokens], axis=0 + ) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) + + +@dataclass +class EagleDraftExtendInput(SpecInput): + """Inputs to the draft-extend forward (the per-accepted-token pass after verify). + + Produced by `EagleVerifyInput.verify`, installed on `batch.spec_info` for + the draft-extend forward, then replaced with a fresh `EagleDraftInput` for + the next iter's draft. + """ + + # shape: (total_accepted, hidden_size). Sliced from verify-time hidden_states + # by accept_index; consumed by the draft-extend forward. + hidden_states: torch.Tensor = None + + # Per-req accept counts. `num_accepted_tokens = num_accepted_drafts + 1`. + # Both kept for cuda-graph buffer indexing and the + # `create_extend_after_decode_spec_info` kernel. + num_accepted_drafts: torch.Tensor = None + num_accepted_tokens: torch.Tensor = None + # CPU view, read by attention backends during the extend forward. + num_accepted_tokens_cpu: List[int] = None + + # Batch-state slices for the draft-extend forward. Set by verify (sliced to + # reqs continuing into next iter). `prepare_extend_after_decode` copies + # these onto `batch.{input_ids, seq_lens, seq_lens_cpu, req_pool_indices}`. + # - input_ids: accept tokens flat over surviving reqs + # - seq_lens / _cpu: per-req sequence length (post-accept) + # - req_pool_indices: per-req kv-pool slot + input_ids: torch.Tensor = None + seq_lens: torch.Tensor = None + seq_lens_cpu: torch.Tensor = None + req_pool_indices: torch.Tensor = None + + # Set by `prepare_extend_after_decode`: + # - positions: kernel-written, shape `[total_accepted]`. + # - bonus_tokens: kernel-written, shape `[bs]`. The worker reads this + # post-extend to populate next iter's `EagleDraftInput.bonus_tokens`. + positions: Optional[torch.Tensor] = None + bonus_tokens: Optional[torch.Tensor] = None + + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.LAST + num_tokens_per_req: int = -1 + num_tokens_for_logprob_per_req: int = 1 + + def __post_init__(self): + super().__init__(SpecInputType.EAGLE_DRAFT_EXTEND) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.num_tokens_per_req, self.num_tokens_for_logprob_per_req + + @classmethod + def create_idle_input( + cls, + device: torch.device, + hidden_size: int, + dtype: torch.dtype, + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.LAST, + ) -> "EagleDraftExtendInput": + return cls( + hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), num_accepted_drafts=torch.empty((0,), device=device, dtype=torch.int32), num_accepted_tokens=torch.empty((0,), device=device, dtype=torch.int32), num_accepted_tokens_cpu=[], + input_ids=torch.empty((0,), device=device, dtype=torch.long), + seq_lens=torch.empty((0,), device=device, dtype=torch.int32), + seq_lens_cpu=torch.empty((0,), dtype=torch.int32), + req_pool_indices=torch.empty((0,), device=device, dtype=torch.int64), + capture_hidden_mode=capture_hidden_mode, ) def prepare_extend_after_decode( self, batch: ScheduleBatch, - verify_output: "EagleVerifyOutput", speculative_num_steps: int, ): - + # Caller must have installed `self` as `batch.spec_info` before calling. + assert batch.spec_info is self if batch.forward_mode.is_idle(): return - # All transient verify->extend handoff state is read off `verify_output`, - # not from `self`. The kernel below populates `self.bonus_tokens` - # ([bs] per-req) for the next decode round; that is the only state on - # `self` that survives past this method. - batch.input_ids = verify_output.unfinished_accept_tokens - batch.extend_lens = batch.spec_info.num_accepted_tokens_cpu + # The kernel below populates `self.positions` and `self.bonus_tokens`; + # the worker reads `self.bonus_tokens` to construct next iter's + # `EagleDraftInput`. + batch.input_ids = self.input_ids + batch.extend_lens = self.num_accepted_tokens_cpu batch.extend_num_tokens = sum(batch.extend_lens) - batch.seq_lens = verify_output.seq_lens_for_draft_extend - batch.seq_lens_cpu = verify_output.seq_lens_for_draft_extend_cpu - batch.req_pool_indices = verify_output.req_pool_indices_for_draft_extend + batch.seq_lens = self.seq_lens + batch.seq_lens_cpu = self.seq_lens_cpu + batch.req_pool_indices = self.req_pool_indices batch.return_logprob = False batch.return_hidden_states = False @@ -788,7 +877,7 @@ def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, - paged_kernel_lens_sum: int, + paged_kernel_lens_sum: Optional[int], req_to_token: torch.Tensor, ): device = req_pool_indices.device @@ -816,84 +905,17 @@ def generate_attn_arg_prefill( ) return kv_indices, cum_kv_seq_len, qo_indptr, None - def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): - if self.future_indices is not None: - self.future_indices.indices = self.future_indices.indices[new_indices] - return - - strict_check = envs.SGLANG_SPEC_ENABLE_STRICT_FILTER_CHECK.get() - if has_been_filtered: - # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index` - # therefore, we don't need to filter the batch again in scheduler - error_msg = f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen" - if len(new_indices) != len(self.topk_p): - if strict_check: - raise ValueError(error_msg) - else: - logger.warning(error_msg) - - self.topk_p = self.topk_p[: len(new_indices)] - self.topk_index = self.topk_index[: len(new_indices)] - self.hidden_states = self.hidden_states[: len(new_indices)] - self.bonus_tokens = self.bonus_tokens[: len(new_indices)] - else: - # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` - self.topk_p = self.topk_p[new_indices] - self.topk_index = self.topk_index[new_indices] - self.hidden_states = self.hidden_states[new_indices] - self.bonus_tokens = self.bonus_tokens[new_indices] - - def merge_batch(self, spec_info: "EagleDraftInput"): - if self.future_indices is not None: - assert spec_info.future_indices is not None - self.future_indices = FutureIndices( - indices=torch.cat( - [self.future_indices.indices, spec_info.future_indices.indices] - ) - ) - return - - if self.hidden_states is None: - self.hidden_states = spec_info.hidden_states - self.bonus_tokens = spec_info.bonus_tokens - self.topk_p = spec_info.topk_p - self.topk_index = spec_info.topk_index - return - if spec_info.hidden_states is None: - return - self.hidden_states = torch.cat( - [self.hidden_states, spec_info.hidden_states], axis=0 - ) - self.bonus_tokens = torch.cat( - [self.bonus_tokens, spec_info.bonus_tokens], axis=0 - ) - self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) - self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) - @dataclass class EagleVerifyOutput: - # Next iter's persistent draft state, ready to be installed as `batch.spec_info`. - next_draft_input: EagleDraftInput + # Next iter's draft-extend input, installed as `batch.spec_info` for the + # draft-extend forward. + draft_extend_input: EagleDraftExtendInput # Logit outputs from target worker. logits_output: LogitsProcessorOutput # All accepted tokens flat across all reqs incl. those that finished this # step. Includes the bonus token. Used for output processing. accept_tokens: torch.Tensor - # Below are transient handoff fields for the next iter's draft-extend pass. - # They are scoped to the verify -> prepare_extend_after_decode window only; - # `prepare_extend_after_decode` reads them off this object via method arg - # rather than smuggling them through `EagleDraftInput`. - # - # Subset of `accept_tokens` for reqs continuing into next iter's draft-extend - # forward (= `accept_tokens` when no req finished; flat over unfinished - # reqs only otherwise). Becomes `batch.input_ids` for that forward pass. - unfinished_accept_tokens: torch.Tensor - # `batch.seq_lens` / `batch.seq_lens_cpu` / `batch.req_pool_indices` to - # use for the next iter's draft-extend forward; sliced to surviving reqs. - seq_lens_for_draft_extend: torch.Tensor - seq_lens_for_draft_extend_cpu: torch.Tensor - req_pool_indices_for_draft_extend: torch.Tensor # Accepted token length per sequence in a batch in CPU (full set). num_accepted_drafts_per_req_cpu: List[int] # Accepted indices from logits_output.next_token_logits @@ -903,21 +925,15 @@ class EagleVerifyOutput: def create_idle( cls, *, - next_draft_input: EagleDraftInput, + draft_extend_input: EagleDraftExtendInput, logits_output: LogitsProcessorOutput, device: torch.device, spec_steps: int, ) -> "EagleVerifyOutput": return cls( - next_draft_input=next_draft_input, + draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=torch.empty(0, dtype=torch.long, device=device), - unfinished_accept_tokens=torch.empty(0, dtype=torch.long, device=device), - seq_lens_for_draft_extend=torch.empty(0, dtype=torch.int32, device=device), - seq_lens_for_draft_extend_cpu=torch.empty(0, dtype=torch.int32), - req_pool_indices_for_draft_extend=torch.empty( - 0, dtype=torch.int64, device=device - ), num_accepted_drafts_per_req_cpu=[], accepted_indices=torch.full( (0, spec_steps + 1), -1, dtype=torch.int32, device=device diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 751dee72846b..515ef3796739 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -46,6 +46,7 @@ EAGLEDraftExtendCudaGraphRunner, ) from sglang.srt.speculative.eagle_info import ( + EagleDraftExtendInput, EagleDraftInput, EagleVerifyInput, EagleVerifyOutput, @@ -480,14 +481,13 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.draft_model_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): - spec_info = self.draft(batch) + verify_input = self.draft(batch) set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) - logits_output, verify_output, can_run_cuda_graph = self.verify( - batch, spec_info - ) + batch.spec_info = verify_input + logits_output, verify_output, can_run_cuda_graph = self.verify(batch) if get_global_tracing_enabled(): for idx, req in enumerate(batch.reqs): @@ -503,12 +503,24 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. + draft_extend_input = verify_output.draft_extend_input if ( self.server_args.enable_dp_attention - or verify_output.unfinished_accept_tokens.shape[0] > 0 + or draft_extend_input.input_ids.shape[0] > 0 ): - # decode is not finished - self.forward_draft_extend_after_decode(batch, verify_output) + # decode is not finished; stash for extend, then restash + # the next-iter EagleDraftInput it returns. + batch.spec_info = draft_extend_input + next_draft_input = self.forward_draft_extend_after_decode(batch) + batch.spec_info = next_draft_input + else: + # All reqs finished and dp_attention isn't forcing extend. + # Stash an empty EagleDraftInput so next iter's merge_batch + # short-circuits on None hidden_states (EagleVerifyInput + # has no merge_batch). + batch.spec_info = EagleDraftInput( + capture_hidden_mode=CaptureHiddenMode.LAST, + ) set_time_batch( batch.reqs, "set_spec_draft_extend_end_time", trace_only=True @@ -528,7 +540,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ) def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput): - local_need_forward = verify_output.unfinished_accept_tokens.shape[0] > 0 + local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward @@ -890,7 +902,8 @@ def clear_cache_pool(self): # allocator and kv cache pool are shared with target worker pass - def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): + def verify(self, batch: ScheduleBatch): + spec_info: EagleVerifyInput = batch.spec_info seq_lens_pre_verify = batch.seq_lens.clone() spec_info.prepare_for_verify(batch, self.page_size) spec_info.num_tokens_per_req = self.speculative_num_steps + 1 @@ -900,7 +913,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu @@ -977,7 +989,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.next_draft_input return logits_output, res, can_run_cuda_graph @@ -1105,20 +1116,20 @@ def forward_draft_extend( self.capture_for_decode(logits_output, forward_batch.spec_info) def forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ): - assert isinstance(batch.spec_info, EagleDraftInput) + self, batch: ScheduleBatch + ) -> EagleDraftInput: + draft_extend_input: EagleDraftExtendInput = batch.spec_info + # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() seq_lens_cpu_backup = batch.seq_lens_cpu.clone() req_pool_indices_backup = batch.req_pool_indices - num_accepted_drafts_backup = batch.spec_info.num_accepted_drafts.clone() - num_accepted_tokens_backup = batch.spec_info.num_accepted_tokens.clone() return_logprob_backup = batch.return_logprob input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: + if not input_is_idle and draft_extend_input.input_ids.numel() == 0: + # All reqs finished this verify; swap to an idle ExtendInput. batch = batch.copy() batch.prepare_for_idle() hidden_size = ( @@ -1127,19 +1138,19 @@ def forward_draft_extend_after_decode( and self.eagle_use_aux_hidden_state else self.model_config.spec_hidden_size ) - batch.spec_info = EagleDraftInput.create_idle_input( + draft_extend_input = EagleDraftExtendInput.create_idle_input( device=self.device, hidden_size=hidden_size, dtype=self.model_config.dtype, - topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) + batch.spec_info = draft_extend_input - batch.spec_info.num_tokens_per_req = self.speculative_num_steps + 1 - batch.spec_info.num_tokens_for_logprob_per_req = 1 - batch.spec_info.prepare_extend_after_decode( + # Phase 1: prepare extend (kernel writes draft_extend_input.{positions, bonus_tokens}) + draft_extend_input.num_tokens_per_req = self.speculative_num_steps + 1 + draft_extend_input.num_tokens_for_logprob_per_req = 1 + draft_extend_input.prepare_extend_after_decode( batch, - verify_output=verify_output, speculative_num_steps=self.speculative_num_steps, ) batch.forward_mode = ( @@ -1159,7 +1170,7 @@ def forward_draft_extend_after_decode( else: forward_batch.seq_lens_sum = batch.seq_lens.sum().item() - # Run + # Phase 2: run draft-extend forward can_cuda_graph = ( self.cuda_graph_runner_for_draft_extend and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) @@ -1168,11 +1179,10 @@ def forward_draft_extend_after_decode( logits_output = self.cuda_graph_runner_for_draft_extend.replay( forward_batch ) - forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = ( - logits_output.topk_p, - logits_output.topk_index, - ) - forward_batch.spec_info.hidden_states = logits_output.hidden_states + # cuda-graph replay populates logits_output.{topk_p, topk_index, hidden_states}. + topk_p = logits_output.topk_p + topk_index = logits_output.topk_index + hidden_states = logits_output.hidden_states else: forward_batch.can_run_dp_cuda_graph = False if not forward_batch.forward_mode.is_idle(): @@ -1185,24 +1195,36 @@ def forward_draft_extend_after_decode( logits_output = self.draft_model_runner.forward( forward_batch, skip_attn_backend_init=True ).logits_output - self.capture_for_decode(logits_output, forward_batch.spec_info) + # Non-cuda-graph path: compute topk_p / topk_index inline. + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + hidden_states = logits_output.hidden_states maybe_detect_nan( logits_output.next_token_logits, f"draft_extend_after_decode (cuda_graph={can_cuda_graph})", ) - # Restore backup. - # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + # Phase 3: assemble next-iter EagleDraftInput from extend output + next_draft_input = EagleDraftInput( + bonus_tokens=draft_extend_input.bonus_tokens, + hidden_states=hidden_states, + topk_p=topk_p, + topk_index=topk_index, + capture_hidden_mode=CaptureHiddenMode.FULL, + ) + + # Restore batch fields. `seq_lens` etc. were modified by + # `prepare_extend_after_decode`. Caller installs `next_draft_input` as + # `batch.spec_info`. batch.forward_mode = ( ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ) batch.seq_lens = seq_lens_backup batch.seq_lens_cpu = seq_lens_cpu_backup batch.req_pool_indices = req_pool_indices_backup - batch.spec_info.num_accepted_drafts = num_accepted_drafts_backup - batch.spec_info.num_accepted_tokens = num_accepted_tokens_backup batch.return_logprob = return_logprob_backup + return next_draft_input def capture_for_decode( self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_info.py b/python/sglang/srt/speculative/frozen_kv_mtp_info.py index d092446168bd..7b562b52094f 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_info.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_info.py @@ -18,6 +18,7 @@ from sglang.srt.mem_cache.memory_pool import KVCache from sglang.srt.speculative.eagle_info import ( + EagleDraftExtendInput, EagleDraftInput, EagleVerifyInput, EagleVerifyOutput, @@ -53,6 +54,14 @@ def __post_init__(self): SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_DRAFT) +@dataclass +class FrozenKVMTPDraftExtendInput(EagleDraftExtendInput): + """Draft-extend input for Frozen-KV MTP. Tag-only subclass.""" + + def __post_init__(self): + SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_DRAFT_EXTEND) + + @dataclass class FrozenKVMTPVerifyInput(EagleVerifyInput): """Verify input for Frozen-KV MTP.""" @@ -62,21 +71,23 @@ def __post_init__(self): def verify(self, *args, **kwargs) -> EagleVerifyOutput: output = super().verify(*args, **kwargs) - output.next_draft_input = _to_frozen_kv_mtp_draft_input(output.next_draft_input) + output.draft_extend_input = _to_frozen_kv_mtp_draft_extend_input( + output.draft_extend_input + ) return output FrozenKVMTPVerifyOutput = EagleVerifyOutput -def _to_frozen_kv_mtp_draft_input( - draft_input: EagleDraftInput, -) -> FrozenKVMTPDraftInput: - if isinstance(draft_input, FrozenKVMTPDraftInput): - return draft_input - return FrozenKVMTPDraftInput( +def _to_frozen_kv_mtp_draft_extend_input( + draft_extend_input: EagleDraftExtendInput, +) -> FrozenKVMTPDraftExtendInput: + if isinstance(draft_extend_input, FrozenKVMTPDraftExtendInput): + return draft_extend_input + return FrozenKVMTPDraftExtendInput( **{ - field.name: getattr(draft_input, field.name) - for field in fields(EagleDraftInput) + field.name: getattr(draft_extend_input, field.name) + for field in fields(EagleDraftExtendInput) } ) diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py index 05512ddf555c..74ff0ef7ee70 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py @@ -23,6 +23,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.speculative.frozen_kv_mtp_info import ( FrozenKVMTPContext, + FrozenKVMTPDraftExtendInput, FrozenKVMTPDraftInput, ) from sglang.srt.speculative.spec_utils import fast_topk @@ -134,11 +135,8 @@ def select_last_extend_hidden( def select_last_verified_seed( - draft_input: FrozenKVMTPDraftInput, + draft_input: FrozenKVMTPDraftExtendInput, ) -> Tuple[torch.Tensor, torch.Tensor]: - if draft_input.num_accepted_tokens is None: - return draft_input.bonus_tokens, draft_input.hidden_states - counts = draft_input.num_accepted_tokens.to(torch.long) last_indices = torch.cumsum(counts, dim=0) - 1 return ( diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index 09ea8f0b98c9..b603800726b6 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -43,13 +43,13 @@ from sglang.srt.observability.req_time_stats import set_time_batch from sglang.srt.observability.trace import get_global_tracing_enabled from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.eagle_info import EagleVerifyOutput from sglang.srt.speculative.eagle_utils import ( build_tree_kernel_efficient, organize_draft_results, ) from sglang.srt.speculative.frozen_kv_mtp_info import ( FrozenKVMTPContext, + FrozenKVMTPDraftExtendInput, FrozenKVMTPDraftInput, FrozenKVMTPVerifyInput, FrozenKVMTPVerifyOutput, @@ -335,7 +335,7 @@ def _select_last_extend_hidden( return select_last_extend_hidden(batch, hidden_states) def _select_last_verified_seed( - self, draft_input: FrozenKVMTPDraftInput + self, draft_input: FrozenKVMTPDraftExtendInput ) -> Tuple[torch.Tensor, torch.Tensor]: return select_last_verified_seed(draft_input) @@ -443,11 +443,12 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.draft_model_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): - spec_info = self.draft(batch) + verify_input = self.draft(batch) set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) - logits_output, verify_output, can_run_cuda_graph = self.verify(batch, spec_info) + batch.spec_info = verify_input + logits_output, verify_output, can_run_cuda_graph = self.verify(batch) if get_global_tracing_enabled(): for idx, req in enumerate(batch.reqs): @@ -458,11 +459,15 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.draft_model_runner.tp_group ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + draft_extend_input = verify_output.draft_extend_input if ( self.server_args.enable_dp_attention - or batch.spec_info.bonus_tokens.numel() + or draft_extend_input.input_ids.numel() > 0 ): - self.forward_draft_extend_after_decode(batch, verify_output) + # Stash for the seed step; _run_assistant_seed_step swaps in + # a fresh FrozenKVMTPDraftInput for next iter. + batch.spec_info = draft_extend_input + self.forward_draft_extend_after_decode(batch) set_time_batch(batch.reqs, "set_spec_draft_extend_end_time", trace_only=True) return GenerationBatchResult( @@ -503,12 +508,13 @@ def forward_draft_extend( mm_input_embeds=mm_input_embeds, ) - def forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ) -> None: - assert isinstance(batch.spec_info, FrozenKVMTPDraftInput) + def forward_draft_extend_after_decode(self, batch: ScheduleBatch) -> None: + draft_extend_input: FrozenKVMTPDraftExtendInput = batch.spec_info input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle and batch.spec_info.bonus_tokens.numel() == 0: + + if not input_is_idle and draft_extend_input.input_ids.numel() == 0: + # All reqs finished; stash an idle FrozenKVMTPDraftInput so the + # next-iter draft sees a valid spec_info. batch = batch.copy() batch.prepare_for_idle() batch.spec_info = FrozenKVMTPDraftInput.create_idle_input( @@ -518,29 +524,32 @@ def forward_draft_extend_after_decode( topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) + return if batch.forward_mode.is_idle(): return - draft_input = batch.spec_info seq_lens_backup = batch.seq_lens.clone() seq_lens_cpu_backup = batch.seq_lens_cpu.clone() req_pool_indices_backup = batch.req_pool_indices try: # Verify may leave finished requests in ScheduleBatch; seed only - # the unfinished requests carried by `verify_output`. - batch.seq_lens = verify_output.seq_lens_for_draft_extend - batch.seq_lens_cpu = verify_output.seq_lens_for_draft_extend_cpu - batch.req_pool_indices = verify_output.req_pool_indices_for_draft_extend + # the unfinished reqs carried by `draft_extend_input`. + batch.seq_lens = draft_extend_input.seq_lens + batch.seq_lens_cpu = draft_extend_input.seq_lens_cpu + batch.req_pool_indices = draft_extend_input.req_pool_indices - last_token_ids, last_hidden = self._select_last_verified_seed(draft_input) + last_token_ids, last_hidden = self._select_last_verified_seed( + draft_extend_input + ) + # `_run_assistant_seed_step` constructs a fresh `FrozenKVMTPDraftInput` + # and installs it on `batch.spec_info` for next iter. self._run_assistant_seed_step( batch, last_token_ids, last_hidden, - seq_lens_cpu=verify_output.seq_lens_for_draft_extend_cpu, - draft_input=draft_input, + seq_lens_cpu=draft_extend_input.seq_lens_cpu, ) finally: batch.seq_lens = seq_lens_backup @@ -687,7 +696,8 @@ def draft_forward( score_list, token_list, parents_list, self.speculative_num_draft_tokens ) - def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): + def verify(self, batch: ScheduleBatch): + spec_info: FrozenKVMTPVerifyInput = batch.spec_info seq_lens_pre_verify = batch.seq_lens.clone() spec_info.prepare_for_verify(batch, self.page_size) spec_info.num_tokens_per_req = self.speculative_num_steps + 1 @@ -697,7 +707,6 @@ def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu @@ -766,7 +775,6 @@ def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.next_draft_input del seq_lens_pre_verify return logits_output, res, can_run_cuda_graph diff --git a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py index 295c1009bd7d..18050f9e0c2c 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py @@ -41,7 +41,7 @@ ForwardMode, ) from sglang.srt.model_executor.input_buffers import ForwardInputBuffers -from sglang.srt.speculative.eagle_info import EagleDraftInput +from sglang.srt.speculative.eagle_info import EagleDraftExtendInput from sglang.srt.speculative.multi_layer_eagle_utils import assign_new_state_triton from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( @@ -349,7 +349,7 @@ def get_forward_batch(self, bs: int) -> ForwardBatch: else: global_dp_buffer_len = None - spec_info = EagleDraftInput( + spec_info = EagleDraftExtendInput( hidden_states=hidden_states, num_accepted_drafts=num_accepted_drafts, num_accepted_tokens=num_accepted_tokens, diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 8dcb6685c481..eaca6d74d0c5 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -34,6 +34,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.draft_utils import DraftBackendFactory from sglang.srt.speculative.eagle_info import ( + EagleDraftExtendInput, EagleDraftInput, EagleVerifyInput, EagleVerifyOutput, @@ -271,22 +272,33 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul with self.draft_tp_context( self.mtp_model_runner(0).tp_group ), speculative_moe_backend_context(): - spec_info = self.draft(batch) - logits_output, verify_output, can_run_cuda_graph = self.verify( - batch, spec_info - ) + verify_input = self.draft(batch) + batch.spec_info = verify_input + logits_output, verify_output, can_run_cuda_graph = self.verify(batch) with self.draft_tp_context( self.mtp_model_runner(0).tp_group ), speculative_moe_backend_context(): # NOTE: We should use `check_forward_draft_extend_after_decode` # when DP attention is enabled, but it is slow. Skip it for now. + draft_extend_input = verify_output.draft_extend_input if ( self.server_args.enable_dp_attention - or verify_output.unfinished_accept_tokens.shape[0] > 0 + or draft_extend_input.input_ids.shape[0] > 0 ): - # decode is not finished - self.forward_draft_extend_after_decode(batch, verify_output) + # decode is not finished; stash for extend, then restash + # the next-iter EagleDraftInput it returns. + batch.spec_info = draft_extend_input + next_draft_input = self.forward_draft_extend_after_decode(batch) + batch.spec_info = next_draft_input + else: + # All reqs finished and dp_attention isn't forcing extend. + # Stash an empty EagleDraftInput so next iter's merge_batch + # short-circuits on None hidden_states (EagleVerifyInput + # has no merge_batch). + batch.spec_info = EagleDraftInput( + capture_hidden_mode=CaptureHiddenMode.LAST, + ) return GenerationBatchResult( logits_output=logits_output, @@ -296,7 +308,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul ) def check_forward_draft_extend_after_decode(self, verify_output: EagleVerifyOutput): - local_need_forward = verify_output.unfinished_accept_tokens.shape[0] > 0 + local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0 if not self.server_args.enable_dp_attention: return local_need_forward @@ -471,7 +483,8 @@ def clear_cache_pool(self): # allocator and kv cache pool are shared with target worker pass - def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): + def verify(self, batch: ScheduleBatch): + spec_info: EagleVerifyInput = batch.spec_info spec_info.prepare_for_verify(batch, self.page_size) batch.return_hidden_states = False batch.forward_mode = ( @@ -479,7 +492,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = spec_info model_worker_batch = batch.get_model_worker_batch( seq_lens_cpu_cache=spec_info.seq_lens_cpu @@ -589,7 +601,6 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): batch.forward_mode = ( ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE ) - batch.spec_info = res.next_draft_input return logits_output, res, can_run_cuda_graph @@ -653,20 +664,19 @@ def forward_draft_extend( forward_batch.spec_info.topk_index = torch.cat(topk_index_list, dim=1) def forward_draft_extend_after_decode( - self, batch: ScheduleBatch, verify_output: EagleVerifyOutput - ): - assert isinstance(batch.spec_info, EagleDraftInput) + self, batch: ScheduleBatch + ) -> EagleDraftInput: + draft_extend_input: EagleDraftExtendInput = batch.spec_info + # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() seq_lens_cpu_backup = batch.seq_lens_cpu.clone() req_pool_indices_backup = batch.req_pool_indices - num_accepted_drafts_backup = batch.spec_info.num_accepted_drafts - num_accepted_tokens_backup = batch.spec_info.num_accepted_tokens return_logprob_backup = batch.return_logprob input_is_idle = batch.forward_mode.is_idle() - if not input_is_idle and verify_output.unfinished_accept_tokens.numel() == 0: + if not input_is_idle and draft_extend_input.input_ids.numel() == 0: batch = batch.copy() batch.prepare_for_idle() hidden_size = ( @@ -674,19 +684,19 @@ def forward_draft_extend_after_decode( if self.speculative_algorithm.is_eagle3() else self.model_config.hidden_size ) - batch.spec_info = EagleDraftInput.create_idle_input( + draft_extend_input = EagleDraftExtendInput.create_idle_input( device=self.device, hidden_size=hidden_size, dtype=self.model_config.dtype, - topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) + batch.spec_info = draft_extend_input - batch.spec_info.num_tokens_per_req = self.speculative_num_steps + 1 - batch.spec_info.num_tokens_for_logprob_per_req = 1 - batch.spec_info.prepare_extend_after_decode( + # Phase 1: prepare extend (kernel writes draft_extend_input.{positions, bonus_tokens}) + draft_extend_input.num_tokens_per_req = self.speculative_num_steps + 1 + draft_extend_input.num_tokens_for_logprob_per_req = 1 + draft_extend_input.prepare_extend_after_decode( batch, - verify_output=verify_output, speculative_num_steps=self.speculative_num_steps, ) batch.forward_mode = ( @@ -748,17 +758,23 @@ def forward_draft_extend_after_decode( ) pt += extend_len - forward_batch.spec_info.topk_p = torch.cat(topk_p_list, dim=1) - forward_batch.spec_info.topk_index = torch.cat(topk_index_list, dim=1) + # Phase 3: assemble next-iter EagleDraftInput from extend output + next_draft_input = EagleDraftInput( + bonus_tokens=draft_extend_input.bonus_tokens, + hidden_states=logits_output.hidden_states, + topk_p=torch.cat(topk_p_list, dim=1), + topk_index=torch.cat(topk_index_list, dim=1), + capture_hidden_mode=CaptureHiddenMode.FULL, + ) - # Restore backup. - # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + # Restore batch fields. `seq_lens` etc. were modified by + # `prepare_extend_after_decode`. Caller installs `next_draft_input` as + # `batch.spec_info`. batch.forward_mode = ( ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE ) batch.seq_lens = seq_lens_backup batch.seq_lens_cpu = seq_lens_cpu_backup batch.req_pool_indices = req_pool_indices_backup - batch.spec_info.num_accepted_drafts = num_accepted_drafts_backup - batch.spec_info.num_accepted_tokens = num_accepted_tokens_backup batch.return_logprob = return_logprob_backup + return next_draft_input diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 8a2588c0c833..8a0565e32f96 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -195,8 +195,10 @@ class SpecInputType(IntEnum): # NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends. # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it EAGLE_DRAFT = auto() + EAGLE_DRAFT_EXTEND = auto() EAGLE_VERIFY = auto() FROZEN_KV_MTP_DRAFT = auto() + FROZEN_KV_MTP_DRAFT_EXTEND = auto() FROZEN_KV_MTP_VERIFY = auto() DFLASH_DRAFT = auto() DFLASH_VERIFY = auto() @@ -212,7 +214,9 @@ def is_draft_input(self) -> bool: # or use another variable name like `draft_input` to substitute `spec_info` return self.spec_input_type in { SpecInputType.EAGLE_DRAFT, + SpecInputType.EAGLE_DRAFT_EXTEND, SpecInputType.FROZEN_KV_MTP_DRAFT, + SpecInputType.FROZEN_KV_MTP_DRAFT_EXTEND, SpecInputType.DFLASH_DRAFT, }