From 15ecaae801b8f1eb535089a095750368eab6d4c7 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Mon, 11 May 2026 21:50:01 -0700 Subject: [PATCH 1/2] rename accepted_indices -> accept_indices; drop _token_id suffix per Rule 5 --- python/sglang/srt/layers/utils/logprob.py | 6 ++-- .../scheduler_output_processor_mixin.py | 4 +-- .../sglang/srt/speculative/dflash_worker.py | 14 ++++----- python/sglang/srt/speculative/eagle_info.py | 8 ++--- python/sglang/srt/speculative/eagle_worker.py | 18 +++++------ .../srt/speculative/frozen_kv_mtp_worker.py | 4 +-- .../speculative/multi_layer_eagle_worker.py | 18 +++++------ python/sglang/srt/speculative/ngram_info.py | 30 +++++++++---------- python/sglang/srt/speculative/spec_utils.py | 6 ++-- 9 files changed, 54 insertions(+), 54 deletions(-) diff --git a/python/sglang/srt/layers/utils/logprob.py b/python/sglang/srt/layers/utils/logprob.py index ab199a86d1b8..3b14b2510650 100644 --- a/python/sglang/srt/layers/utils/logprob.py +++ b/python/sglang/srt/layers/utils/logprob.py @@ -346,14 +346,14 @@ def add_output_logprobs_for_spec_v1( top_logprobs_nums = batch.top_logprobs_nums token_ids_logprobs = batch.token_ids_logprobs - accepted_indices = res.accepted_indices - assert len(accepted_indices) == len(logits_output.next_token_logits) + accept_indices = res.accept_indices + assert len(accept_indices) == len(logits_output.next_token_logits) temperatures = batch.sampling_info.temperatures num_draft_tokens = batch.spec_info.draft_token_num # acceptance indices are the indices in a "flattened" batch. # dividing it to num_draft_tokens will yield the actual batch index. - temperatures = temperatures[accepted_indices // num_draft_tokens] + temperatures = temperatures[accept_indices // num_draft_tokens] if envs.SGLANG_RETURN_ORIGINAL_LOGPROB.get(): logprobs = torch.nn.functional.log_softmax( logits_output.next_token_logits, dim=-1 diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 084fe7717de4..ae6f732fe934 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -406,7 +406,7 @@ def process_batch_result_prefill( dp_cooperation_info=batch.dp_cooperation_info, ) - def _resolve_spec_overlap_token_ids( + def _resolve_spec_overlap_tokens( self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch ) -> List[List[int]]: """Resolve the padding next token ids for speculative decoding with overlap.""" @@ -487,7 +487,7 @@ def process_batch_result_decode( if batch.spec_algorithm.is_none() or batch.is_spec_v2: if batch.is_spec_v2: - next_token_ids = self._resolve_spec_overlap_token_ids(result, batch) + next_token_ids = self._resolve_spec_overlap_tokens(result, batch) elif isinstance(next_token_ids, list): pass # MLX path: already a list[int], skip torch round-trip else: diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 68fde73632f1..4b9e4fe9b440 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -738,7 +738,7 @@ def _greedy_sample_from_vocab_parallel_head( added_vocab_start = int(shard.added_vocab_start_index) num_tokens = int(hidden_states.shape[0]) - out_token_ids = torch.empty( + out_tokens = torch.empty( (num_tokens,), dtype=torch.long, device=hidden_states.device ) @@ -753,13 +753,13 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: hs = _cast_hs(hidden_states[start:end]) if num_org > 0: base_logits = torch.matmul(hs, weight[:num_org].T) - out_token_ids[start:end] = ( + out_tokens[start:end] = ( torch.argmax(base_logits, dim=-1).to(torch.long) + org_vocab_start ) else: - out_token_ids[start:end] = 0 - return out_token_ids + out_tokens[start:end] = 0 + return out_tokens for start in range(0, num_tokens, int(chunk_size)): end = min(num_tokens, start + int(chunk_size)) @@ -812,7 +812,7 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: ) if tp_size == 1: - out_token_ids[start:end] = global_ids.to(torch.long) + out_tokens[start:end] = global_ids.to(torch.long) continue # Gather per-rank maxima and associated global ids, then select the global max. @@ -869,9 +869,9 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: rank_index[0].copy_(best_rank) selected_ids = self._draft_greedy_selected_ids_buf[:, :chunk_len] torch.gather(gathered_ids, 0, rank_index, out=selected_ids) - out_token_ids[start:end].copy_(selected_ids.view(-1)) + out_tokens[start:end].copy_(selected_ids.view(-1)) - return out_token_ids + return out_tokens def _append_target_hidden_to_draft_kv( self, diff --git a/python/sglang/srt/speculative/eagle_info.py b/python/sglang/srt/speculative/eagle_info.py index d92a99dc31e2..2ca9d1f1eca1 100644 --- a/python/sglang/srt/speculative/eagle_info.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -574,7 +574,7 @@ def verify( logits_output=logits_output, accept_tokens=accept_tokens, num_correct_drafts_per_req_cpu=num_correct_drafts_list, - accepted_indices=accept_index, + accept_indices=accept_index, ) else: if page_size == 1 or self.topk == 1: @@ -651,7 +651,7 @@ def verify( logits_output=logits_output, accept_tokens=accept_tokens, num_correct_drafts_per_req_cpu=num_correct_drafts_list, - accepted_indices=accept_index, + accept_indices=accept_index, ) @@ -972,7 +972,7 @@ class EagleVerifyOutput: # Accepted token length per sequence in a batch in CPU (full set). num_correct_drafts_per_req_cpu: List[int] # Accepted indices from logits_output.next_token_logits - accepted_indices: torch.Tensor + accept_indices: torch.Tensor @classmethod def create_idle( @@ -988,7 +988,7 @@ def create_idle( logits_output=logits_output, accept_tokens=torch.empty(0, dtype=torch.long, device=device), num_correct_drafts_per_req_cpu=[], - accepted_indices=torch.full( + accept_indices=torch.full( (0, spec_steps + 1), -1, dtype=torch.int32, device=device ), ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index aa6e7338008f..7495ae5ad180 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -967,9 +967,9 @@ def verify(self, batch: ScheduleBatch): # Post process based on verified outputs. # Pick indices that we care (accepted) logits_output.next_token_logits = logits_output.next_token_logits[ - res.accepted_indices + res.accept_indices ] - logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] + logits_output.hidden_states = logits_output.hidden_states[res.accept_indices] if ( self.target_worker.model_runner.hybrid_gdn_config is not None @@ -1029,16 +1029,16 @@ def _mamba_verify_update( ) # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask - # res.accepted_indices.shape[0] > 0 skips DP attn idle batch - if spec_info.topk > 1 and res.accepted_indices.shape[0] > 0: - # accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9] - # first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] - # last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) + # res.accept_indices.shape[0] > 0 skips DP attn idle batch + if spec_info.topk > 1 and res.accept_indices.shape[0] > 0: + # accept_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9] + # first_token_indices_per_req=prepend(0, accept_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] + # last_token_indices_per_req=accept_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) # last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches # equivalent: last_correct_step_indices = last_token_indices_per_req - first_token_indices_per_req; # `accepted_indices_offset` equals `first_token_indices_per_req` because the first accepted slot of each req is its "current token" at logical position i * draft_token_num. last_correct_step_indices = ( - res.accepted_indices[cumulative_num_accept_tokens - 1] + res.accept_indices[cumulative_num_accept_tokens - 1] - accepted_indices_offset ) else: @@ -1058,7 +1058,7 @@ def _mamba_verify_update( to_track_ith = torch.clamp(tracking_point - seq_lens_pre_verify - 1, min=0) mamba_steps_to_track = torch.where( to_track_mask, - res.accepted_indices[to_track_ith + accepted_indices_start] + res.accept_indices[to_track_ith + accepted_indices_start] - accepted_indices_offset, -1, ) diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index e2e477fb3581..d1de77ed41d8 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -755,9 +755,9 @@ def verify(self, batch: ScheduleBatch): ) logits_output.next_token_logits = logits_output.next_token_logits[ - res.accepted_indices + res.accept_indices ] - logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] + logits_output.hidden_states = logits_output.hidden_states[res.accept_indices] if ( self.target_worker.model_runner.hybrid_gdn_config is not None diff --git a/python/sglang/srt/speculative/multi_layer_eagle_worker.py b/python/sglang/srt/speculative/multi_layer_eagle_worker.py index 0311167acf65..8feeb47eef57 100644 --- a/python/sglang/srt/speculative/multi_layer_eagle_worker.py +++ b/python/sglang/srt/speculative/multi_layer_eagle_worker.py @@ -556,9 +556,9 @@ def verify(self, batch: ScheduleBatch): # Post process based on verified outputs. # Pick indices that we care (accepted) logits_output.next_token_logits = logits_output.next_token_logits[ - res.accepted_indices + res.accept_indices ] - logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] + logits_output.hidden_states = logits_output.hidden_states[res.accept_indices] if self.target_worker.model_runner.hybrid_gdn_config is not None: num_accept_tokens = ( @@ -571,11 +571,11 @@ def verify(self, batch: ScheduleBatch): ) # If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask - # res.accepted_indices.shape[0] > 0 skips DP attn idle batch - if spec_info.topk > 1 and res.accepted_indices.shape[0] > 0: - # accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9] - # first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] - # last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) + # res.accept_indices.shape[0] > 0 skips DP attn idle batch + if spec_info.topk > 1 and res.accept_indices.shape[0] > 0: + # accept_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9] + # first_token_indices_per_req=prepend(0, accept_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10] + # last_token_indices_per_req=accept_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req) # last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0) req_start_positions = torch.cat( @@ -588,8 +588,8 @@ def verify(self, batch: ScheduleBatch): cumulative_num_accept_tokens[:-1], ] ) - first_token_indices_per_req = res.accepted_indices[req_start_positions] - last_token_indices_per_req = res.accepted_indices[ + first_token_indices_per_req = res.accept_indices[req_start_positions] + last_token_indices_per_req = res.accept_indices[ cumulative_num_accept_tokens - 1 ] last_correct_step_indices = ( diff --git a/python/sglang/srt/speculative/ngram_info.py b/python/sglang/srt/speculative/ngram_info.py index 777aea1aa3ae..d760cfc0f56a 100644 --- a/python/sglang/srt/speculative/ngram_info.py +++ b/python/sglang/srt/speculative/ngram_info.py @@ -157,7 +157,7 @@ def _fill_requests( batch: ScheduleBatch, logits_output: torch.Tensor, ): - accept_index_cpu = self.accepted_indices.tolist() + accept_index_cpu = self.accept_indices.tolist() predict_cpu = self.predict.tolist() has_finished = False think_end_id = batch.model_config.think_end_id @@ -176,7 +176,7 @@ def _fill_requests( if req.finished(): has_finished = True # set all tokens after finished token to -1 and break - self.accepted_indices[i, j + 1 :] = -1 + self.accept_indices[i, j + 1 :] = -1 break else: if req.grammar is not None: @@ -185,7 +185,7 @@ def _fill_requests( except ValueError as e: logger.info( f"{i=}, {req=}\n" - f"{self.accepted_indices=}\n" + f"{self.accept_indices=}\n" f"{self.predict=}\n" ) raise e @@ -197,17 +197,17 @@ def _fill_requests( req.update_spec_correct_drafts_histogram(num_correct_drafts_this_req) if has_finished: - self.num_correct_drafts = (self.accepted_indices != -1).sum(dim=1) - 1 - self.accepted_indices = self.accepted_indices[self.accepted_indices != -1] + self.num_correct_drafts = (self.accept_indices != -1).sum(dim=1) - 1 + self.accept_indices = self.accept_indices[self.accept_indices != -1] logits_output.next_token_logits = logits_output.next_token_logits[ - self.accepted_indices + self.accept_indices ] if logits_output.hidden_states: logits_output.hidden_states = logits_output.hidden_states[ - self.accepted_indices + self.accept_indices ] - self.accept_tokens = self.predict[self.accepted_indices] + self.accept_tokens = self.predict[self.accept_indices] def _free_cache( self, @@ -220,16 +220,16 @@ def _free_cache( if page_size == 1: # TODO: boolean array index leads to a device sync. Remove it. evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) - evict_mask[self.accepted_indices] = False + evict_mask[self.accept_indices] = False batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) - batch.out_cache_loc = batch.out_cache_loc[self.accepted_indices] + batch.out_cache_loc = batch.out_cache_loc[self.accept_indices] else: # Shift the accepted tokens to the beginning. # Only evict the last part src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc( batch.seq_lens, batch.out_cache_loc, - self.accepted_indices, + self.accept_indices, self.num_correct_drafts, self.draft_token_num, page_size, @@ -297,7 +297,7 @@ def _greedy_verify( predict_shape = list(logits_output.next_token_logits.shape)[:-1] predict_shape[-1] += 1 self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device) - self.accepted_indices = torch.full( + self.accept_indices = torch.full( (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device ) self.num_correct_drafts = torch.empty( @@ -306,7 +306,7 @@ def _greedy_verify( verify_tree_greedy( predicts=self.predict, # mutable - accept_index=self.accepted_indices, # mutable + accept_index=self.accept_indices, # mutable accept_token_num=self.num_correct_drafts, # mutable candidates=candidates, # kwarg LHS retained as `retrive_*` to match sgl_kernel op schema. @@ -327,7 +327,7 @@ def _sampling_verify( predict_shape = list(logits_output.next_token_logits.shape)[:-1] predict_shape[-1] += 1 self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device) - self.accepted_indices = torch.full( + self.accept_indices = torch.full( (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device ) self.num_correct_drafts = torch.empty( @@ -371,7 +371,7 @@ def _sampling_verify( ) tree_speculative_sampling_target_only( predicts=self.predict, # mutable - accept_index=self.accepted_indices, # mutable + accept_index=self.accept_indices, # mutable accept_token_num=self.num_correct_drafts, # mutable candidates=candidates.to(torch.int64), # kwarg LHS retained as `retrive_*` to match sgl_kernel op schema. diff --git a/python/sglang/srt/speculative/spec_utils.py b/python/sglang/srt/speculative/spec_utils.py index 5c879a978ad7..49ed43fad092 100644 --- a/python/sglang/srt/speculative/spec_utils.py +++ b/python/sglang/srt/speculative/spec_utils.py @@ -621,13 +621,13 @@ def dfs( accepted = True else: parent_bitmask = allocate_token_bitmask[parent_pos] - curr_token_id = draft_tokens[curr] - if vocab_size and curr_token_id >= vocab_size: + current_token = draft_tokens[curr] + if vocab_size and current_token >= vocab_size: accepted = False else: # 32 boolean bitmask values are packed into 32-bit integers accepted = ( - parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32)) + parent_bitmask[current_token // 32] & (1 << (current_token % 32)) ) != 0 if accepted: From 181c06671025044f80f144add730c45176765e3f Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Mon, 11 May 2026 22:01:06 -0700 Subject: [PATCH 2/2] keep curr_token_id (scalar) as-is; Rule 5 only for tensor/list --- python/sglang/srt/speculative/spec_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/speculative/spec_utils.py b/python/sglang/srt/speculative/spec_utils.py index 49ed43fad092..5c879a978ad7 100644 --- a/python/sglang/srt/speculative/spec_utils.py +++ b/python/sglang/srt/speculative/spec_utils.py @@ -621,13 +621,13 @@ def dfs( accepted = True else: parent_bitmask = allocate_token_bitmask[parent_pos] - current_token = draft_tokens[curr] - if vocab_size and current_token >= vocab_size: + curr_token_id = draft_tokens[curr] + if vocab_size and curr_token_id >= vocab_size: accepted = False else: # 32 boolean bitmask values are packed into 32-bit integers accepted = ( - parent_bitmask[current_token // 32] & (1 << (current_token % 32)) + parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32)) ) != 0 if accepted: