From 12341f5b04ecee1e737adf4767e786001e2c7d8b Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Mon, 11 May 2026 20:20:13 -0700 Subject: [PATCH] multi-layer mamba scatter: align to eagle_worker form; fix positional bug --- .../speculative/multi_layer_eagle_worker.py | 46 +++++++++---------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 8feeb47eef57..2c3c53e734ba 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -561,13 +561,10 @@ def verify(self, batch: ScheduleBatch): logits_output.hidden_states = logits_output.hidden_states[res.accept_indices] if self.target_worker.model_runner.hybrid_gdn_config is not None: - num_accept_tokens = ( - torch.tensor( - res.num_correct_drafts_per_req_cpu, - device=logits_output.hidden_states.device, - dtype=torch.int64, - ) - + 1 + num_correct_drafts = torch.tensor( + res.num_correct_drafts_per_req_cpu, + device=logits_output.hidden_states.device, + dtype=torch.int64, ) # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask @@ -577,28 +574,29 @@ def verify(self, batch: ScheduleBatch): # first_token_indices_per_req=prepend(0, accept_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] # last_token_indices_per_req=accept_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) # last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches - cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0) - req_start_positions = torch.cat( - [ - torch.zeros( - 1, - dtype=cumulative_num_accept_tokens.dtype, - device=cumulative_num_accept_tokens.device, - ), - cumulative_num_accept_tokens[:-1], - ] + # equivalent: last_correct_step_indices = last_token_indices_per_req - first_token_indices_per_req; + # `accepted_indices_offset` equals `first_token_indices_per_req` because the first accepted slot of each req is its "current token" at logical position i * draft_token_num. + cumulative_num_accept_tokens = torch.cumsum( + num_correct_drafts + 1, dim=0 + ) + accepted_indices_offset = torch.arange( + 0, + len(batch.seq_lens) * self.speculative_num_draft_tokens, + step=self.speculative_num_draft_tokens, + dtype=num_correct_drafts.dtype, + device=num_correct_drafts.device, ) - first_token_indices_per_req = res.accept_indices[req_start_positions] - last_token_indices_per_req = res.accept_indices[ - cumulative_num_accept_tokens - 1 - ] last_correct_step_indices = ( - last_token_indices_per_req - first_token_indices_per_req + res.accept_indices[cumulative_num_accept_tokens - 1] + - accepted_indices_offset ) else: - last_correct_step_indices = num_accept_tokens - 1 + last_correct_step_indices = num_correct_drafts self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( - last_correct_step_indices, self.target_worker.model_runner.model + last_correct_step_indices=last_correct_step_indices, + mamba_track_indices=None, + mamba_steps_to_track=None, + model=self.target_worker.model_runner.model, ) if batch.return_logprob: