diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py index 03f30af8aa7a..de7cf58fa08e 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py @@ -219,7 +219,7 @@ def __init__( def update_mamba_state_after_mtp_verify( self, - accept_steps: torch.Tensor, + last_correct_step_indices: torch.Tensor, mamba_track_indices: Optional[torch.Tensor], mamba_steps_to_track: Optional[torch.Tensor], model, @@ -233,7 +233,7 @@ def update_mamba_state_after_mtp_verify( - index_select kernel launches - nonzero kernel launches """ - request_number = accept_steps.shape[0] + request_number = last_correct_step_indices.shape[0] state_indices_tensor = ( self.linear_attn_backend.forward_metadata.mamba_cache_indices[ @@ -254,7 +254,7 @@ def update_mamba_state_after_mtp_verify( device=dst_indices_tensor.device, dtype=torch.int64, ) - last_steps = accept_steps.to(torch.int64) # [N] + last_steps = last_correct_step_indices.to(torch.int64) # [N] move_intermediate_cache( ssm_states, diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 80aa16822be7..30ca32f4c78a 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -936,7 +936,7 @@ def forward( def update_mamba_state_after_mtp_verify( self, - accept_steps: torch.Tensor, + last_correct_step_indices: torch.Tensor, mamba_track_indices: Optional[torch.Tensor], mamba_steps_to_track: Optional[torch.Tensor], model, @@ -950,7 +950,7 @@ def update_mamba_state_after_mtp_verify( - index_select kernel launches - nonzero kernel launches """ - request_number = accept_steps.shape[0] + request_number = last_correct_step_indices.shape[0] state_indices_tensor = ( self.linear_attn_backend.forward_metadata.mamba_cache_indices[ @@ -973,13 +973,13 @@ def update_mamba_state_after_mtp_verify( ssm_states, intermediate_state_cache, state_indices_tensor, - accept_steps, + last_correct_step_indices, ) fused_mamba_state_scatter_with_mask( conv_states, intermediate_conv_window_cache, state_indices_tensor, - accept_steps, + last_correct_step_indices, ) # Track indices used for tracking mamba states for prefix cache diff --git a/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py b/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py index 77419c2fbc9f..28c0e4929cbc 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py +++ b/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py @@ -17,7 +17,7 @@ def _fused_mamba_state_scatter_with_mask_kernel( dst_ptr, # Raw index arrays (before index_select) dst_indices_raw_ptr, # [total_requests] - state_indices_tensor - step_indices_raw_ptr, # [total_requests] - accept_steps or mamba_steps_to_track + step_indices_raw_ptr, # [total_requests] - last_correct_step_indices or mamba_steps_to_track elem_per_entry: tl.constexpr, src_layer_stride, src_req_stride, diff --git a/python/sglang/srt/observability/scheduler_metrics_mixin.py b/python/sglang/srt/observability/scheduler_metrics_mixin.py index 5c04f0d51daf..a1afa6c885d9 100644 --- a/python/sglang/srt/observability/scheduler_metrics_mixin.py +++ b/python/sglang/srt/observability/scheduler_metrics_mixin.py @@ -106,7 +106,7 @@ def init_metrics( # Cumulative spec-decoding counters (reset every decode_log_interval). # Each update adds (num_correct_drafts + bs, bs). - # `*_accepted_tokens` = drafts + bonus; `*_accepted_drafts` = drafts-only. + # `*_accept_tokens` = drafts + bonus; `*_correct_drafts` = drafts-only. self.spec_num_accept_tokens = 0 # per-log-interval self.spec_num_forward_ct = 0 self.spec_total_num_accept_tokens = 0 # lifetime diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index a87f74c10cdb..c2db72c8a224 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -368,7 +368,7 @@ def verify( and not sampling_info.is_all_greedy and is_dflash_sampling_verify_available() ): - accept_len, bonus = compute_dflash_sampling_correct_drafts_and_bonus( + correct_len, bonus = compute_dflash_sampling_correct_drafts_and_bonus( candidates=candidates, next_token_logits=logits_output.next_token_logits, sampling_info=sampling_info, @@ -377,14 +377,14 @@ def verify( target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( bs, self.draft_token_num ) - accept_len, bonus = compute_dflash_correct_drafts_and_bonus( + correct_len, bonus = compute_dflash_correct_drafts_and_bonus( candidates=candidates, target_predict=target_predict, ) - # Single D2H transfer: candidates[1:] + accept_len + bonus + # Single D2H transfer: candidates[1:] + correct_len + bonus packed = torch.cat( - [candidates[:, 1:], accept_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1 + [candidates[:, 1:], correct_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1 ).cpu() max_acc = self.draft_token_num - 1 diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index ff997ce0973d..f1ea1d794795 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -432,8 +432,8 @@ def compute_dflash_correct_drafts_and_bonus( Shape: [bs, block_size]. target_predict[:, t] corresponds to argmax at position t. Returns: - accept_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token). - bonus: int64 tensor [bs], the target-predicted token at index accept_len (the "bonus" token to append). + correct_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token). + bonus: int64 tensor [bs], the target-predicted token at index correct_len (the "bonus" token to append). Notes: Matches the reference implementation rule: @@ -454,9 +454,9 @@ def compute_dflash_correct_drafts_and_bonus( raise ValueError(f"block_size must be positive, got {block_size}.") matches = candidates[:, 1:] == target_predict[:, :-1] - accept_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1) - bonus = target_predict[torch.arange(bs, device=target_predict.device), accept_len] - return accept_len, bonus.to(torch.int64) + correct_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1) + bonus = target_predict[torch.arange(bs, device=target_predict.device), correct_len] + return correct_len, bonus.to(torch.int64) def compute_dflash_sampling_correct_drafts_and_bonus( @@ -631,8 +631,8 @@ def compute_dflash_sampling_correct_drafts_and_bonus( deterministic=True, ) - accept_len = accept_token_num + correct_len = accept_token_num row_ids = torch.arange(bs, dtype=torch.long, device=device) - accept_pos = accept_index[row_ids, accept_len.to(torch.long)].to(torch.long) + accept_pos = accept_index[row_ids, correct_len.to(torch.long)].to(torch.long) bonus = predicts[accept_pos].to(torch.int64) - return accept_len, bonus + return correct_len, bonus diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 68821c6454f3..68fde73632f1 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -1080,7 +1080,7 @@ def _update_target_mamba_state_after_verify( if not hasattr(attn_backend, "update_mamba_state_after_mtp_verify"): return - accept_steps = commit_lens.to(torch.int64) - 1 + last_correct_step_indices = commit_lens.to(torch.int64) - 1 mamba_steps_to_track = None if batch.mamba_track_indices is not None: @@ -1103,7 +1103,7 @@ def _update_target_mamba_state_after_verify( ) attn_backend.update_mamba_state_after_mtp_verify( - accept_steps=accept_steps, + last_correct_step_indices=last_correct_step_indices, mamba_track_indices=batch.mamba_track_indices, mamba_steps_to_track=mamba_steps_to_track, model=self.target_worker.model_runner.model, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 402ec3a9b2bf..aa6e7338008f 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1003,15 +1003,12 @@ def _mamba_verify_update( if batch.forward_mode.is_idle(): return - num_accept_tokens = ( - torch.tensor( - res.num_correct_drafts_per_req_cpu, - device=logits_output.hidden_states.device, - dtype=torch.int64, - ) - + 1 + num_correct_drafts = torch.tensor( + res.num_correct_drafts_per_req_cpu, + device=logits_output.hidden_states.device, + dtype=torch.int64, ) - cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0) + cumulative_num_accept_tokens = torch.cumsum(num_correct_drafts + 1, dim=0) # prepend 0 to the cumulative_num_accept_tokens accepted_indices_start = torch.cat( [ @@ -1037,14 +1034,15 @@ def _mamba_verify_update( # accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9] # first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] # last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) - # accept_steps = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches - # first_token_indices_per_req = res.accepted_indices[accepted_indices_start] - accept_steps = ( + # last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches + # equivalent: last_correct_step_indices = last_token_indices_per_req - first_token_indices_per_req; + # `accepted_indices_offset` equals `first_token_indices_per_req` because the first accepted slot of each req is its "current token" at logical position i * draft_token_num. + last_correct_step_indices = ( res.accepted_indices[cumulative_num_accept_tokens - 1] - accepted_indices_offset ) else: - accept_steps = num_accept_tokens - 1 + last_correct_step_indices = num_correct_drafts if batch.mamba_track_indices is not None: # If after verify, the request's seq_lens has crossed a mamba track interval, @@ -1068,7 +1066,7 @@ def _mamba_verify_update( mamba_steps_to_track = None self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( - accept_steps=accept_steps, + last_correct_step_indices=last_correct_step_indices, mamba_track_indices=batch.mamba_track_indices, mamba_steps_to_track=mamba_steps_to_track, model=self.target_worker.model_runner.model, diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 1dc5bc50bfb1..6a7cda9b80d7 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -1097,7 +1097,6 @@ def _mamba_verify_update( ): """Update mamba state for hybrid GDN models after verification.""" # `accept_lens` already includes the bonus token (drafts + 1 per req). - num_accept_tokens = accept_lens if not batch.forward_mode.is_idle() and accept_index.numel() > 0: if verify_input.topk != 1: raise ValueError("Spec v2 currently only supports topk = 1.") @@ -1106,16 +1105,16 @@ def _mamba_verify_update( 0, bs * self.speculative_num_draft_tokens, step=self.speculative_num_draft_tokens, - dtype=num_accept_tokens.dtype, - device=num_accept_tokens.device, + dtype=accept_lens.dtype, + device=accept_lens.device, ) - accept_steps = num_accept_tokens - 1 + last_correct_step_indices = accept_lens - 1 if batch.mamba_track_indices is not None: # If after verify, the request's seq_lens has crossed a mamba track interval, # we need to update the mamba state for the request at the crossing point. seq_lens_pre_verify = batch.seq_lens - seq_lens_post_verify = batch.seq_lens + num_accept_tokens + seq_lens_post_verify = batch.seq_lens + accept_lens mamba_track_interval = self.server_args.mamba_track_interval to_track_mask = ( seq_lens_pre_verify // mamba_track_interval @@ -1130,7 +1129,7 @@ def _mamba_verify_update( req_idx = torch.arange( bs, dtype=torch.int64, - device=num_accept_tokens.device, + device=accept_lens.device, ) candidate_track_steps = ( accept_index[req_idx, to_track_ith] - accepted_indices_offset @@ -1144,7 +1143,7 @@ def _mamba_verify_update( mamba_steps_to_track = None self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( - accept_steps=accept_steps, + last_correct_step_indices=last_correct_step_indices, mamba_track_indices=batch.mamba_track_indices, mamba_steps_to_track=mamba_steps_to_track, model=self.target_worker.model_runner.model, @@ -1157,12 +1156,13 @@ def move_accepted_tokens_to_target_kvcache( num_correct_drafts: torch.Tensor, ): """ - Move accepted tokens to the target KV cache. + Move accepted tokens (drafts + bonus) to the target KV cache. Args: batch: The batch to run. - accept_index: The index of the accepted tokens. - num_correct_drafts: The length of the accepted tokens. + accept_index: The index of the accepted tokens (incl. bonus). + num_correct_drafts: Per-req count of correct drafts (excludes bonus); + seq_lens is advanced by ``num_correct_drafts + 1`` to cover the bonus slot. """ bs = len(batch.seq_lens) size = bs * self.speculative_num_draft_tokens @@ -1179,7 +1179,7 @@ def move_accepted_tokens_to_target_kvcache( batch.req_pool_indices, self.req_to_token_pool.req_to_token, batch.seq_lens, - batch.seq_lens + num_correct_drafts, + batch.seq_lens + num_correct_drafts + 1, tgt_cache_loc, self.req_to_token_pool.req_to_token.shape[1], next_power_of_2(bs), diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 62b7b118de8c..0311167acf65 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -576,7 +576,7 @@ def verify(self, batch: ScheduleBatch): # accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9] # first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] # last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) - # max_relative_indices_per_req = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches + # last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0) req_start_positions = torch.cat( [ @@ -592,13 +592,13 @@ def verify(self, batch: ScheduleBatch): last_token_indices_per_req = res.accepted_indices[ cumulative_num_accept_tokens - 1 ] - max_relative_indices_per_req = ( + last_correct_step_indices = ( last_token_indices_per_req - first_token_indices_per_req ) else: - max_relative_indices_per_req = num_accept_tokens - 1 + last_correct_step_indices = num_accept_tokens - 1 self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( - max_relative_indices_per_req, self.target_worker.model_runner.model + last_correct_step_indices, self.target_worker.model_runner.model ) if batch.return_logprob: