Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(

def update_mamba_state_after_mtp_verify(
self,
accept_steps: torch.Tensor,
last_correct_step_indices: torch.Tensor,
mamba_track_indices: Optional[torch.Tensor],
mamba_steps_to_track: Optional[torch.Tensor],
model,
Expand All @@ -233,7 +233,7 @@ def update_mamba_state_after_mtp_verify(
- index_select kernel launches
- nonzero kernel launches
"""
request_number = accept_steps.shape[0]
request_number = last_correct_step_indices.shape[0]

state_indices_tensor = (
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
Expand All @@ -254,7 +254,7 @@ def update_mamba_state_after_mtp_verify(
device=dst_indices_tensor.device,
dtype=torch.int64,
)
last_steps = accept_steps.to(torch.int64) # [N]
last_steps = last_correct_step_indices.to(torch.int64) # [N]

move_intermediate_cache(
ssm_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ def forward(

def update_mamba_state_after_mtp_verify(
self,
accept_steps: torch.Tensor,
last_correct_step_indices: torch.Tensor,
mamba_track_indices: Optional[torch.Tensor],
mamba_steps_to_track: Optional[torch.Tensor],
model,
Expand All @@ -950,7 +950,7 @@ def update_mamba_state_after_mtp_verify(
- index_select kernel launches
- nonzero kernel launches
"""
request_number = accept_steps.shape[0]
request_number = last_correct_step_indices.shape[0]

state_indices_tensor = (
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
Expand All @@ -973,13 +973,13 @@ def update_mamba_state_after_mtp_verify(
ssm_states,
intermediate_state_cache,
state_indices_tensor,
accept_steps,
last_correct_step_indices,
)
fused_mamba_state_scatter_with_mask(
conv_states,
intermediate_conv_window_cache,
state_indices_tensor,
accept_steps,
last_correct_step_indices,
)

# Track indices used for tracking mamba states for prefix cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _fused_mamba_state_scatter_with_mask_kernel(
dst_ptr,
# Raw index arrays (before index_select)
dst_indices_raw_ptr, # [total_requests] - state_indices_tensor
step_indices_raw_ptr, # [total_requests] - accept_steps or mamba_steps_to_track
step_indices_raw_ptr, # [total_requests] - last_correct_step_indices or mamba_steps_to_track
elem_per_entry: tl.constexpr,
src_layer_stride,
src_req_stride,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def init_metrics(

# Cumulative spec-decoding counters (reset every decode_log_interval).
# Each update adds (num_correct_drafts + bs, bs).
# `*_accepted_tokens` = drafts + bonus; `*_accepted_drafts` = drafts-only.
# `*_accept_tokens` = drafts + bonus; `*_correct_drafts` = drafts-only.
self.spec_num_accept_tokens = 0 # per-log-interval
self.spec_num_forward_ct = 0
self.spec_total_num_accept_tokens = 0 # lifetime
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/speculative/dflash_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def verify(
and not sampling_info.is_all_greedy
and is_dflash_sampling_verify_available()
):
accept_len, bonus = compute_dflash_sampling_correct_drafts_and_bonus(
correct_len, bonus = compute_dflash_sampling_correct_drafts_and_bonus(
candidates=candidates,
next_token_logits=logits_output.next_token_logits,
sampling_info=sampling_info,
Expand All @@ -377,14 +377,14 @@ def verify(
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view(
bs, self.draft_token_num
)
accept_len, bonus = compute_dflash_correct_drafts_and_bonus(
correct_len, bonus = compute_dflash_correct_drafts_and_bonus(
candidates=candidates,
target_predict=target_predict,
)

# Single D2H transfer: candidates[1:] + accept_len + bonus
# Single D2H transfer: candidates[1:] + correct_len + bonus
packed = torch.cat(
[candidates[:, 1:], accept_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1
[candidates[:, 1:], correct_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1
).cpu()

max_acc = self.draft_token_num - 1
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/speculative/dflash_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,8 @@ def compute_dflash_correct_drafts_and_bonus(
Shape: [bs, block_size]. target_predict[:, t] corresponds to argmax at position t.

Returns:
accept_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token).
bonus: int64 tensor [bs], the target-predicted token at index accept_len (the "bonus" token to append).
correct_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token).
bonus: int64 tensor [bs], the target-predicted token at index correct_len (the "bonus" token to append).

Notes:
Matches the reference implementation rule:
Expand All @@ -454,9 +454,9 @@ def compute_dflash_correct_drafts_and_bonus(
raise ValueError(f"block_size must be positive, got {block_size}.")

matches = candidates[:, 1:] == target_predict[:, :-1]
accept_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1)
bonus = target_predict[torch.arange(bs, device=target_predict.device), accept_len]
return accept_len, bonus.to(torch.int64)
correct_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1)
bonus = target_predict[torch.arange(bs, device=target_predict.device), correct_len]
return correct_len, bonus.to(torch.int64)


def compute_dflash_sampling_correct_drafts_and_bonus(
Expand Down Expand Up @@ -631,8 +631,8 @@ def compute_dflash_sampling_correct_drafts_and_bonus(
deterministic=True,
)

accept_len = accept_token_num
correct_len = accept_token_num
row_ids = torch.arange(bs, dtype=torch.long, device=device)
accept_pos = accept_index[row_ids, accept_len.to(torch.long)].to(torch.long)
accept_pos = accept_index[row_ids, correct_len.to(torch.long)].to(torch.long)
bonus = predicts[accept_pos].to(torch.int64)
return accept_len, bonus
return correct_len, bonus
4 changes: 2 additions & 2 deletions python/sglang/srt/speculative/dflash_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,7 @@ def _update_target_mamba_state_after_verify(
if not hasattr(attn_backend, "update_mamba_state_after_mtp_verify"):
return

accept_steps = commit_lens.to(torch.int64) - 1
last_correct_step_indices = commit_lens.to(torch.int64) - 1
mamba_steps_to_track = None

if batch.mamba_track_indices is not None:
Expand All @@ -1103,7 +1103,7 @@ def _update_target_mamba_state_after_verify(
)

attn_backend.update_mamba_state_after_mtp_verify(
accept_steps=accept_steps,
last_correct_step_indices=last_correct_step_indices,
mamba_track_indices=batch.mamba_track_indices,
mamba_steps_to_track=mamba_steps_to_track,
model=self.target_worker.model_runner.model,
Expand Down
24 changes: 11 additions & 13 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,15 +1003,12 @@ def _mamba_verify_update(
if batch.forward_mode.is_idle():
return

num_accept_tokens = (
torch.tensor(
res.num_correct_drafts_per_req_cpu,
device=logits_output.hidden_states.device,
dtype=torch.int64,
)
+ 1
num_correct_drafts = torch.tensor(
res.num_correct_drafts_per_req_cpu,
device=logits_output.hidden_states.device,
dtype=torch.int64,
)
cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0)
cumulative_num_accept_tokens = torch.cumsum(num_correct_drafts + 1, dim=0)
# prepend 0 to the cumulative_num_accept_tokens
accepted_indices_start = torch.cat(
[
Expand All @@ -1037,14 +1034,15 @@ def _mamba_verify_update(
# accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9]
# first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req)
# accept_steps = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
# first_token_indices_per_req = res.accepted_indices[accepted_indices_start]
accept_steps = (
# last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
# equivalent: last_correct_step_indices = last_token_indices_per_req - first_token_indices_per_req;
# `accepted_indices_offset` equals `first_token_indices_per_req` because the first accepted slot of each req is its "current token" at logical position i * draft_token_num.
last_correct_step_indices = (
res.accepted_indices[cumulative_num_accept_tokens - 1]
- accepted_indices_offset
)
else:
accept_steps = num_accept_tokens - 1
last_correct_step_indices = num_correct_drafts

if batch.mamba_track_indices is not None:
# If after verify, the request's seq_lens has crossed a mamba track interval,
Expand All @@ -1068,7 +1066,7 @@ def _mamba_verify_update(
mamba_steps_to_track = None

self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
accept_steps=accept_steps,
last_correct_step_indices=last_correct_step_indices,
mamba_track_indices=batch.mamba_track_indices,
mamba_steps_to_track=mamba_steps_to_track,
model=self.target_worker.model_runner.model,
Expand Down
22 changes: 11 additions & 11 deletions python/sglang/srt/speculative/eagle_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,7 +1097,6 @@ def _mamba_verify_update(
):
"""Update mamba state for hybrid GDN models after verification."""
# `accept_lens` already includes the bonus token (drafts + 1 per req).
num_accept_tokens = accept_lens
if not batch.forward_mode.is_idle() and accept_index.numel() > 0:
if verify_input.topk != 1:
raise ValueError("Spec v2 currently only supports topk = 1.")
Expand All @@ -1106,16 +1105,16 @@ def _mamba_verify_update(
0,
bs * self.speculative_num_draft_tokens,
step=self.speculative_num_draft_tokens,
dtype=num_accept_tokens.dtype,
device=num_accept_tokens.device,
dtype=accept_lens.dtype,
device=accept_lens.device,
)
accept_steps = num_accept_tokens - 1
last_correct_step_indices = accept_lens - 1

if batch.mamba_track_indices is not None:
# If after verify, the request's seq_lens has crossed a mamba track interval,
# we need to update the mamba state for the request at the crossing point.
seq_lens_pre_verify = batch.seq_lens
seq_lens_post_verify = batch.seq_lens + num_accept_tokens
seq_lens_post_verify = batch.seq_lens + accept_lens
mamba_track_interval = self.server_args.mamba_track_interval
to_track_mask = (
seq_lens_pre_verify // mamba_track_interval
Expand All @@ -1130,7 +1129,7 @@ def _mamba_verify_update(
req_idx = torch.arange(
bs,
dtype=torch.int64,
device=num_accept_tokens.device,
device=accept_lens.device,
)
candidate_track_steps = (
accept_index[req_idx, to_track_ith] - accepted_indices_offset
Expand All @@ -1144,7 +1143,7 @@ def _mamba_verify_update(
mamba_steps_to_track = None

self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
accept_steps=accept_steps,
last_correct_step_indices=last_correct_step_indices,
mamba_track_indices=batch.mamba_track_indices,
mamba_steps_to_track=mamba_steps_to_track,
model=self.target_worker.model_runner.model,
Expand All @@ -1157,12 +1156,13 @@ def move_accepted_tokens_to_target_kvcache(
num_correct_drafts: torch.Tensor,
):
"""
Move accepted tokens to the target KV cache.
Move accepted tokens (drafts + bonus) to the target KV cache.

Args:
batch: The batch to run.
accept_index: The index of the accepted tokens.
num_correct_drafts: The length of the accepted tokens.
accept_index: The index of the accepted tokens (incl. bonus).
num_correct_drafts: Per-req count of correct drafts (excludes bonus);
seq_lens is advanced by ``num_correct_drafts + 1`` to cover the bonus slot.
"""
bs = len(batch.seq_lens)
size = bs * self.speculative_num_draft_tokens
Expand All @@ -1179,7 +1179,7 @@ def move_accepted_tokens_to_target_kvcache(
batch.req_pool_indices,
self.req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + num_correct_drafts,
batch.seq_lens + num_correct_drafts + 1,
tgt_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/speculative/multi_layer_eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def verify(self, batch: ScheduleBatch):
# accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9]
# first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req)
# max_relative_indices_per_req = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
# last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0)
req_start_positions = torch.cat(
[
Expand All @@ -592,13 +592,13 @@ def verify(self, batch: ScheduleBatch):
last_token_indices_per_req = res.accepted_indices[
cumulative_num_accept_tokens - 1
]
max_relative_indices_per_req = (
last_correct_step_indices = (
last_token_indices_per_req - first_token_indices_per_req
)
else:
max_relative_indices_per_req = num_accept_tokens - 1
last_correct_step_indices = num_accept_tokens - 1
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
max_relative_indices_per_req, self.target_worker.model_runner.model
last_correct_step_indices, self.target_worker.model_runner.model
)

if batch.return_logprob:
Expand Down
Loading