Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 22 additions & 24 deletions python/sglang/srt/speculative/multi_layer_eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,13 +561,10 @@ def verify(self, batch: ScheduleBatch):
logits_output.hidden_states = logits_output.hidden_states[res.accept_indices]

if self.target_worker.model_runner.hybrid_gdn_config is not None:
num_accept_tokens = (
torch.tensor(
res.num_correct_drafts_per_req_cpu,
device=logits_output.hidden_states.device,
dtype=torch.int64,
)
+ 1
num_correct_drafts = torch.tensor(
res.num_correct_drafts_per_req_cpu,
device=logits_output.hidden_states.device,
dtype=torch.int64,
)

# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
Expand All @@ -577,28 +574,29 @@ def verify(self, batch: ScheduleBatch):
# first_token_indices_per_req=prepend(0, accept_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accept_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req)
# last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0)
req_start_positions = torch.cat(
[
torch.zeros(
1,
dtype=cumulative_num_accept_tokens.dtype,
device=cumulative_num_accept_tokens.device,
),
cumulative_num_accept_tokens[:-1],
]
# equivalent: last_correct_step_indices = last_token_indices_per_req - first_token_indices_per_req;
# `accepted_indices_offset` equals `first_token_indices_per_req` because the first accepted slot of each req is its "current token" at logical position i * draft_token_num.
cumulative_num_accept_tokens = torch.cumsum(
num_correct_drafts + 1, dim=0
)
accepted_indices_offset = torch.arange(
0,
len(batch.seq_lens) * self.speculative_num_draft_tokens,
step=self.speculative_num_draft_tokens,
dtype=num_correct_drafts.dtype,
device=num_correct_drafts.device,
)
first_token_indices_per_req = res.accept_indices[req_start_positions]
last_token_indices_per_req = res.accept_indices[
cumulative_num_accept_tokens - 1
]
last_correct_step_indices = (
last_token_indices_per_req - first_token_indices_per_req
res.accept_indices[cumulative_num_accept_tokens - 1]
- accepted_indices_offset
)
else:
last_correct_step_indices = num_accept_tokens - 1
last_correct_step_indices = num_correct_drafts
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
last_correct_step_indices, self.target_worker.model_runner.model
last_correct_step_indices=last_correct_step_indices,
mamba_track_indices=None,
mamba_steps_to_track=None,
model=self.target_worker.model_runner.model,
)

if batch.return_logprob:
Expand Down
Loading