diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py index 6abe5c2fd042..6533699b0607 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_gdn_backend.py @@ -53,7 +53,7 @@ def prepare_gdn_inputs( spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): cache_indices = self.forward_metadata.mamba_cache_indices - self.num_accepted_tokens = torch.ones( + self.num_accept_tokens = torch.ones( [bs], dtype=torch.int32, device=cache_indices.device ) self.actual_seq_lengths = torch.ones( @@ -237,7 +237,7 @@ def forward_extend( seq_len = forward_batch.num_token_non_padded_cpu mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1) - num_accepted_tokens = torch.full( + num_accept_tokens = torch.full( (batch_size,), draft_token_num, dtype=torch.int32, @@ -249,7 +249,7 @@ def forward_extend( conv_states, cache_indices, layer.bias, - num_accepted_tokens, + num_accept_tokens, None, layer.activation == "silu", self.pad_slot_id, @@ -391,7 +391,7 @@ def fused_recurrent_gated_delta_rule_update( ) if self.graph_mode: - num_accepted_tokens = torch.full( + num_accept_tokens = torch.full( [batch_size], 1, dtype=torch.int32, device=cache_indices.device ) actual_seq_lengths = torch.full( @@ -399,7 +399,7 @@ def fused_recurrent_gated_delta_rule_update( ) ssm_state_indices = self.forward_metadata.mamba_cache_indices_gdn else: - num_accepted_tokens = self.num_accepted_tokens + num_accept_tokens = self.num_accept_tokens actual_seq_lengths = self.actual_seq_lengths ssm_state_indices = self.ssm_state_indices @@ -414,7 +414,7 @@ def fused_recurrent_gated_delta_rule_update( nv=num_value_heads, intermediate_state=intermediate_state, cache_indices=cache_indices, - num_accepted_tokens=num_accepted_tokens, + num_accept_tokens=num_accept_tokens, g=g, ) diff --git a/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py b/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py index 171a612a9350..03f30af8aa7a 100644 --- a/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py +++ b/python/sglang/srt/hardware_backend/npu/attention/ascend_hybrid_linear_attn_backend.py @@ -219,7 +219,7 @@ def __init__( def update_mamba_state_after_mtp_verify( self, - accepted_steps: torch.Tensor, + accept_steps: torch.Tensor, mamba_track_indices: Optional[torch.Tensor], mamba_steps_to_track: Optional[torch.Tensor], model, @@ -233,7 +233,7 @@ def update_mamba_state_after_mtp_verify( - index_select kernel launches - nonzero kernel launches """ - request_number = accepted_steps.shape[0] + request_number = accept_steps.shape[0] state_indices_tensor = ( self.linear_attn_backend.forward_metadata.mamba_cache_indices[ @@ -254,7 +254,7 @@ def update_mamba_state_after_mtp_verify( device=dst_indices_tensor.device, dtype=torch.int64, ) - last_steps = accepted_steps.to(torch.int64) # [N] + last_steps = accept_steps.to(torch.int64) # [N] move_intermediate_cache( ssm_states, diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index b09f5eeef7b6..897740536321 100755 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -1042,7 +1042,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): self.indices_updater_prefill.max_kv_len, ) elif forward_batch.forward_mode.is_draft_extend(): - # EAGLE V1: DRAFT_EXTEND mode - uses spec_info.num_accepted_tokens + # EAGLE V1: DRAFT_EXTEND mode - uses spec_info.num_accept_tokens if self.use_mla: kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( @@ -1110,7 +1110,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) ) kv_indices = kv_indices.to(torch.int64) - draft_max_extend_len = torch.max(spec_info.num_accepted_tokens).item() + draft_max_extend_len = torch.max(spec_info.num_accept_tokens).item() self.forward_metadata = ForwardMetadata( kv_indptr, @@ -2240,10 +2240,10 @@ def init_forward_metadata_replay_cuda_graph( num_kv_splits=num_kv_splits, ) elif forward_mode.is_draft_extend(): - # EAGLE V1: Uses spec_info.num_accepted_tokens + # EAGLE V1: Uses spec_info.num_accept_tokens num_tokens_per_bs = self.speculative_num_steps + 1 seq_lens = seq_lens[:bs] - extend_lens = spec_info.num_accepted_tokens[:bs] + extend_lens = spec_info.num_accept_tokens[:bs] qo_indptr = self.qo_indptr[: bs + 1] qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0) kv_indptr = self.kv_indptr[: bs + 1] diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 709f0c51cc18..cc3b1ca32ca6 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -2162,9 +2162,9 @@ def init_forward_metadata_replay_cuda_graph( metadata.cu_seqlens_k[1:].copy_( torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) ) - extend_lens = spec_info.num_accepted_tokens[:bs] - if spec_info.num_accepted_tokens_cpu: - metadata.max_seq_len_q = max(spec_info.num_accepted_tokens_cpu) + extend_lens = spec_info.num_accept_tokens[:bs] + if spec_info.num_accept_tokens_cpu: + metadata.max_seq_len_q = max(spec_info.num_accept_tokens_cpu) else: metadata.max_seq_len_q = 1 diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 45a9a4c9985c..80aa16822be7 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -936,7 +936,7 @@ def forward( def update_mamba_state_after_mtp_verify( self, - accepted_steps: torch.Tensor, + accept_steps: torch.Tensor, mamba_track_indices: Optional[torch.Tensor], mamba_steps_to_track: Optional[torch.Tensor], model, @@ -950,7 +950,7 @@ def update_mamba_state_after_mtp_verify( - index_select kernel launches - nonzero kernel launches """ - request_number = accepted_steps.shape[0] + request_number = accept_steps.shape[0] state_indices_tensor = ( self.linear_attn_backend.forward_metadata.mamba_cache_indices[ @@ -973,13 +973,13 @@ def update_mamba_state_after_mtp_verify( ssm_states, intermediate_state_cache, state_indices_tensor, - accepted_steps, + accept_steps, ) fused_mamba_state_scatter_with_mask( conv_states, intermediate_conv_window_cache, state_indices_tensor, - accepted_steps, + accept_steps, ) # Track indices used for tracking mamba states for prefix cache diff --git a/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py b/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py index c82f4d730fa1..793b37e984e4 100644 --- a/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +++ b/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py @@ -576,7 +576,7 @@ def _causal_conv1d_update_kernel( conv_state_ptr, cache_seqlens_ptr, # circular buffer conv_state_indices_ptr, - num_accepted_tokens_ptr, + num_accept_tokens_ptr, intermediate_conv_window_ptr, intermediate_state_indices_ptr, retrieve_next_token_ptr, @@ -667,7 +667,7 @@ def _causal_conv1d_update_kernel( # - accept 1 tokens: [history2, ..., historyM, draft1] # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] # - and so on. - conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1 + conv_state_token_offset = tl.load(num_accept_tokens_ptr + idx_seq) - 1 else: conv_state_token_offset = 0 @@ -985,7 +985,7 @@ def causal_conv1d_update( activation: Union[bool, str, None] = None, cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, + num_accept_tokens: Optional[torch.Tensor] = None, intermediate_conv_window: Optional[torch.Tensor] = None, intermediate_state_indices: Optional[torch.Tensor] = None, retrieve_next_token: Optional[torch.Tensor] = None, @@ -1071,7 +1071,7 @@ def causal_conv1d_update( if intermediate_state_indices is not None else 0 ) - if num_accepted_tokens is not None: + if num_accept_tokens is not None: state_len = width - 1 + (seqlen - 1) # effective state_len needed else: state_len = width - 1 @@ -1130,7 +1130,7 @@ def grid(META): conv_state, cache_seqlens, conv_state_indices, - num_accepted_tokens, + num_accept_tokens, intermediate_conv_window if intermediate_conv_window is not None else x, intermediate_state_indices, retrieve_next_token, @@ -1174,7 +1174,7 @@ def grid(META): KERNEL_WIDTH=width, SILU_ACTIVATION=activation in ["silu", "swish"], IS_CONTINUOUS_BATCHING=conv_state_indices is not None, - IS_SPEC_DECODING=num_accepted_tokens is not None, + IS_SPEC_DECODING=num_accept_tokens is not None, NP2_STATELEN=np2_statelen, NP2_SEQLEN=np2_seqlen, USE_PAD_SLOT=pad_slot_id is not None, diff --git a/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py b/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py index a97320cac514..77419c2fbc9f 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py +++ b/python/sglang/srt/layers/attention/mamba/mamba_state_scatter_triton.py @@ -17,7 +17,7 @@ def _fused_mamba_state_scatter_with_mask_kernel( dst_ptr, # Raw index arrays (before index_select) dst_indices_raw_ptr, # [total_requests] - state_indices_tensor - step_indices_raw_ptr, # [total_requests] - accepted_steps or mamba_steps_to_track + step_indices_raw_ptr, # [total_requests] - accept_steps or mamba_steps_to_track elem_per_entry: tl.constexpr, src_layer_stride, src_req_stride, diff --git a/python/sglang/srt/layers/attention/nsa/nsa_backend_mtp_precompute.py b/python/sglang/srt/layers/attention/nsa/nsa_backend_mtp_precompute.py index af8af4008f27..1e61c8416232 100644 --- a/python/sglang/srt/layers/attention/nsa/nsa_backend_mtp_precompute.py +++ b/python/sglang/srt/layers/attention/nsa/nsa_backend_mtp_precompute.py @@ -261,9 +261,9 @@ def _precompute_draft_extend_mode( cache_seqlens = seq_lens.to(torch.int32) cu_seqlens_k = compute_cu_seqlens(cache_seqlens) - # Extend seqlens from spec_info: num_accepted_tokens already includes + # Extend seqlens from spec_info: num_accept_tokens already includes # the bonus token (drafts + 1). - extend_seq_lens = spec_info.num_accepted_tokens[:bs] + extend_seq_lens = spec_info.num_accept_tokens[:bs] extend_seq_lens_cpu = extend_seq_lens.tolist() # Page indices (repeated per accept length) diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index cb9d9eab370d..a938f0b01e22 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -515,7 +515,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): page_table, repeats=self.speculative_num_draft_tokens, dim=0 ) else: - # DRAFT_EXTEND (v1): V1 worker extends by (num_accepted_drafts + 1) per request + # DRAFT_EXTEND (v1): V1 worker extends by (num_correct_drafts + 1) per request # after verification. Lengths vary per request based on how many tokens # were accepted. page_table = torch.repeat_interleave( @@ -1053,7 +1053,7 @@ def init_forward_metadata_replay_cuda_graph( torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32) ) - extend_seq_lens = spec_info.num_accepted_tokens[:bs] + extend_seq_lens = spec_info.num_accept_tokens[:bs] extend_seq_lens_cpu = extend_seq_lens.tolist() page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k] diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index d80a87eada5c..f6d27e9c32fe 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -424,9 +424,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): kv_indices = kv_indices.to(torch.int64) mask_indptr = None # TODO(FIXME): This will trigger an invalid Eagle tree when using - # `max(spec_info.num_accepted_tokens_cpu)`. + # `max(spec_info.num_accept_tokens_cpu)`. # It might have been forgotten to update somewhere. - max_extend_len = torch.max(spec_info.num_accepted_tokens).item() + max_extend_len = torch.max(spec_info.num_accept_tokens).item() num_kv_splits = None attn_logits = None attn_lse = None diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 5271a421c1db..253bb92d9a9c 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -528,9 +528,9 @@ def init_forward_metadata_replay_cuda_graph( metadata.cu_seqlens_k[1:].copy_( torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) ) - extend_lens = spec_info.num_accepted_tokens[:bs] - if spec_info.num_accepted_tokens_cpu: - metadata.max_seq_len_q = max(spec_info.num_accepted_tokens_cpu) + extend_lens = spec_info.num_accept_tokens[:bs] + if spec_info.num_accept_tokens_cpu: + metadata.max_seq_len_q = max(spec_info.num_accept_tokens_cpu) else: metadata.max_seq_len_q = 1 diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 68a102d14499..a0fe494189fe 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -131,7 +131,7 @@ def pad_draft_extend_query_kernel( def unpad_draft_extend_output_kernel( raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim) output_ptr, # Output tensor (-1, tp_q_head_num, v_head_dim) - accept_length_ptr, # Accept lengths for each sequence [batch_size] + num_accept_tokens_ptr, # Accept lengths for each sequence [batch_size] cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1] batch_size, token_per_batch, @@ -151,7 +151,7 @@ def unpad_draft_extend_output_kernel( return # Load accept length for this batch - accept_len = tl.load(accept_length_ptr + batch_id) + accept_len = tl.load(num_accept_tokens_ptr + batch_id) if seq_pos >= accept_len: return @@ -745,7 +745,7 @@ def unpad_draft_extend_output( unpad_draft_extend_output_kernel[grid]( raw_out_ptr=raw_out, output_ptr=output, - accept_length_ptr=seq_lens_q, + num_accept_tokens_ptr=seq_lens_q, cumsum_ptr=cu_seqlens_q, batch_size=batch_size, token_per_batch=token_per_batch, @@ -1006,7 +1006,7 @@ def forward_extend( q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim) needs_unpad = False else: - # draft_extend: handle varying num_accepted_drafts_per_req. If total_tokens % bs == 0, + # draft_extend: handle varying num_correct_drafts_per_req. If total_tokens % bs == 0, # we can directly reshape q; otherwise, pad to max_seq_len_q. total_tokens = q.shape[0] tokens_per_seq = total_tokens // bs if bs > 0 else 0 diff --git a/python/sglang/srt/layers/attention/wave_backend.py b/python/sglang/srt/layers/attention/wave_backend.py index 829877db8de5..2a759c22207d 100644 --- a/python/sglang/srt/layers/attention/wave_backend.py +++ b/python/sglang/srt/layers/attention/wave_backend.py @@ -293,9 +293,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ) mask_indptr = None # TODO(FIXME): This will trigger an invalid Eagle tree when using - # `max(spec_info.num_accepted_tokens_cpu)`. + # `max(spec_info.num_accept_tokens_cpu)`. # It might have been forgotten to update somewhere. - max_extend_len = torch.max(spec_info.num_accepted_tokens).item() + max_extend_len = torch.max(spec_info.num_accept_tokens).item() num_kv_splits = None attn_logits = None attn_lse = None diff --git a/python/sglang/srt/layers/utils/logprob.py b/python/sglang/srt/layers/utils/logprob.py index 76474ee79ace..ab199a86d1b8 100644 --- a/python/sglang/srt/layers/utils/logprob.py +++ b/python/sglang/srt/layers/utils/logprob.py @@ -338,11 +338,11 @@ def add_output_logprobs_for_spec_v1( if logits_output is None: logits_output = res.logits_output - if hasattr(res, "num_accepted_drafts_per_req_cpu"): - num_accepted_drafts_per_req_cpu = res.num_accepted_drafts_per_req_cpu + if hasattr(res, "num_correct_drafts_per_req_cpu"): + num_correct_drafts_per_req_cpu = res.num_correct_drafts_per_req_cpu else: # FIXME: Get a NgramVerifyOutput class and use that instead of this hack. - num_accepted_drafts_per_req_cpu = res.num_accepted_drafts.tolist() + num_correct_drafts_per_req_cpu = res.num_correct_drafts.tolist() top_logprobs_nums = batch.top_logprobs_nums token_ids_logprobs = batch.token_ids_logprobs @@ -363,7 +363,7 @@ def add_output_logprobs_for_spec_v1( logits_output.next_token_logits / temperatures, dim=-1 ) batch_next_token_ids = res.accept_tokens - num_tokens_per_req = [accept + 1 for accept in num_accepted_drafts_per_req_cpu] + num_tokens_per_req = [accept + 1 for accept in num_correct_drafts_per_req_cpu] # We should repeat top_logprobs_nums to match num_tokens_per_req. top_logprobs_nums_repeat_interleaved = [ diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 9b507601c555..a4547bf36c24 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -362,8 +362,8 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput): cached_tokens=recv_obj.cached_tokens, cached_tokens_details=recv_obj.cached_tokens_details, spec_verify_ct=recv_obj.spec_verify_ct, - spec_accepted_drafts=recv_obj.spec_accepted_drafts, - spec_acceptance_histogram=recv_obj.spec_acceptance_histogram, + spec_num_correct_drafts=recv_obj.spec_num_correct_drafts, + spec_correct_drafts_histogram=recv_obj.spec_correct_drafts_histogram, input_token_logprobs_val=recv_obj.input_token_logprobs_val, input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, output_token_logprobs_val=recv_obj.output_token_logprobs_val, diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 33bfc1bc6051..293335f646d1 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -95,13 +95,13 @@ class SpeculativeDecodingMetricsMixin: # Accepted drafts: Number of accepted draft tokens during speculative decoding # (strict drafts-only count, excludes the bonus token). - spec_accepted_drafts: List[int] + spec_num_correct_drafts: List[int] # Acceptance histogram: List of lists, where each inner list represents histogram counts. # List index = number of accepted tokens in a step, List value = count of steps with that many accepted tokens. # Example: histogram[0] = 5 means 5 steps with 0 accepted tokens, histogram[3] = 10 means 10 steps with 3 accepted tokens. # Empty list [] when speculative decoding is disabled. - spec_acceptance_histogram: List[List[int]] + spec_correct_drafts_histogram: List[List[int]] # Parameters for a session diff --git a/python/sglang/srt/managers/multi_tokenizer_mixin.py b/python/sglang/srt/managers/multi_tokenizer_mixin.py index 45b75ed90539..a6bb630d243b 100644 --- a/python/sglang/srt/managers/multi_tokenizer_mixin.py +++ b/python/sglang/srt/managers/multi_tokenizer_mixin.py @@ -129,11 +129,11 @@ def _handle_output_by_index(output, i): new_output = BatchTokenIDOutput( rids=[output.rids[i]], spec_verify_ct=_extract_field_by_index(output, "spec_verify_ct", i), - spec_accepted_drafts=_extract_field_by_index( - output, "spec_accepted_drafts", i + spec_num_correct_drafts=_extract_field_by_index( + output, "spec_num_correct_drafts", i ), - spec_acceptance_histogram=_extract_field_by_index( - output, "spec_acceptance_histogram", i + spec_correct_drafts_histogram=_extract_field_by_index( + output, "spec_correct_drafts_histogram", i ), time_stats=_extract_field_by_index(output, "time_stats", i), finished_reasons=_extract_field_by_index(output, "finished_reasons", i), @@ -217,11 +217,11 @@ def _handle_output_by_index(output, i): new_output = BatchStrOutput( rids=[output.rids[i]], spec_verify_ct=_extract_field_by_index(output, "spec_verify_ct", i), - spec_accepted_drafts=_extract_field_by_index( - output, "spec_accepted_drafts", i + spec_num_correct_drafts=_extract_field_by_index( + output, "spec_num_correct_drafts", i ), - spec_acceptance_histogram=_extract_field_by_index( - output, "spec_acceptance_histogram", i + spec_correct_drafts_histogram=_extract_field_by_index( + output, "spec_correct_drafts_histogram", i ), time_stats=_extract_field_by_index(output, "time_stats", i), finished_reasons=_extract_field_by_index(output, "finished_reasons", i), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 99f2744ee763..6db0f8b0f8e0 100755 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -857,12 +857,12 @@ def __init__( self.spec_verify_ct = 0 # Per-request count of accepted draft tokens (excludes the bonus token). - self.spec_accepted_drafts = 0 + self.spec_num_correct_drafts = 0 # Acceptance histogram for speculative decoding. # List index = number of accepted tokens in a step, List value = count of steps with that many accepted tokens. # Example: histogram[0] = 5 means 5 steps with 0 accepted tokens, histogram[3] = 10 means 10 steps with 3 accepted tokens. - self.spec_acceptance_histogram: List[int] = [] + self.spec_correct_drafts_histogram: List[int] = [] # The number of times this request has been retracted / preempted. self.retraction_count = 0 @@ -961,17 +961,17 @@ def pop_overallocated_kv_cache(self) -> Tuple[int, int]: self.kv_overallocated_freed = True return self._cache_commit_len(), self.kv_allocated_len - def update_spec_acceptance_histogram(self, accepted_draft_tokens: int): + def update_spec_correct_drafts_histogram(self, num_correct_drafts: int): """Update the speculative decoding acceptance histogram. Args: - accepted_draft_tokens: Number of draft tokens accepted in this step. + num_correct_drafts: Number of correct draft tokens (no bonus) in this step. """ - if len(self.spec_acceptance_histogram) <= accepted_draft_tokens: - self.spec_acceptance_histogram.extend( - [0] * (accepted_draft_tokens - len(self.spec_acceptance_histogram) + 1) + if len(self.spec_correct_drafts_histogram) <= num_correct_drafts: + self.spec_correct_drafts_histogram.extend( + [0] * (num_correct_drafts - len(self.spec_correct_drafts_histogram) + 1) ) - self.spec_acceptance_histogram[accepted_draft_tokens] += 1 + self.spec_correct_drafts_histogram[num_correct_drafts] += 1 def extend_image_inputs(self, image_inputs): if self.multimodal_inputs is None: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8f4dfe996afb..2d33b3fc0ddf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -3442,7 +3442,7 @@ def get_internal_state(self, recv_req: GetInternalStateReq): if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0: ret["avg_spec_accept_length"] = ( - self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct + self.spec_total_num_accept_tokens / self.spec_total_num_forward_ct ) if RECORD_STEP_TIME: @@ -3481,10 +3481,10 @@ def set_internal_state(self, recv_req: SetInternalStateReq): if if_success: if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0: avg_spec_accept_length = ( - self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct + self.spec_total_num_accept_tokens / self.spec_total_num_forward_ct ) logger.info(f"{avg_spec_accept_length=}") - self.spec_total_num_accepted_tokens = self.spec_total_num_forward_ct = 0 + self.spec_total_num_accept_tokens = self.spec_total_num_forward_ct = 0 for k, v in server_args_dict.items(): setattr(get_global_server_args(), k, v) logger.info(f"Global server args updated! {get_global_server_args()=}") diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 997073ab8ca4..084fe7717de4 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -415,13 +415,13 @@ def _resolve_spec_overlap_token_ids( next_token_ids = result.next_token_ids.tolist() accept_lens = result.accept_lens.tolist() - result.num_accepted_drafts = sum(accept_lens) - len(batch.reqs) - result.num_accepted_drafts_per_req_cpu = [x - 1 for x in accept_lens] + result.num_correct_drafts = sum(accept_lens) - len(batch.reqs) + result.num_correct_drafts_per_req_cpu = [x - 1 for x in accept_lens] # Feed the adaptive controller now that accept_lens is on CPU, # instead of doing a synchronous GPU→CPU copy in the worker hot path. # BaseSpecWorker provides a no-op default for non-adaptive workers. - self.model_worker.on_verify_complete_cpu(result.num_accepted_drafts_per_req_cpu) + self.model_worker.on_verify_complete_cpu(result.num_correct_drafts_per_req_cpu) predict_tokens = [] # In adaptive spec-v2, the worker state may already have switched when this @@ -447,9 +447,9 @@ def _resolve_spec_overlap_token_ids( req.kv_committed_len += accept_lens[i] - 1 req.spec_verify_ct += 1 - accepted_draft_tokens = result.num_accepted_drafts_per_req_cpu[i] - req.spec_accepted_drafts += accepted_draft_tokens - req.update_spec_acceptance_histogram(accepted_draft_tokens) + num_correct_drafts = result.num_correct_drafts_per_req_cpu[i] + req.spec_num_correct_drafts += num_correct_drafts + req.update_spec_correct_drafts_histogram(num_correct_drafts) return predict_tokens @@ -513,7 +513,7 @@ def process_batch_result_decode( self.num_generated_tokens += len(batch.reqs) if not batch.spec_algorithm.is_none(): - self.update_spec_metrics(batch.batch_size(), result.num_accepted_drafts) + self.update_spec_metrics(batch.batch_size(), result.num_correct_drafts) if self.enable_metrics: self.metrics_collector.increment_decode_cuda_graph_pass( value=can_run_cuda_graph @@ -628,7 +628,7 @@ def process_batch_result_decode( self.report_decode_stats( can_run_cuda_graph, running_batch=batch, - num_accepted_drafts=result.num_accepted_drafts, + num_correct_drafts=result.num_correct_drafts, ) def _handle_finished_req( @@ -687,13 +687,13 @@ def _mamba_prefix_cache_update( req.mamba_last_track_seqlen = seq_len elif ( not batch.spec_algorithm.is_none() - and result.num_accepted_drafts_per_req_cpu is not None + and result.num_correct_drafts_per_req_cpu is not None ): # for spec decode, update mamba_last_track_seqlen if this iteration crosses a track interval actual_seq_len = req.seqlen - 1 if ( actual_seq_len // mamba_track_interval - != (actual_seq_len - result.num_accepted_drafts_per_req_cpu[i] - 1) + != (actual_seq_len - result.num_correct_drafts_per_req_cpu[i] - 1) // mamba_track_interval ): req.mamba_next_track_idx = ( @@ -1046,8 +1046,8 @@ def stream_output_generation( cached_tokens = [] cached_tokens_details = [] # Detailed breakdown by cache source spec_verify_ct = [] - spec_accepted_drafts = [] - spec_acceptance_histogram = [] + spec_num_correct_drafts = [] + spec_correct_drafts_histogram = [] retraction_counts = [] output_hidden_states = None load = self.get_loads(GetLoadsReqInput(include=["core"])) @@ -1156,8 +1156,10 @@ def stream_output_generation( if not self.spec_algorithm.is_none(): spec_verify_ct.append(req.spec_verify_ct) - spec_accepted_drafts.append(req.spec_accepted_drafts) - spec_acceptance_histogram.append(req.spec_acceptance_histogram) + spec_num_correct_drafts.append(req.spec_num_correct_drafts) + spec_correct_drafts_histogram.append( + req.spec_correct_drafts_histogram + ) if return_logprob: if ( @@ -1265,8 +1267,8 @@ def stream_output_generation( rids=rids, http_worker_ipcs=http_worker_ipcs, spec_verify_ct=spec_verify_ct, - spec_accepted_drafts=spec_accepted_drafts, - spec_acceptance_histogram=spec_acceptance_histogram, + spec_num_correct_drafts=spec_num_correct_drafts, + spec_correct_drafts_histogram=spec_correct_drafts_histogram, time_stats=time_stats, finished_reasons=finished_reasons, decoded_texts=decoded_texts, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 0c368367bdf9..fc5fb4c6e91f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -2107,37 +2107,37 @@ def _calculate_spec_decoding_metrics( if ( hasattr(recv_obj, "spec_verify_ct") and recv_obj.spec_verify_ct[i] > 0 - and hasattr(recv_obj, "spec_accepted_drafts") - and len(recv_obj.spec_accepted_drafts) > i + and hasattr(recv_obj, "spec_num_correct_drafts") + and len(recv_obj.spec_num_correct_drafts) > i ): # Total number of proposed draft tokens per request. - all_drafts = recv_obj.spec_verify_ct[i] * ( + num_proposed_drafts = recv_obj.spec_verify_ct[i] * ( self.server_args.speculative_num_draft_tokens - 1 ) - accepted_drafts = recv_obj.spec_accepted_drafts[i] + num_correct_drafts = recv_obj.spec_num_correct_drafts[i] # Calculate per-request acceptance rate and average acceptance length. - if all_drafts > 0: - # accept_rate: accepted_drafts / total_proposed_drafts (strict count, no bonus). - meta_info["spec_accept_rate"] = accepted_drafts / all_drafts + if num_proposed_drafts > 0: + # accept_rate: num_correct_drafts / num_proposed_drafts (strict count, no bonus). + meta_info["spec_accept_rate"] = num_correct_drafts / num_proposed_drafts # accept_length: completion_tokens / verify_ct (includes bonus token). meta_info["spec_accept_length"] = ( recv_obj.completion_tokens[i] / recv_obj.spec_verify_ct[i] ) - meta_info["spec_accepted_drafts"] = accepted_drafts - meta_info["spec_proposed_drafts"] = all_drafts + meta_info["spec_accepted_drafts"] = num_correct_drafts + meta_info["spec_proposed_drafts"] = num_proposed_drafts meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] # Acceptance histogram: tracks how many decoding steps accepted a certain number of draft tokens. if ( - recv_obj.spec_acceptance_histogram - and len(recv_obj.spec_acceptance_histogram) > i - and recv_obj.spec_acceptance_histogram[i] + recv_obj.spec_correct_drafts_histogram + and len(recv_obj.spec_correct_drafts_histogram) > i + and recv_obj.spec_correct_drafts_histogram[i] ): - meta_info["spec_accept_histogram"] = recv_obj.spec_acceptance_histogram[ - i - ] + meta_info["spec_accept_histogram"] = ( + recv_obj.spec_correct_drafts_histogram[i] + ) def _request_has_grammar(self, obj: GenerateReqInput) -> bool: return ( diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py index b404f15aebf5..6ac057572f57 100644 --- a/python/sglang/srt/managers/utils.py +++ b/python/sglang/srt/managers/utils.py @@ -27,8 +27,8 @@ class GenerationBatchResult: logits_output: Optional[LogitsProcessorOutput] = None pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None next_token_ids: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None - num_accepted_drafts: int = 0 # no bonus included - num_accepted_drafts_per_req_cpu: Optional[List[int]] = None + num_correct_drafts: int = 0 # no bonus included + num_correct_drafts_per_req_cpu: Optional[List[int]] = None can_run_cuda_graph: bool = False # For output processing diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 0b7e3b14fd4f..1671c5e42846 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -1000,12 +1000,12 @@ def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs): spec_info.topk_index = self._pad_tensor_to_size( spec_info.topk_index, bs ) - if getattr(spec_info, "num_accepted_drafts", None) is not None: - spec_info.num_accepted_drafts = self._pad_tensor_to_size( - spec_info.num_accepted_drafts, bs + if getattr(spec_info, "num_correct_drafts", None) is not None: + spec_info.num_correct_drafts = self._pad_tensor_to_size( + spec_info.num_correct_drafts, bs ) - spec_info.num_accepted_tokens = self._pad_tensor_to_size( - spec_info.num_accepted_tokens, bs + spec_info.num_accept_tokens = self._pad_tensor_to_size( + spec_info.num_accept_tokens, bs ) spec_info.hidden_states = self._pad_tensor_to_size( spec_info.hidden_states, num_tokens @@ -1049,12 +1049,10 @@ def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput): ] logits_output.hidden_states = logits_output.hidden_states[:num_tokens] elif self.forward_mode.is_draft_extend(): # draft extend - self.spec_info.num_accepted_drafts = self.spec_info.num_accepted_drafts[ - :bs - ] - self.spec_info.num_accepted_tokens = self.spec_info.num_accepted_tokens[ + self.spec_info.num_correct_drafts = self.spec_info.num_correct_drafts[ :bs ] + self.spec_info.num_accept_tokens = self.spec_info.num_accept_tokens[:bs] logits_output.next_token_logits = logits_output.next_token_logits[:bs] logits_output.hidden_states = logits_output.hidden_states[:bs] elif self.forward_mode.is_draft_extend_v2(): # draft extend_v2 diff --git a/python/sglang/srt/observability/scheduler_metrics_mixin.py b/python/sglang/srt/observability/scheduler_metrics_mixin.py index 5703434dd5ec..5c04f0d51daf 100644 --- a/python/sglang/srt/observability/scheduler_metrics_mixin.py +++ b/python/sglang/srt/observability/scheduler_metrics_mixin.py @@ -105,11 +105,11 @@ def init_metrics( }.get(getattr(self, "device", ""), "cuda graph") # Cumulative spec-decoding counters (reset every decode_log_interval). - # Each update adds (num_accepted_drafts + bs, bs). + # Each update adds (num_correct_drafts + bs, bs). # `*_accepted_tokens` = drafts + bonus; `*_accepted_drafts` = drafts-only. - self.spec_num_accepted_tokens = 0 # per-log-interval + self.spec_num_accept_tokens = 0 # per-log-interval self.spec_num_forward_ct = 0 - self.spec_total_num_accepted_tokens = 0 # lifetime + self.spec_total_num_accept_tokens = 0 # lifetime self.spec_total_num_forward_ct = 0 # For PD disaggregation @@ -202,12 +202,12 @@ def init_kv_events(self: Scheduler, kv_events_config: Optional[str]): kv_events_config, self.attn_dp_rank ) - def update_spec_metrics(self: Scheduler, bs: int, num_accepted_drafts: int): - self.spec_num_accepted_tokens += num_accepted_drafts + bs + def update_spec_metrics(self: Scheduler, bs: int, num_correct_drafts: int): + self.spec_num_accept_tokens += num_correct_drafts + bs self.spec_num_forward_ct += bs # Bonus tokens updated elsewhere - self.num_generated_tokens += num_accepted_drafts + self.num_generated_tokens += num_correct_drafts def _init_estimated_perf_constants(self: Scheduler) -> None: model_config = self.model_config @@ -345,9 +345,9 @@ def _estimate_decode_perf( def reset_metrics(self: Scheduler): self.forward_ct_decode = 0 self.num_generated_tokens = 0 - self.spec_num_accepted_tokens = 0 + self.spec_num_accept_tokens = 0 self.spec_num_forward_ct = 0 - self.spec_total_num_accepted_tokens = 0 + self.spec_total_num_accept_tokens = 0 self.spec_total_num_forward_ct = 0 def report_prefill_stats( @@ -487,13 +487,13 @@ def report_decode_stats( self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None, - num_accepted_drafts: int = 0, + num_correct_drafts: int = 0, ): batch = running_batch or self.running_batch # Every-iteration work: realtime token counting + status logger if self.current_scheduler_metrics_enabled: - decode_tokens = batch.batch_size() + num_accepted_drafts + decode_tokens = batch.batch_size() + num_correct_drafts self.metrics_collector.increment_realtime_tokens( # TODO unify this w/ the bumping logic in `Scheduler.num_generated_tokens` accumulator decode_tokens=decode_tokens, @@ -551,25 +551,19 @@ def report_decode_stats( spec_accept_length = 0 spec_accept_rate = 0 else: - spec_accept_length = ( - self.spec_num_accepted_tokens / self.spec_num_forward_ct - ) - num_accepted_drafts = ( - self.spec_num_accepted_tokens - self.spec_num_forward_ct - ) + spec_accept_length = self.spec_num_accept_tokens / self.spec_num_forward_ct + num_correct_drafts = self.spec_num_accept_tokens - self.spec_num_forward_ct if self.server_args.speculative_num_draft_tokens: draft_per_round = self.server_args.speculative_num_draft_tokens - 1 else: draft_per_round = self.server_args.speculative_num_steps or 0 total_draft_tokens = self.spec_num_forward_ct * draft_per_round spec_accept_rate = ( - num_accepted_drafts / total_draft_tokens - if total_draft_tokens > 0 - else 0 + num_correct_drafts / total_draft_tokens if total_draft_tokens > 0 else 0 ) - self.spec_total_num_accepted_tokens += self.spec_num_accepted_tokens + self.spec_total_num_accept_tokens += self.spec_num_accept_tokens self.spec_total_num_forward_ct += self.spec_num_forward_ct - self.spec_num_accepted_tokens = self.spec_num_forward_ct = 0 + self.spec_num_accept_tokens = self.spec_num_forward_ct = 0 msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, " cache_hit_rate = 0.0 @@ -870,7 +864,7 @@ def get_loads(self: Scheduler, req: GetLoadsReqInput = None) -> GetLoadsReqOutpu if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0: speculative = SpeculativeMetrics( accept_length=( - self.spec_total_num_accepted_tokens + self.spec_total_num_accept_tokens / self.spec_total_num_forward_ct ), accept_rate=self.stats.spec_accept_rate, diff --git a/python/sglang/srt/speculative/adaptive_runtime_state.py b/python/sglang/srt/speculative/adaptive_runtime_state.py index fc469797b84b..4653fafb6295 100644 --- a/python/sglang/srt/speculative/adaptive_runtime_state.py +++ b/python/sglang/srt/speculative/adaptive_runtime_state.py @@ -71,7 +71,7 @@ class AdaptiveController: The worker only needs to: 1. Call ``register()`` for the initial state, then ``init_states()`` once during startup. - 2. Call ``on_verify_complete(num_accepted_drafts_per_req)`` after each decode verify. + 2. Call ``on_verify_complete(num_correct_drafts_per_req)`` after each decode verify. """ def __init__(self, worker: AdaptiveSpecWorker, config_path: str | None = None): @@ -107,9 +107,9 @@ def init_states(self) -> None: self._states[steps] = state self._activate(self.params.current_steps) - def on_verify_complete(self, num_accepted_drafts_per_req: list[int]) -> None: + def on_verify_complete(self, num_correct_drafts_per_req: list[int]) -> None: """Feed verify results; switch runtime state if EMA warrants it.""" - if self.params.update(num_accepted_drafts_per_req): + if self.params.update(num_correct_drafts_per_req): self._activate(self.params.current_steps) def _activate(self, speculative_num_steps: int) -> None: diff --git a/python/sglang/srt/speculative/adaptive_spec_params.py b/python/sglang/srt/speculative/adaptive_spec_params.py index e7bbb1862724..37bae4e4d688 100644 --- a/python/sglang/srt/speculative/adaptive_spec_params.py +++ b/python/sglang/srt/speculative/adaptive_spec_params.py @@ -132,16 +132,16 @@ def __init__( f"steps={self.current_steps}, candidate_steps={self.candidate_steps}", ) - def update(self, num_accepted_drafts_per_req: list[int]) -> bool: + def update(self, num_correct_drafts_per_req: list[int]) -> bool: """Update EMA with observed accept lengths. Returns True if params changed. Args: - num_accepted_drafts_per_req: Per-request accepted draft token counts from last verify. + num_correct_drafts_per_req: Per-request accepted draft token counts from last verify. """ - if not num_accepted_drafts_per_req: + if not num_correct_drafts_per_req: return False - batch_avg = sum(num_accepted_drafts_per_req) / len(num_accepted_drafts_per_req) + batch_avg = sum(num_correct_drafts_per_req) / len(num_correct_drafts_per_req) self.ema_accept_len = ( 1 - self.ema_alpha ) * self.ema_accept_len + self.ema_alpha * batch_avg diff --git a/python/sglang/srt/speculative/base_spec_worker.py b/python/sglang/srt/speculative/base_spec_worker.py index 566e723e3c67..9faee6b0eed5 100644 --- a/python/sglang/srt/speculative/base_spec_worker.py +++ b/python/sglang/srt/speculative/base_spec_worker.py @@ -33,7 +33,7 @@ def clear_cache_pool(self): # TODO: move this abstract method to BaseTpWorker and call through self.model_runner pass - def on_verify_complete_cpu(self, num_accepted_drafts_per_req: list[int]) -> None: + def on_verify_complete_cpu(self, num_correct_drafts_per_req: list[int]) -> None: """Hook called after verify finishes and accept counts are on CPU. Default no-op. Adaptive-aware workers override this to feed the diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 9cbba1faa61d..a87f74c10cdb 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -16,8 +16,8 @@ ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.dflash_utils import ( - compute_dflash_accept_len_and_bonus, - compute_dflash_sampling_accept_len_and_bonus, + compute_dflash_correct_drafts_and_bonus, + compute_dflash_sampling_correct_drafts_and_bonus, is_dflash_sampling_verify_available, ) from sglang.srt.speculative.spec_info import SpecInput, SpecInputType @@ -323,7 +323,7 @@ def verify( new_bonus_tokens: int64 tensor [bs] (the new current token per request) commit_lens: int32 tensor [bs] (how many verify-input tokens are committed) next_target_hidden: tensor [sum(commit_lens), feature_dim] - num_accepted_drafts_per_req_cpu: list[int] (accepted draft tokens per request) + num_correct_drafts_per_req_cpu: list[int] (accepted draft tokens per request) """ if batch.forward_mode.is_idle(): empty = torch.empty((0,), dtype=torch.int64, device=batch.device) @@ -368,7 +368,7 @@ def verify( and not sampling_info.is_all_greedy and is_dflash_sampling_verify_available() ): - accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( + accept_len, bonus = compute_dflash_sampling_correct_drafts_and_bonus( candidates=candidates, next_token_logits=logits_output.next_token_logits, sampling_info=sampling_info, @@ -377,7 +377,7 @@ def verify( target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( bs, self.draft_token_num ) - accept_len, bonus = compute_dflash_accept_len_and_bonus( + accept_len, bonus = compute_dflash_correct_drafts_and_bonus( candidates=candidates, target_predict=target_predict, ) @@ -388,7 +388,7 @@ def verify( ).cpu() max_acc = self.draft_token_num - 1 - num_accepted_drafts_per_req_cpu: List[int] = [] + num_correct_drafts_per_req_cpu: List[int] = [] commit_lens_cpu: List[int] = [] new_bonus_tokens_list: List[int] = [] @@ -421,9 +421,9 @@ def verify( commit_lens_cpu.append(appended) new_bonus_tokens_list.append(new_bonus_token) - num_accepted_drafts_per_req_cpu.append(max(0, appended - 1)) + num_correct_drafts_per_req_cpu.append(max(0, appended - 1)) req.spec_verify_ct += 1 - req.spec_accepted_drafts += num_accepted_drafts_per_req_cpu[-1] + req.spec_num_correct_drafts += num_correct_drafts_per_req_cpu[-1] commit_lens = torch.tensor(commit_lens_cpu, dtype=torch.int32, device=device) new_bonus_tokens = torch.tensor( @@ -498,5 +498,5 @@ def verify( new_bonus_tokens, commit_lens, next_target_hidden, - num_accepted_drafts_per_req_cpu, + num_correct_drafts_per_req_cpu, ) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 2d7963532654..ff997ce0973d 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -418,7 +418,7 @@ def can_dflash_use_fused_qkv_proj(qkv_proj: Any) -> Tuple[bool, str]: return True, "" -def compute_dflash_accept_len_and_bonus( +def compute_dflash_correct_drafts_and_bonus( *, candidates: torch.Tensor, target_predict: torch.Tensor, @@ -459,7 +459,7 @@ def compute_dflash_accept_len_and_bonus( return accept_len, bonus.to(torch.int64) -def compute_dflash_sampling_accept_len_and_bonus( +def compute_dflash_sampling_correct_drafts_and_bonus( *, candidates: torch.Tensor, next_token_logits: torch.Tensor, diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 8d34db1748a5..57f549dd3e23 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -1080,7 +1080,7 @@ def _update_target_mamba_state_after_verify( if not hasattr(attn_backend, "update_mamba_state_after_mtp_verify"): return - accepted_steps = commit_lens.to(torch.int64) - 1 + accept_steps = commit_lens.to(torch.int64) - 1 mamba_steps_to_track = None if batch.mamba_track_indices is not None: @@ -1103,7 +1103,7 @@ def _update_target_mamba_state_after_verify( ) attn_backend.update_mamba_state_after_mtp_verify( - accepted_steps=accepted_steps, + accept_steps=accept_steps, mamba_track_indices=batch.mamba_track_indices, mamba_steps_to_track=mamba_steps_to_track, model=self.target_worker.model_runner.model, @@ -1178,7 +1178,7 @@ def _to_int32_device_tensor(x, *, device=device): return GenerationBatchResult( logits_output=logits_output, next_token_ids=next_token_ids, - num_accepted_drafts=0, + num_correct_drafts=0, can_run_cuda_graph=batch_result.can_run_cuda_graph, ) @@ -1216,7 +1216,7 @@ def _to_int32_device_tensor(x, *, device=device): new_bonus_tokens, commit_lens, next_target_hidden, - num_accepted_drafts_per_req_cpu, + num_correct_drafts_per_req_cpu, ) = verify_input.verify( batch=batch, logits_output=logits_output, @@ -1239,18 +1239,18 @@ def _to_int32_device_tensor(x, *, device=device): batch.spec_info = draft_input batch.forward_mode = ForwardMode.DECODE - num_accepted_drafts = sum(num_accepted_drafts_per_req_cpu) + num_correct_drafts = sum(num_correct_drafts_per_req_cpu) if not self._logged_first_verify and self.tp_rank == 0: logger.info( - "DFLASH verify completed. num_accepted_drafts_per_req=%s", - num_accepted_drafts_per_req_cpu, + "DFLASH verify completed. num_correct_drafts_per_req=%s", + num_correct_drafts_per_req_cpu, ) self._logged_first_verify = True return GenerationBatchResult( logits_output=logits_output, next_token_ids=new_bonus_tokens, - num_accepted_drafts=num_accepted_drafts, - num_accepted_drafts_per_req_cpu=num_accepted_drafts_per_req_cpu, + num_correct_drafts=num_correct_drafts, + num_correct_drafts_per_req_cpu=num_correct_drafts_per_req_cpu, can_run_cuda_graph=can_run_cuda_graph, ) diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index ca3d65a4a0a8..5a22d3260243 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -50,8 +50,8 @@ class EagleDraftExtendInputBuffers(ForwardInputBuffers): seq_lens: torch.Tensor seq_lens_cpu: torch.Tensor extend_seq_lens: torch.Tensor - num_accepted_drafts: torch.Tensor - num_accepted_tokens: torch.Tensor + num_correct_drafts: torch.Tensor + num_accept_tokens: torch.Tensor next_token_logits_buffer: torch.Tensor global_num_tokens_gpu: Optional[torch.Tensor] global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] @@ -148,10 +148,10 @@ def __init__( extend_seq_lens = torch.full( (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 ) - num_accepted_drafts = torch.full( + num_correct_drafts = torch.full( (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 ) - num_accepted_tokens = torch.full( + num_accept_tokens = torch.full( (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 ) @@ -206,8 +206,8 @@ def __init__( seq_lens=seq_lens, seq_lens_cpu=seq_lens_cpu, extend_seq_lens=extend_seq_lens, - num_accepted_drafts=num_accepted_drafts, - num_accepted_tokens=num_accepted_tokens, + num_correct_drafts=num_correct_drafts, + num_accept_tokens=num_accept_tokens, next_token_logits_buffer=next_token_logits_buffer, global_num_tokens_gpu=global_num_tokens_gpu, global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu, @@ -293,8 +293,8 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0 positions = buffers.positions[:num_tokens] mrope_positions = buffers.mrope_positions[:, :num_tokens] hidden_states = buffers.hidden_states[:num_tokens] - num_accepted_drafts = buffers.num_accepted_drafts[:bs] - num_accepted_tokens = buffers.num_accepted_tokens[:bs] + num_correct_drafts = buffers.num_correct_drafts[:bs] + num_accept_tokens = buffers.num_accept_tokens[:bs] next_token_logits_buffer = buffers.next_token_logits_buffer[ : bs if self.forward_mode == ForwardMode.DRAFT_EXTEND else num_tokens ] @@ -342,8 +342,8 @@ def capture_one_batch_size(self, bs: int, forward: Callable, stream_idx: int = 0 spec_info = EagleDraftExtendInput( hidden_states=hidden_states, - num_accepted_drafts=num_accepted_drafts, - num_accepted_tokens=num_accepted_tokens, + num_correct_drafts=num_correct_drafts, + num_accept_tokens=num_accept_tokens, ) self.deepep_adapter.capture(is_extend_in_batch=True) @@ -448,8 +448,8 @@ def replay(self, forward_batch: ForwardBatch): buffers.seq_lens.fill_(self.seq_len_fill_value) buffers.out_cache_loc.zero_() buffers.positions.zero_() - buffers.num_accepted_drafts.fill_(self.num_tokens_per_bs) - buffers.num_accepted_tokens.fill_(self.num_tokens_per_bs) + buffers.num_correct_drafts.fill_(self.num_tokens_per_bs) + buffers.num_accept_tokens.fill_(self.num_tokens_per_bs) buffers.extend_seq_lens.fill_(self.num_tokens_per_bs) # Common inputs @@ -468,12 +468,12 @@ def replay(self, forward_batch: ForwardBatch): buffers.hidden_states[:num_tokens].copy_( forward_batch.spec_info.hidden_states ) - if forward_batch.spec_info.num_accepted_drafts is not None: - buffers.num_accepted_drafts[:raw_bs].copy_( - forward_batch.spec_info.num_accepted_drafts + if forward_batch.spec_info.num_correct_drafts is not None: + buffers.num_correct_drafts[:raw_bs].copy_( + forward_batch.spec_info.num_correct_drafts ) - buffers.num_accepted_tokens[:raw_bs].copy_( - forward_batch.spec_info.num_accepted_tokens + buffers.num_accept_tokens[:raw_bs].copy_( + forward_batch.spec_info.num_accept_tokens ) buffers.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) @@ -508,12 +508,8 @@ def replay(self, forward_batch: ForwardBatch): if bs != raw_bs: forward_batch.spec_info.positions = buffers.positions[:num_tokens] - forward_batch.spec_info.num_accepted_drafts = buffers.num_accepted_drafts[ - :bs - ] - forward_batch.spec_info.num_accepted_tokens = buffers.num_accepted_tokens[ - :bs - ] + forward_batch.spec_info.num_correct_drafts = buffers.num_correct_drafts[:bs] + forward_batch.spec_info.num_accept_tokens = buffers.num_accept_tokens[:bs] self.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph( bs=bs, @@ -537,10 +533,10 @@ def replay(self, forward_batch: ForwardBatch): # DRAFT_EXTEND_V2: all tokens calculations whether accepted or not. unpadding_bs = num_tokens elif bs != raw_bs: - forward_batch.spec_info.num_accepted_drafts = buffers.num_accepted_drafts[ + forward_batch.spec_info.num_correct_drafts = buffers.num_correct_drafts[ :raw_bs ] - forward_batch.spec_info.num_accepted_tokens = buffers.num_accepted_tokens[ + forward_batch.spec_info.num_accept_tokens = buffers.num_accept_tokens[ :raw_bs ] unpadding_bs = raw_bs diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index 14c17c9eae65..0cbe5448352e 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -38,7 +38,7 @@ align_evict_mask_to_page_size, assign_req_to_token_pool_func, create_extend_after_decode_spec_info, - create_num_accepted_drafts_filter, + create_num_accept_tokens_filter, filter_finished_cache_loc_kernel, generate_simulated_accept_index, get_src_tgt_cache_loc, @@ -275,7 +275,7 @@ def verify( accept_index = torch.full( (bs, self.spec_steps + 1), -1, dtype=torch.int32, device=batch.device ) - num_accepted_drafts = torch.empty((bs,), dtype=torch.int32, device=batch.device) + num_correct_drafts = torch.empty((bs,), dtype=torch.int32, device=batch.device) if bs != len(sampling_info): sampling_info = copy.deepcopy(sampling_info) @@ -326,10 +326,10 @@ def verify( if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE: target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) target_predict = target_predict.reshape(bs, self.draft_token_num) - predict, accept_index, num_accepted_drafts = verify_tree_greedy_func( + predict, accept_index, num_correct_drafts = verify_tree_greedy_func( predicts=predict, # mutable accept_index=accept_index, # mutable - accept_token_num=num_accepted_drafts, # mutable + accept_token_num=num_correct_drafts, # mutable candidates=candidates, retrieve_index=self.retrieve_index, retrieve_next_token=self.retrieve_next_token, @@ -377,7 +377,7 @@ def verify( tree_speculative_sampling_target_only( predicts=predict, # mutable accept_index=accept_index, # mutable - accept_token_num=num_accepted_drafts, # mutable + accept_token_num=num_correct_drafts, # mutable candidates=candidates, # kwarg LHS retained as `retrive_*` to match sgl_kernel op schema. retrive_index=self.retrieve_index, @@ -404,14 +404,14 @@ def verify( if tp_group.world_size > 1: tp_group.broadcast(predict, src=0) tp_group.broadcast(accept_index, src=0) - tp_group.broadcast(num_accepted_drafts, src=0) + tp_group.broadcast(num_correct_drafts, src=0) if SIMULATE_ACC_LEN > 0.0: # Do simulation accept_index = generate_simulated_accept_index( accept_index=accept_index, predict=predict, # mutable - num_accepted_drafts=num_accepted_drafts, # mutable + num_correct_drafts=num_correct_drafts, # mutable bs=bs, spec_steps=self.spec_steps, ) @@ -460,12 +460,14 @@ def verify( else: unfinished_accept_index.append(accept_index[i]) req.spec_verify_ct += 1 - accepted_draft_tokens = sum(1 for idx in accept_index_row if idx != -1) - 1 - req.spec_accepted_drafts += accepted_draft_tokens - req.update_spec_acceptance_histogram(accepted_draft_tokens) + num_correct_drafts_this_req = ( + sum(1 for idx in accept_index_row if idx != -1) - 1 + ) + req.spec_num_correct_drafts += num_correct_drafts_this_req + req.update_spec_correct_drafts_histogram(num_correct_drafts_this_req) if has_finished: - num_accepted_drafts = (accept_index != -1).sum(dim=1) - 1 + num_correct_drafts = (accept_index != -1).sum(dim=1) - 1 # Free the KV cache for unaccepted tokens # TODO: fuse them @@ -473,12 +475,12 @@ def verify( accept_tokens = predict[accept_index] evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False - num_accepted_drafts_cpu = num_accepted_drafts.cpu() - num_accepted_tokens_cpu = num_accepted_drafts_cpu + 1 + num_correct_drafts_cpu = num_correct_drafts.cpu() + num_accept_tokens_cpu = num_correct_drafts_cpu + 1 # FIXME: this `tolist()` fixes the numerical calculation consistency # try to unify the tensor representation and list representation - num_accepted_drafts_list = num_accepted_drafts_cpu.tolist() - num_accepted_tokens_list = num_accepted_tokens_cpu.tolist() + num_correct_drafts_list = num_correct_drafts_cpu.tolist() + num_accept_tokens_list = num_accept_tokens_cpu.tolist() if page_size == 1: # TODO: boolean array index leads to a device sync. Remove it. @@ -501,7 +503,7 @@ def verify( batch.seq_lens, batch.out_cache_loc, accept_index, - num_accepted_drafts, + num_correct_drafts, self.draft_token_num, page_size, ) @@ -518,12 +520,12 @@ def verify( # to_free_slots also needs to be page-aligned without the first partial page # # split each row of out_cache_loc into two parts. - # 1. the first part goes to tgt_cache_loc. length = num_accepted_drafts[i] + 1 + # 1. the first part goes to tgt_cache_loc. length = num_correct_drafts[i] + 1 # 2. the second part goes to to_free_slots. get_target_cache_loc[(bs,)]( tgt_cache_loc, to_free_slots, - num_accepted_drafts, + num_correct_drafts, to_free_num_slots, batch.out_cache_loc, self.draft_token_num, @@ -547,20 +549,20 @@ def verify( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, - batch.seq_lens + num_accepted_drafts + 1, + batch.seq_lens + num_correct_drafts + 1, batch.out_cache_loc, bs, ) else: batch.out_cache_loc = tgt_cache_loc - batch.seq_lens.add_(num_accepted_drafts + 1) - batch.seq_lens_cpu.add_(num_accepted_tokens_cpu) + batch.seq_lens.add_(num_correct_drafts + 1) + batch.seq_lens_cpu.add_(num_accept_tokens_cpu) draft_extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[accept_index], - num_accepted_drafts=num_accepted_drafts, - num_accepted_tokens=num_accepted_drafts + 1, - num_accepted_tokens_cpu=num_accepted_tokens_list, + num_correct_drafts=num_correct_drafts, + num_accept_tokens=num_correct_drafts + 1, + num_accept_tokens_cpu=num_accept_tokens_list, input_ids=accept_tokens, seq_lens=batch.seq_lens, seq_lens_cpu=batch.seq_lens_cpu, @@ -571,7 +573,7 @@ def verify( draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=accept_tokens, - num_accepted_drafts_per_req_cpu=num_accepted_drafts_list, + num_correct_drafts_per_req_cpu=num_correct_drafts_list, accepted_indices=accept_index, ) else: @@ -580,58 +582,57 @@ def verify( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, - batch.seq_lens + num_accepted_drafts + 1, + batch.seq_lens + num_correct_drafts + 1, batch.out_cache_loc[accept_index], bs, ) - batch.seq_lens.add_(num_accepted_drafts + 1) - batch.seq_lens_cpu.add_(num_accepted_tokens_cpu) + batch.seq_lens.add_(num_correct_drafts + 1) + batch.seq_lens_cpu.add_(num_accept_tokens_cpu) if len(unfinished_accept_index) > 0: unfinished_accept_index = torch.cat(unfinished_accept_index) unfinished_index_device = torch.tensor( unfinished_index, dtype=torch.int64, device=predict.device ) - draft_input_num_accepted_drafts_cpu = [ - num_accepted_drafts_list[i] for i in unfinished_index + draft_input_num_correct_drafts_cpu = [ + num_correct_drafts_list[i] for i in unfinished_index ] - draft_input_num_accepted_tokens_cpu = [ - num_accepted_tokens_list[i] for i in unfinished_index + draft_input_num_accept_tokens_cpu = [ + num_accept_tokens_list[i] for i in unfinished_index ] if page_size == 1 or self.topk == 1: batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index] else: batch.out_cache_loc = torch.empty( - len(unfinished_index) - + sum(draft_input_num_accepted_drafts_cpu), + len(unfinished_index) + sum(draft_input_num_correct_drafts_cpu), dtype=torch.int64, device=predict.device, ) - num_accepted_drafts_filter = create_num_accepted_drafts_filter( - num_accepted_drafts, + num_accept_tokens_filter = create_num_accept_tokens_filter( + num_correct_drafts, unfinished_index_device, batch.seq_lens, ) - batch.seq_lens_cpu.add_(num_accepted_tokens_cpu) + batch.seq_lens_cpu.add_(num_accept_tokens_cpu) filter_finished_cache_loc_kernel[(bs,)]( batch.out_cache_loc, tgt_cache_loc, - num_accepted_drafts, - num_accepted_drafts_filter, + num_correct_drafts, + num_accept_tokens_filter, next_power_of_2(bs), next_power_of_2(self.draft_token_num), ) - unfinished_num_accepted_drafts = num_accepted_drafts[ + unfinished_num_correct_drafts = num_correct_drafts[ unfinished_index_device ] draft_extend_input = EagleDraftExtendInput( hidden_states=batch.spec_info.hidden_states[ unfinished_accept_index ], - num_accepted_tokens_cpu=draft_input_num_accepted_tokens_cpu, - num_accepted_drafts=unfinished_num_accepted_drafts, - num_accepted_tokens=unfinished_num_accepted_drafts + 1, + num_accept_tokens_cpu=draft_input_num_accept_tokens_cpu, + num_correct_drafts=unfinished_num_correct_drafts, + num_accept_tokens=unfinished_num_correct_drafts + 1, input_ids=predict[unfinished_accept_index], seq_lens=batch.seq_lens[unfinished_index_device], seq_lens_cpu=batch.seq_lens_cpu[unfinished_index], @@ -649,7 +650,7 @@ def verify( draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=accept_tokens, - num_accepted_drafts_per_req_cpu=num_accepted_drafts_list, + num_correct_drafts_per_req_cpu=num_correct_drafts_list, accepted_indices=accept_index, ) @@ -685,8 +686,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin): verify_done: Optional[torch.cuda.Event] = None # V2 reuses `EagleDraftInput` across phases (V1 has a separate # `EagleDraftExtendInput` for these). Set during V2's draft-extend. - num_accepted_drafts: Optional[torch.Tensor] = None - num_accepted_tokens: Optional[torch.Tensor] = None + num_correct_drafts: Optional[torch.Tensor] = None + num_accept_tokens: Optional[torch.Tensor] = None def __post_init__(self): super().__init__(SpecInputType.EAGLE_DRAFT) @@ -807,13 +808,13 @@ class EagleDraftExtendInput(SpecInput): # by accept_index; consumed by the draft-extend forward. hidden_states: torch.Tensor = None - # Per-req accept counts. `num_accepted_tokens = num_accepted_drafts + 1`. + # Per-req accept counts. `num_accept_tokens = num_correct_drafts + 1`. # Both kept for cuda-graph buffer indexing and the # `create_extend_after_decode_spec_info` kernel. - num_accepted_drafts: torch.Tensor = None - num_accepted_tokens: torch.Tensor = None + num_correct_drafts: torch.Tensor = None + num_accept_tokens: torch.Tensor = None # CPU view, read by attention backends during the extend forward. - num_accepted_tokens_cpu: List[int] = None + num_accept_tokens_cpu: List[int] = None # Batch-state slices for the draft-extend forward. Set by verify (sliced to # reqs continuing into next iter). `prepare_extend_after_decode` copies @@ -871,9 +872,9 @@ def create_idle_input( ) -> "EagleDraftExtendInput": return cls( hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), - num_accepted_drafts=torch.empty((0,), device=device, dtype=torch.int32), - num_accepted_tokens=torch.empty((0,), device=device, dtype=torch.int32), - num_accepted_tokens_cpu=[], + num_correct_drafts=torch.empty((0,), device=device, dtype=torch.int32), + num_accept_tokens=torch.empty((0,), device=device, dtype=torch.int32), + num_accept_tokens_cpu=[], input_ids=torch.empty((0,), device=device, dtype=torch.long), seq_lens=torch.empty((0,), device=device, dtype=torch.int32), seq_lens_cpu=torch.empty((0,), dtype=torch.int32), @@ -895,7 +896,7 @@ def prepare_extend_after_decode( # the worker reads `self.bonus_tokens` to construct next iter's # `EagleDraftInput`. batch.input_ids = self.input_ids - batch.extend_lens = self.num_accepted_tokens_cpu + batch.extend_lens = self.num_accept_tokens_cpu batch.extend_num_tokens = sum(batch.extend_lens) batch.seq_lens = self.seq_lens batch.seq_lens_cpu = self.seq_lens_cpu @@ -905,14 +906,12 @@ def prepare_extend_after_decode( self.capture_hidden_mode = CaptureHiddenMode.LAST self.positions = torch.empty_like(batch.input_ids, dtype=torch.long) - self.bonus_tokens = torch.empty_like( - self.num_accepted_tokens, dtype=torch.int32 - ) + self.bonus_tokens = torch.empty_like(self.num_accept_tokens, dtype=torch.int32) create_extend_after_decode_spec_info[(len(batch.seq_lens),)]( batch.input_ids, batch.seq_lens, - self.num_accepted_tokens, + self.num_accept_tokens, self.positions, self.bonus_tokens, next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))), @@ -926,9 +925,9 @@ def generate_attn_arg_prefill( req_to_token: torch.Tensor, ): device = req_pool_indices.device - bs = self.num_accepted_drafts.numel() + bs = self.num_correct_drafts.numel() qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=device) - qo_indptr[1:] = torch.cumsum(self.num_accepted_tokens, dim=0) + qo_indptr[1:] = torch.cumsum(self.num_accept_tokens, dim=0) cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device) cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) @@ -962,7 +961,7 @@ class EagleVerifyOutput: # step. Includes the bonus token. Used for output processing. accept_tokens: torch.Tensor # Accepted token length per sequence in a batch in CPU (full set). - num_accepted_drafts_per_req_cpu: List[int] + num_correct_drafts_per_req_cpu: List[int] # Accepted indices from logits_output.next_token_logits accepted_indices: torch.Tensor @@ -979,7 +978,7 @@ def create_idle( draft_extend_input=draft_extend_input, logits_output=logits_output, accept_tokens=torch.empty(0, dtype=torch.long, device=device), - num_accepted_drafts_per_req_cpu=[], + num_correct_drafts_per_req_cpu=[], accepted_indices=torch.full( (0, spec_steps + 1), -1, dtype=torch.int32, device=device ), diff --git a/python/sglang/srt/speculative/eagle_info_v2.py b/python/sglang/srt/speculative/eagle_info_v2.py index f88af9402c4e..fe5aa3214228 100644 --- a/python/sglang/srt/speculative/eagle_info_v2.py +++ b/python/sglang/srt/speculative/eagle_info_v2.py @@ -327,13 +327,13 @@ def sample( """ if batch.forward_mode.is_idle(): predict = torch.empty(0, dtype=torch.int32, device=batch.input_ids.device) - num_accepted_drafts = torch.empty( + num_correct_drafts = torch.empty( 0, dtype=torch.int32, device=batch.input_ids.device ) accept_index = torch.empty( 0, dtype=torch.int32, device=batch.input_ids.device ) - return predict, num_accepted_drafts, accept_index + return predict, num_correct_drafts, accept_index bs = len(batch.seq_lens) sampling_info = batch.sampling_info @@ -375,16 +375,16 @@ def sample( accept_index = torch.full( (bs, self.spec_steps + 1), -1, dtype=torch.int32, device=device ) - num_accepted_drafts = torch.empty((bs,), dtype=torch.int32, device=device) + num_correct_drafts = torch.empty((bs,), dtype=torch.int32, device=device) # Sample tokens if sampling_info.is_all_greedy or _is_npu or _is_hip: target_predict = torch.argmax(next_token_logits, dim=-1) target_predict = target_predict.reshape(bs, self.draft_token_num) - predict, accept_index, num_accepted_drafts = verify_tree_greedy_func( + predict, accept_index, num_correct_drafts = verify_tree_greedy_func( predicts=predict, # mutable accept_index=accept_index, # mutable - accept_token_num=num_accepted_drafts, # mutable + accept_token_num=num_correct_drafts, # mutable candidates=candidates, retrieve_index=self.retrieve_index, retrieve_next_token=self.retrieve_next_token, @@ -426,7 +426,7 @@ def sample( tree_speculative_sampling_target_only( predicts=predict, # mutable accept_index=accept_index, # mutable - accept_token_num=num_accepted_drafts, # mutable + accept_token_num=num_correct_drafts, # mutable candidates=candidates, # kwarg LHS retained as `retrive_*` to match sgl_kernel op schema. retrive_index=self.retrieve_index, @@ -453,23 +453,23 @@ def sample( if tp_group.world_size > 1: tp_group.broadcast(predict, src=0) tp_group.broadcast(accept_index, src=0) - tp_group.broadcast(num_accepted_drafts, src=0) + tp_group.broadcast(num_correct_drafts, src=0) if SIMULATE_ACC_LEN > 0: # Do simulation accept_index = generate_simulated_accept_index( accept_index=accept_index, predict=predict, # mutable - num_accepted_drafts=num_accepted_drafts, # mutable + num_correct_drafts=num_correct_drafts, # mutable simulate_acc_len=SIMULATE_ACC_LEN, bs=bs, spec_steps=self.spec_steps, ) - # `num_accepted_drafts` stays drafts-only inside this function; the returned + # `num_correct_drafts` stays drafts-only inside this function; the returned # tensor includes the trailing/bonus token via out-of-place +1 so the # name no longer flips semantics mid-function (naming doc C2). - return predict, num_accepted_drafts + 1, accept_index + return predict, num_correct_drafts + 1, accept_index @triton.jit diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 010831fd3f24..402ec3a9b2bf 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -472,7 +472,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul return GenerationBatchResult( logits_output=logits_output, next_token_ids=next_token_ids, - num_accepted_drafts=0, + num_correct_drafts=0, can_run_cuda_graph=can_run_cuda_graph, ) else: @@ -491,7 +491,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul if get_global_tracing_enabled(): for idx, req in enumerate(batch.reqs): - accepted = verify_output.num_accepted_drafts_per_req_cpu[idx] + accepted = verify_output.num_correct_drafts_per_req_cpu[idx] req.time_stats.set_spec_verify_end_time(accepted_tokens=accepted) set_time_batch( @@ -526,14 +526,14 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul if self.adaptive_controller is not None: self.adaptive_controller.on_verify_complete( - verify_output.num_accepted_drafts_per_req_cpu + verify_output.num_correct_drafts_per_req_cpu ) return GenerationBatchResult( logits_output=logits_output, next_token_ids=verify_output.accept_tokens, - num_accepted_drafts=sum(verify_output.num_accepted_drafts_per_req_cpu), - num_accepted_drafts_per_req_cpu=verify_output.num_accepted_drafts_per_req_cpu, + num_correct_drafts=sum(verify_output.num_correct_drafts_per_req_cpu), + num_correct_drafts_per_req_cpu=verify_output.num_correct_drafts_per_req_cpu, can_run_cuda_graph=can_run_cuda_graph, ) @@ -1003,24 +1003,24 @@ def _mamba_verify_update( if batch.forward_mode.is_idle(): return - accepted_length = ( + num_accept_tokens = ( torch.tensor( - res.num_accepted_drafts_per_req_cpu, + res.num_correct_drafts_per_req_cpu, device=logits_output.hidden_states.device, dtype=torch.int64, ) + 1 ) - cumulative_accepted_lengths = torch.cumsum(accepted_length, dim=0) - # prepend 0 to the cumulative_accepted_lengths + cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0) + # prepend 0 to the cumulative_num_accept_tokens accepted_indices_start = torch.cat( [ torch.zeros( 1, - dtype=cumulative_accepted_lengths.dtype, - device=cumulative_accepted_lengths.device, + dtype=cumulative_num_accept_tokens.dtype, + device=cumulative_num_accept_tokens.device, ), - cumulative_accepted_lengths[:-1], + cumulative_num_accept_tokens[:-1], ] ) accepted_indices_offset = torch.arange( @@ -1034,17 +1034,17 @@ def _mamba_verify_update( # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask # res.accepted_indices.shape[0] > 0 skips DP attn idle batch if spec_info.topk > 1 and res.accepted_indices.shape[0] > 0: - # accepted_indices=[0,2,3,4,5,7,9,10,11], accepted_length=[4, 3, 2], cumulative_accepted_lengths=[4, 7, 9] - # first_token_indices_per_req=prepend(0, accepted_indices[cumulative_accepted_lengths[:-1]]) = [0, 5, 10] - # last_token_indices_per_req=accepted_indices[cumulative_accepted_lengths - 1] = [4, 9, 11] (last token ID of each req) - # max_relative_indices_per_req = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches + # accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9] + # first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] + # last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) + # accept_steps = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches # first_token_indices_per_req = res.accepted_indices[accepted_indices_start] - accepted_steps = ( - res.accepted_indices[cumulative_accepted_lengths - 1] + accept_steps = ( + res.accepted_indices[cumulative_num_accept_tokens - 1] - accepted_indices_offset ) else: - accepted_steps = accepted_length - 1 + accept_steps = num_accept_tokens - 1 if batch.mamba_track_indices is not None: # If after verify, the request's seq_lens has crossed a mamba track interval, @@ -1068,7 +1068,7 @@ def _mamba_verify_update( mamba_steps_to_track = None self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( - accepted_steps=accepted_steps, + accept_steps=accept_steps, mamba_track_indices=batch.mamba_track_indices, mamba_steps_to_track=mamba_steps_to_track, model=self.target_worker.model_runner.model, diff --git a/python/sglang/srt/speculative/eagle_worker_v2.py b/python/sglang/srt/speculative/eagle_worker_v2.py index 27c9a14ed445..88f052aa0c2d 100644 --- a/python/sglang/srt/speculative/eagle_worker_v2.py +++ b/python/sglang/srt/speculative/eagle_worker_v2.py @@ -593,11 +593,11 @@ def _draft_extend_for_decode( self.plan_stream ) - if forward_batch.spec_info.num_accepted_drafts is None: + if forward_batch.spec_info.num_correct_drafts is None: # `batch_result.accept_lens` already includes the bonus token, so use it - # directly for `num_accepted_tokens` and subtract 1 for `num_accepted_drafts`. - forward_batch.spec_info.num_accepted_drafts = batch_result.accept_lens - 1 - forward_batch.spec_info.num_accepted_tokens = batch_result.accept_lens + # directly for `num_accept_tokens` and subtract 1 for `num_correct_drafts`. + forward_batch.spec_info.num_correct_drafts = batch_result.accept_lens - 1 + forward_batch.spec_info.num_accept_tokens = batch_result.accept_lens # Run draft extend batch in the main compute stream can_cuda_graph = ( @@ -793,9 +793,9 @@ def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): return batch_output - def on_verify_complete_cpu(self, accepted_draft_tokens: list[int]) -> None: + def on_verify_complete_cpu(self, num_correct_drafts_per_req: list[int]) -> None: if self.adaptive_controller is not None: - self.adaptive_controller.on_verify_complete(accepted_draft_tokens) + self.adaptive_controller.on_verify_complete(num_correct_drafts_per_req) # -- Adaptive speculative decoding protocol -- @@ -1097,7 +1097,7 @@ def _mamba_verify_update( ): """Update mamba state for hybrid GDN models after verification.""" # `accept_lens` already includes the bonus token (drafts + 1 per req). - accepted_length_with_bonus = accept_lens + num_accept_tokens = accept_lens if not batch.forward_mode.is_idle() and accept_index.numel() > 0: if verify_input.topk != 1: raise ValueError("Spec v2 currently only supports topk = 1.") @@ -1106,16 +1106,16 @@ def _mamba_verify_update( 0, bs * self.speculative_num_draft_tokens, step=self.speculative_num_draft_tokens, - dtype=accepted_length_with_bonus.dtype, - device=accepted_length_with_bonus.device, + dtype=num_accept_tokens.dtype, + device=num_accept_tokens.device, ) - accepted_steps = accepted_length_with_bonus - 1 + accept_steps = num_accept_tokens - 1 if batch.mamba_track_indices is not None: # If after verify, the request's seq_lens has crossed a mamba track interval, # we need to update the mamba state for the request at the crossing point. seq_lens_pre_verify = batch.seq_lens - seq_lens_post_verify = batch.seq_lens + accepted_length_with_bonus + seq_lens_post_verify = batch.seq_lens + num_accept_tokens mamba_track_interval = self.server_args.mamba_track_interval to_track_mask = ( seq_lens_pre_verify // mamba_track_interval @@ -1130,7 +1130,7 @@ def _mamba_verify_update( req_idx = torch.arange( bs, dtype=torch.int64, - device=accepted_length_with_bonus.device, + device=num_accept_tokens.device, ) candidate_track_steps = ( accept_index[req_idx, to_track_ith] - accepted_indices_offset @@ -1144,7 +1144,7 @@ def _mamba_verify_update( mamba_steps_to_track = None self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( - accepted_steps=accepted_steps, + accept_steps=accept_steps, mamba_track_indices=batch.mamba_track_indices, mamba_steps_to_track=mamba_steps_to_track, model=self.target_worker.model_runner.model, @@ -1154,7 +1154,7 @@ def move_accepted_tokens_to_target_kvcache( self, batch: ModelWorkerBatch, accept_index: torch.Tensor, - num_accepted_drafts: torch.Tensor, + num_correct_drafts: torch.Tensor, ): """ Move accepted tokens to the target KV cache. @@ -1162,7 +1162,7 @@ def move_accepted_tokens_to_target_kvcache( Args: batch: The batch to run. accept_index: The index of the accepted tokens. - num_accepted_drafts: The length of the accepted tokens. + num_correct_drafts: The length of the accepted tokens. """ bs = len(batch.seq_lens) size = bs * self.speculative_num_draft_tokens @@ -1179,7 +1179,7 @@ def move_accepted_tokens_to_target_kvcache( batch.req_pool_indices, self.req_to_token_pool.req_to_token, batch.seq_lens, - batch.seq_lens + num_accepted_drafts, + batch.seq_lens + num_correct_drafts, tgt_cache_loc, self.req_to_token_pool.req_to_token.shape[1], next_power_of_2(bs), diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py index 74ff0ef7ee70..71ebc4c996be 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py @@ -137,7 +137,7 @@ def select_last_extend_hidden( def select_last_verified_seed( draft_input: FrozenKVMTPDraftExtendInput, ) -> Tuple[torch.Tensor, torch.Tensor]: - counts = draft_input.num_accepted_tokens.to(torch.long) + counts = draft_input.num_accept_tokens.to(torch.long) last_indices = torch.cumsum(counts, dim=0) - 1 return ( draft_input.bonus_tokens[last_indices], diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index d73da6fc5454..e2e477fb3581 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -435,7 +435,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul return GenerationBatchResult( logits_output=logits_output, next_token_ids=next_token_ids, - num_accepted_drafts=0, + num_correct_drafts=0, can_run_cuda_graph=can_run_cuda_graph, ) @@ -452,7 +452,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul if get_global_tracing_enabled(): for idx, req in enumerate(batch.reqs): - accepted = verify_output.num_accepted_drafts_per_req_cpu[idx] + accepted = verify_output.num_correct_drafts_per_req_cpu[idx] req.time_stats.set_spec_verify_end_time(accepted_tokens=accepted) set_time_batch(batch.reqs, "set_spec_draft_extend_start_time", trace_only=True) @@ -473,8 +473,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul return GenerationBatchResult( logits_output=logits_output, next_token_ids=verify_output.accept_tokens, - num_accepted_drafts=sum(verify_output.num_accepted_drafts_per_req_cpu), - num_accepted_drafts_per_req_cpu=verify_output.num_accepted_drafts_per_req_cpu, + num_correct_drafts=sum(verify_output.num_correct_drafts_per_req_cpu), + num_correct_drafts_per_req_cpu=verify_output.num_correct_drafts_per_req_cpu, can_run_cuda_graph=can_run_cuda_graph, ) diff --git a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py index 18050f9e0c2c..f9b07ecb15aa 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_draft_extend_cuda_graph_runner.py @@ -72,8 +72,8 @@ class MultiLayerEagleDraftExtendInputBuffers(ForwardInputBuffers): seq_lens: torch.Tensor seq_lens_cpu: torch.Tensor req_pool_indices: torch.Tensor - num_accepted_drafts: torch.Tensor - num_accepted_tokens: torch.Tensor + num_correct_drafts: torch.Tensor + num_accept_tokens: torch.Tensor # Per-step buffers extend_seq_lens: torch.Tensor extend_start_loc: torch.Tensor @@ -160,8 +160,8 @@ def init_buffers_and_capture( # shared states seq_lens = cuda_graph_buffers["seq_lens"] req_pool_indices = cuda_graph_buffers["req_pool_indices"] - num_accepted_drafts = cuda_graph_buffers["num_accepted_drafts"] - num_accepted_tokens = cuda_graph_buffers["num_accepted_tokens"] + num_correct_drafts = cuda_graph_buffers["num_correct_drafts"] + num_accept_tokens = cuda_graph_buffers["num_accept_tokens"] extend_seq_lens = torch.full( (self.max_bs,), @@ -231,8 +231,8 @@ def init_buffers_and_capture( seq_lens=seq_lens, seq_lens_cpu=seq_lens_cpu, req_pool_indices=req_pool_indices, - num_accepted_drafts=num_accepted_drafts, - num_accepted_tokens=num_accepted_tokens, + num_correct_drafts=num_correct_drafts, + num_accept_tokens=num_accept_tokens, extend_seq_lens=extend_seq_lens, extend_start_loc=extend_start_loc, mrope_positions=mrope_positions, @@ -304,8 +304,8 @@ def get_forward_batch(self, bs: int) -> ForwardBatch: extend_seq_lens = buffers.extend_seq_lens[:bs] extend_seq_lens_cpu = self.extend_seq_lens_cpu[:bs] extend_start_loc = buffers.extend_start_loc[:bs] - num_accepted_drafts = buffers.num_accepted_drafts[:bs] - num_accepted_tokens = buffers.num_accepted_tokens[:bs] + num_correct_drafts = buffers.num_correct_drafts[:bs] + num_accept_tokens = buffers.num_accept_tokens[:bs] out_cache_loc = buffers.out_cache_loc[:num_tokens] positions = buffers.positions[:num_tokens] mrope_positions = buffers.mrope_positions[:, :num_tokens] @@ -351,8 +351,8 @@ def get_forward_batch(self, bs: int) -> ForwardBatch: spec_info = EagleDraftExtendInput( hidden_states=hidden_states, - num_accepted_drafts=num_accepted_drafts, - num_accepted_tokens=num_accepted_tokens, + num_correct_drafts=num_correct_drafts, + num_accept_tokens=num_accept_tokens, ) spec_info.positions = None @@ -444,12 +444,12 @@ def run_once(): ): buffers.hidden_states[:num_tokens].copy_(ret.hidden_states[:num_tokens]) - # num_accepted_drafts is drafts-only; the last accepted draft sits at index - # `num_accepted_drafts` within the (current_token + drafts) slot range. + # num_correct_drafts is drafts-only; the last accepted draft sits at index + # `num_correct_drafts` within the (current_token + drafts) slot range. select_index = ( torch.arange(bs, device=self.model_runner.device) * (self.speculative_num_draft_tokens + self.step) - + buffers.num_accepted_drafts[:bs] + + buffers.num_correct_drafts[:bs] + self.step ) @@ -462,7 +462,7 @@ def run_once(): # speculative_num_draft_tokens includes the current-token slot, so -1. padding_lens = ( self.speculative_num_draft_tokens - 1 - ) - buffers.num_accepted_drafts[:bs] + ) - buffers.num_correct_drafts[:bs] assign_new_state_triton( ret.topk_index, buffers.input_ids, @@ -523,12 +523,12 @@ def init_replay_state( buffers.hidden_states[:num_tokens].copy_( forward_batch.spec_info.hidden_states ) - if forward_batch.spec_info.num_accepted_drafts is not None: - buffers.num_accepted_drafts[:raw_bs].copy_( - forward_batch.spec_info.num_accepted_drafts + if forward_batch.spec_info.num_correct_drafts is not None: + buffers.num_correct_drafts[:raw_bs].copy_( + forward_batch.spec_info.num_correct_drafts ) - buffers.num_accepted_tokens[:raw_bs].copy_( - forward_batch.spec_info.num_accepted_tokens + buffers.num_accept_tokens[:raw_bs].copy_( + forward_batch.spec_info.num_accept_tokens ) buffers.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) @@ -566,8 +566,8 @@ def replay(self, forward_batch: ForwardBatch, init_state: bool = True): buffers.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) forward_batch.spec_info.hidden_states = buffers.hidden_states[:num_tokens] - forward_batch.spec_info.num_accepted_drafts = buffers.num_accepted_drafts[:bs] - forward_batch.spec_info.num_accepted_tokens = buffers.num_accepted_tokens[:bs] + forward_batch.spec_info.num_correct_drafts = buffers.num_correct_drafts[:bs] + forward_batch.spec_info.num_accept_tokens = buffers.num_accept_tokens[:bs] forward_batch.spec_info.num_tokens_per_req = self.num_tokens_per_bs forward_batch.spec_info.num_tokens_for_logprob_per_req = 1 forward_batch.spec_info.positions = buffers.positions[:num_tokens] @@ -597,10 +597,10 @@ def replay(self, forward_batch: ForwardBatch, init_state: bool = True): # DRAFT_EXTEND_V2: all tokens calculations whether accepted or not. unpadding_bs = num_tokens elif bs != raw_bs: - forward_batch.spec_info.num_accepted_drafts = buffers.num_accepted_drafts[ + forward_batch.spec_info.num_correct_drafts = buffers.num_correct_drafts[ :raw_bs ] - forward_batch.spec_info.num_accepted_tokens = buffers.num_accepted_tokens[ + forward_batch.spec_info.num_accept_tokens = buffers.num_accept_tokens[ :raw_bs ] unpadding_bs = raw_bs @@ -690,10 +690,10 @@ def _init_and_capture(self): self.cuda_graph_buffers["req_pool_indices"] = torch.zeros( (self.max_bs,), dtype=torch.int64 ) - self.cuda_graph_buffers["num_accepted_drafts"] = torch.full( + self.cuda_graph_buffers["num_correct_drafts"] = torch.full( (self.max_bs,), 1, dtype=torch.int32 ) - self.cuda_graph_buffers["num_accepted_tokens"] = torch.full( + self.cuda_graph_buffers["num_accept_tokens"] = torch.full( (self.max_bs,), 1, dtype=torch.int32 ) @@ -728,10 +728,10 @@ def reset_buffers(self, forward_batch, batch_result): self.cuda_graph_buffers["positions"].zero_() # `batch_result.accept_lens` is drafts + bonus. bs = forward_batch.batch_size - self.cuda_graph_buffers["num_accepted_drafts"][:bs].copy_( + self.cuda_graph_buffers["num_correct_drafts"][:bs].copy_( batch_result.accept_lens - 1 ) - self.cuda_graph_buffers["num_accepted_tokens"][:bs].copy_( + self.cuda_graph_buffers["num_accept_tokens"][:bs].copy_( batch_result.accept_lens ) diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index c268168ba9f2..71da801e5f8f 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -265,7 +265,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul return GenerationBatchResult( logits_output=logits_output, next_token_ids=next_token_ids, - num_accepted_drafts=0, + num_correct_drafts=0, can_run_cuda_graph=can_run_cuda_graph, ) else: @@ -301,7 +301,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul return GenerationBatchResult( logits_output=logits_output, next_token_ids=verify_output.accept_tokens, - num_accepted_drafts=sum(verify_output.num_accepted_drafts_per_req_cpu), + num_correct_drafts=sum(verify_output.num_correct_drafts_per_req_cpu), can_run_cuda_graph=can_run_cuda_graph, ) @@ -552,9 +552,9 @@ def verify(self, batch: ScheduleBatch): logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] if self.target_worker.model_runner.hybrid_gdn_config is not None: - accepted_length = ( + num_accept_tokens = ( torch.tensor( - res.num_accepted_drafts_per_req_cpu, + res.num_correct_drafts_per_req_cpu, device=logits_output.hidden_states.device, dtype=torch.int64, ) @@ -564,30 +564,30 @@ def verify(self, batch: ScheduleBatch): # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask # res.accepted_indices.shape[0] > 0 skips DP attn idle batch if spec_info.topk > 1 and res.accepted_indices.shape[0] > 0: - # accepted_indices=[0,2,3,4,5,7,9,10,11], accepted_length=[4, 3, 2], cumulative_accepted_lengths=[4, 7, 9] - # first_token_indices_per_req=prepend(0, accepted_indices[cumulative_accepted_lengths[:-1]]) = [0, 5, 10] - # last_token_indices_per_req=accepted_indices[cumulative_accepted_lengths - 1] = [4, 9, 11] (last token ID of each req) + # accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9] + # first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] + # last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) # max_relative_indices_per_req = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches - cumulative_accepted_lengths = torch.cumsum(accepted_length, dim=0) + cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0) req_start_positions = torch.cat( [ torch.zeros( 1, - dtype=cumulative_accepted_lengths.dtype, - device=cumulative_accepted_lengths.device, + dtype=cumulative_num_accept_tokens.dtype, + device=cumulative_num_accept_tokens.device, ), - cumulative_accepted_lengths[:-1], + cumulative_num_accept_tokens[:-1], ] ) first_token_indices_per_req = res.accepted_indices[req_start_positions] last_token_indices_per_req = res.accepted_indices[ - cumulative_accepted_lengths - 1 + cumulative_num_accept_tokens - 1 ] max_relative_indices_per_req = ( last_token_indices_per_req - first_token_indices_per_req ) else: - max_relative_indices_per_req = accepted_length - 1 + max_relative_indices_per_req = num_accept_tokens - 1 self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( max_relative_indices_per_req, self.target_worker.model_runner.model ) diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index 0b60cfe868c3..777aea1aa3ae 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -190,12 +190,14 @@ def _fill_requests( ) raise e req.spec_verify_ct += 1 - accepted_draft_tokens = sum(1 for idx in accept_index_row if idx != -1) - 1 - req.spec_accepted_drafts += accepted_draft_tokens - req.update_spec_acceptance_histogram(accepted_draft_tokens) + num_correct_drafts_this_req = ( + sum(1 for idx in accept_index_row if idx != -1) - 1 + ) + req.spec_num_correct_drafts += num_correct_drafts_this_req + req.update_spec_correct_drafts_histogram(num_correct_drafts_this_req) if has_finished: - self.num_accepted_drafts = (self.accepted_indices != -1).sum(dim=1) - 1 + self.num_correct_drafts = (self.accepted_indices != -1).sum(dim=1) - 1 self.accepted_indices = self.accepted_indices[self.accepted_indices != -1] logits_output.next_token_logits = logits_output.next_token_logits[ @@ -211,7 +213,7 @@ def _free_cache( self, batch: ScheduleBatch, page_size: int, - num_accepted_drafts_cpu: torch.Tensor, + num_correct_drafts_cpu: torch.Tensor, ): bs = batch.batch_size() # Free the KV cache for unaccepted tokens @@ -228,7 +230,7 @@ def _free_cache( batch.seq_lens, batch.out_cache_loc, self.accepted_indices, - self.num_accepted_drafts, + self.num_correct_drafts, self.draft_token_num, page_size, ) @@ -245,12 +247,12 @@ def _free_cache( # to_free_slots also needs to be page-aligned without the first partial page # # split each row of out_cache_loc into two parts. - # 1. the first part goes to tgt_cache_loc. length = num_accepted_drafts[i] + 1 + # 1. the first part goes to tgt_cache_loc. length = num_correct_drafts[i] + 1 # 2. the second part goes to to_free_slots. get_target_cache_loc[(bs,)]( tgt_cache_loc, to_free_slots, - self.num_accepted_drafts, + self.num_correct_drafts, to_free_num_slots, batch.out_cache_loc, self.draft_token_num, @@ -267,16 +269,16 @@ def _free_cache( ) batch.out_cache_loc = tgt_cache_loc - num_accepted_drafts_list = num_accepted_drafts_cpu.tolist() + num_correct_drafts_list = num_correct_drafts_cpu.tolist() for i, req in enumerate(batch.reqs): - req.kv_committed_len += num_accepted_drafts_list[i] + 1 + req.kv_committed_len += num_correct_drafts_list[i] + 1 req.kv_allocated_len = req.kv_committed_len assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, batch.seq_lens, - batch.seq_lens + self.num_accepted_tokens, + batch.seq_lens + self.num_accept_tokens, batch.out_cache_loc, batch.req_to_token_pool.req_to_token.shape[1], triton.next_power_of_2(bs), @@ -298,14 +300,14 @@ def _greedy_verify( self.accepted_indices = torch.full( (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device ) - self.num_accepted_drafts = torch.empty( + self.num_correct_drafts = torch.empty( (bs,), dtype=torch.int32, device=self.device ) verify_tree_greedy( predicts=self.predict, # mutable accept_index=self.accepted_indices, # mutable - accept_token_num=self.num_accepted_drafts, # mutable + accept_token_num=self.num_correct_drafts, # mutable candidates=candidates, # kwarg LHS retained as `retrive_*` to match sgl_kernel op schema. retrive_index=self.retrieve_index, @@ -328,7 +330,7 @@ def _sampling_verify( self.accepted_indices = torch.full( (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device ) - self.num_accepted_drafts = torch.empty( + self.num_correct_drafts = torch.empty( (bs,), dtype=torch.int32, device=self.device ) # apply temperature and get target probs @@ -370,7 +372,7 @@ def _sampling_verify( tree_speculative_sampling_target_only( predicts=self.predict, # mutable accept_index=self.accepted_indices, # mutable - accept_token_num=self.num_accepted_drafts, # mutable + accept_token_num=self.num_correct_drafts, # mutable candidates=candidates.to(torch.int64), # kwarg LHS retained as `retrive_*` to match sgl_kernel op schema. retrive_index=self.retrieve_index.to(torch.int64), @@ -452,19 +454,19 @@ def verify( self._fill_requests(batch, logits_output) # Sync the bonus-included view after the kernel + `_fill_requests` - # finalize `num_accepted_drafts`. - self.num_accepted_tokens = self.num_accepted_drafts + 1 + # finalize `num_correct_drafts`. + self.num_accept_tokens = self.num_correct_drafts + 1 - num_accepted_drafts_cpu = self.num_accepted_drafts.cpu() - num_accepted_tokens_cpu = num_accepted_drafts_cpu + 1 - num_accepted_drafts = num_accepted_drafts_cpu.sum().item() + num_correct_drafts_cpu = self.num_correct_drafts.cpu() + num_accept_tokens_cpu = num_correct_drafts_cpu + 1 + num_correct_drafts = num_correct_drafts_cpu.sum().item() - self._free_cache(batch, page_size, num_accepted_drafts_cpu) + self._free_cache(batch, page_size, num_correct_drafts_cpu) - batch.seq_lens.add_(self.num_accepted_tokens) - batch.seq_lens_cpu.add_(num_accepted_tokens_cpu) + batch.seq_lens.add_(self.num_accept_tokens) + batch.seq_lens_cpu.add_(num_accept_tokens_cpu) - return logits_output, self.accept_tokens, num_accepted_drafts + return logits_output, self.accept_tokens, num_correct_drafts def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): pass diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 82aa4a18753e..19bd77885f4b 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -269,9 +269,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul model_worker_batch = batch.get_model_worker_batch() spec_info = model_worker_batch.spec_info - num_accepted_drafts = 0 + num_correct_drafts = 0 accept_lens = None - num_accepted_drafts_per_req_cpu = None + num_correct_drafts_per_req_cpu = None if model_worker_batch.forward_mode.is_target_verify(): if batch.has_grammar: @@ -312,25 +312,25 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul # and will be applied to produce wrong results batch.sampling_info.vocab_mask = None - logits_output, next_token_ids, num_accepted_drafts = verify_input.verify( + logits_output, next_token_ids, num_correct_drafts = verify_input.verify( batch, logits_output, self.page_size, vocab_mask ) - num_accepted_drafts_per_req_cpu = ( - verify_input.num_accepted_drafts.cpu().tolist() + num_correct_drafts_per_req_cpu = ( + verify_input.num_correct_drafts.cpu().tolist() ) if get_global_tracing_enabled(): for idx, req in enumerate(batch.reqs): accepted = ( - verify_input.num_accepted_drafts[idx].item() - if verify_input.num_accepted_drafts is not None + verify_input.num_correct_drafts[idx].item() + if verify_input.num_correct_drafts is not None else 0 ) req.time_stats.set_spec_verify_end_time(accepted_tokens=accepted) # Store accept_lens (with bonus) for per-request metrics; downstream # subtracts 1 to recover drafts-only counts. - accept_lens = verify_input.num_accepted_tokens + accept_lens = verify_input.num_accept_tokens if batch.return_logprob: add_output_logprobs_for_spec_v1(batch, verify_input, logits_output) self._update_ngram_corpus(batch) @@ -359,8 +359,8 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul return GenerationBatchResult( logits_output=logits_output, next_token_ids=next_token_ids, - num_accepted_drafts=num_accepted_drafts, - num_accepted_drafts_per_req_cpu=num_accepted_drafts_per_req_cpu, + num_correct_drafts=num_correct_drafts, + num_correct_drafts_per_req_cpu=num_correct_drafts_per_req_cpu, can_run_cuda_graph=can_run_cuda_graph, accept_lens=accept_lens, ) diff --git a/python/sglang/srt/speculative/spec_utils.py b/python/sglang/srt/speculative/spec_utils.py index f604bfc2ad11..5c879a978ad7 100644 --- a/python/sglang/srt/speculative/spec_utils.py +++ b/python/sglang/srt/speculative/spec_utils.py @@ -361,7 +361,7 @@ def align_evict_mask_to_page_size( def get_target_cache_loc( tgt_cache_loc, to_free_slots, - num_accepted_drafts, + num_correct_drafts, to_free_num_slots, out_cache_loc, num_verify_tokens: tl.constexpr, @@ -373,9 +373,9 @@ def get_target_cache_loc( bs_offset = tl.arange(0, bs_upper) # write the first part to tgt_cache_loc - accept_len_all = tl.load(num_accepted_drafts + bs_offset, mask=bs_offset < bid) + accept_len_all = tl.load(num_correct_drafts + bs_offset, mask=bs_offset < bid) tgt_cache_loc_start = tl.sum(accept_len_all) + bid - copy_len = tl.load(num_accepted_drafts + bid) + 1 + copy_len = tl.load(num_correct_drafts + bid) + 1 out_cache_loc_row = tl.load( out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len ) @@ -408,7 +408,7 @@ def get_src_tgt_cache_loc( seq_lens: torch.Tensor, out_cache_loc: torch.Tensor, accept_index: torch.Tensor, - num_accepted_drafts: torch.Tensor, + num_correct_drafts: torch.Tensor, draft_token_num: int, page_size: int, ): @@ -416,7 +416,7 @@ def get_src_tgt_cache_loc( tgt_cache_loc = torch.empty_like(src_cache_loc) extended_len = seq_lens + draft_token_num keep_len = torch.minimum( - (seq_lens + num_accepted_drafts + 1 + page_size - 1) // page_size * page_size, + (seq_lens + num_correct_drafts + 1 + page_size - 1) // page_size * page_size, extended_len, ) to_free_num_slots = extended_len - keep_len @@ -427,25 +427,25 @@ def get_src_tgt_cache_loc( def filter_finished_cache_loc_kernel( out_cache_loc, tgt_cache_loc, - num_accepted_drafts, - num_accepted_drafts_filter, + num_correct_drafts, + num_accept_tokens_filter, bs_upper: tl.constexpr, num_verify_tokens_upper: tl.constexpr, ): bid = tl.program_id(0) bs_offset = tl.arange(0, bs_upper) - num_accepted_drafts_all = tl.load( - num_accepted_drafts + bs_offset, mask=bs_offset < bid + num_correct_drafts_all = tl.load( + num_correct_drafts + bs_offset, mask=bs_offset < bid ) - old_start = tl.sum(num_accepted_drafts_all) + bid + old_start = tl.sum(num_correct_drafts_all) + bid - num_accepted_drafts_filter_all = tl.load( - num_accepted_drafts_filter + bs_offset, mask=bs_offset < bid + num_accept_tokens_filter_all = tl.load( + num_accept_tokens_filter + bs_offset, mask=bs_offset < bid ) - new_start = tl.sum(num_accepted_drafts_filter_all) + new_start = tl.sum(num_accept_tokens_filter_all) - copy_len = tl.load(num_accepted_drafts_filter + bid) + copy_len = tl.load(num_accept_tokens_filter + bid) copy_offset = tl.arange(0, num_verify_tokens_upper) value = tl.load( tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len @@ -456,17 +456,17 @@ def filter_finished_cache_loc_kernel( @torch.compile(dynamic=True, disable=_is_npu) -def create_num_accepted_drafts_filter( - num_accepted_drafts: torch.Tensor, +def create_num_accept_tokens_filter( + num_correct_drafts: torch.Tensor, unfinished_index_device: torch.Tensor, seq_lens: torch.Tensor, ): - num_accepted_drafts_filter = torch.zeros_like(num_accepted_drafts) - num_accepted_drafts_filter[unfinished_index_device] = ( - num_accepted_drafts[unfinished_index_device] + 1 + num_accept_tokens_filter = torch.zeros_like(num_correct_drafts) + num_accept_tokens_filter[unfinished_index_device] = ( + num_correct_drafts[unfinished_index_device] + 1 ) - seq_lens.add_(num_accepted_drafts + 1) - return num_accepted_drafts_filter + seq_lens.add_(num_correct_drafts + 1) + return num_accept_tokens_filter def _select_top_k_tokens_first( @@ -544,7 +544,7 @@ def select_top_k_tokens( def generate_simulated_accept_index( accept_index, predict, - num_accepted_drafts, + num_correct_drafts, bs, spec_steps, simulate_acc_len: float = SIMULATE_ACC_LEN, @@ -589,7 +589,7 @@ def generate_simulated_accept_index( sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange( simulate_acc_len, device=accept_index.device ) - num_accepted_drafts.fill_(simulate_acc_len - 1) + num_correct_drafts.fill_(simulate_acc_len - 1) predict.fill_(100) # some legit token id return sim_accept_index diff --git a/python/sglang/test/attention/test_trtllm_mla_backend.py b/python/sglang/test/attention/test_trtllm_mla_backend.py index 69784b84fa29..6ba9a14c05d1 100755 --- a/python/sglang/test/attention/test_trtllm_mla_backend.py +++ b/python/sglang/test/attention/test_trtllm_mla_backend.py @@ -1308,7 +1308,7 @@ def _create_test_output_data( device = torch.device("cuda") # Create accept lengths (varying lengths for each batch) - num_accepted_drafts_per_req = torch.randint( + num_accept_tokens_per_req = torch.randint( 1, token_per_batch + 1, (batch_size,), device=device, dtype=torch.int32 ) @@ -1316,7 +1316,7 @@ def _create_test_output_data( cum_accept_lengths = torch.zeros( batch_size + 1, device=device, dtype=torch.int32 ) - cum_accept_lengths[1:] = torch.cumsum(num_accepted_drafts_per_req, dim=0) + cum_accept_lengths[1:] = torch.cumsum(num_accept_tokens_per_req, dim=0) # Create raw output tensor (batch format) raw_out = torch.randn( @@ -1334,7 +1334,7 @@ def _create_test_output_data( total_tokens, tp_q_head_num, v_head_dim, device=device, dtype=dtype ) - return raw_out, output, num_accepted_drafts_per_req, cum_accept_lengths + return raw_out, output, num_accept_tokens_per_req, cum_accept_lengths # Test 1: pad_draft_extend_query_kernel basic functionality with self.subTest(test="pad_kernel_basic"): @@ -1395,7 +1395,7 @@ def _create_test_output_data( tp_q_head_num = 16 v_head_dim = 64 - raw_out, output, num_accepted_drafts_per_req, cum_accept_lengths = ( + raw_out, output, num_accept_tokens_per_req, cum_accept_lengths = ( _create_test_output_data( self, batch_size, token_per_batch, tp_q_head_num, v_head_dim ) @@ -1408,7 +1408,7 @@ def _create_test_output_data( unpad_draft_extend_output_kernel[grid]( raw_out_ptr=raw_out, output_ptr=output, - accept_length_ptr=num_accepted_drafts_per_req, + num_accept_tokens_ptr=num_accept_tokens_per_req, cumsum_ptr=cum_accept_lengths, batch_size=batch_size, token_per_batch=token_per_batch, @@ -1419,7 +1419,7 @@ def _create_test_output_data( # Verify the unpadding worked correctly for i in range(batch_size): - accept_len = num_accepted_drafts_per_req[i].item() + accept_len = num_accept_tokens_per_req[i].item() output_start = cum_accept_lengths[i].item() # Check that valid positions are copied correctly diff --git a/python/sglang/test/kits/spec_decoding_kit.py b/python/sglang/test/kits/spec_decoding_kit.py index 7ad509bb6ad9..4262743de740 100644 --- a/python/sglang/test/kits/spec_decoding_kit.py +++ b/python/sglang/test/kits/spec_decoding_kit.py @@ -4,7 +4,7 @@ class SpecDecodingMixin: bs_1_speed_thres: float - num_accepted_drafts_thres: float + accept_length_thres: float def test_bs_1_speed(self): args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) @@ -19,5 +19,5 @@ def test_bs_1_speed(self): f"{speed=:.2f} token/s\n" ) - self.assertGreater(acc_length, self.num_accepted_drafts_thres) + self.assertGreater(acc_length, self.accept_length_thres) self.assertGreater(speed, self.bs_1_speed_thres) diff --git a/test/registered/8-gpu-models/test_mimo_models.py b/test/registered/8-gpu-models/test_mimo_models.py index 217b4d4046c9..cc266bbe20e3 100644 --- a/test/registered/8-gpu-models/test_mimo_models.py +++ b/test/registered/8-gpu-models/test_mimo_models.py @@ -46,7 +46,7 @@ class TestMiMoV2Flash(GSM8KMixin, SpecDecodingMixin, DefaultServerBase): ] bs_1_speed_thres = 170 - num_accepted_drafts_thres = 3.2 + accept_length_thres = 3.2 MIMO_V2_MODEL = "XiaomiMiMo/MiMo-V2.5"