Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/utils/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,14 @@ def add_output_logprobs_for_spec_v1(

top_logprobs_nums = batch.top_logprobs_nums
token_ids_logprobs = batch.token_ids_logprobs
accepted_indices = res.accepted_indices
assert len(accepted_indices) == len(logits_output.next_token_logits)
accept_indices = res.accept_indices
assert len(accept_indices) == len(logits_output.next_token_logits)

temperatures = batch.sampling_info.temperatures
num_draft_tokens = batch.spec_info.draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures[accepted_indices // num_draft_tokens]
temperatures = temperatures[accept_indices // num_draft_tokens]
if envs.SGLANG_RETURN_ORIGINAL_LOGPROB.get():
logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def process_batch_result_prefill(
dp_cooperation_info=batch.dp_cooperation_info,
)

def _resolve_spec_overlap_token_ids(
def _resolve_spec_overlap_tokens(
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
) -> List[List[int]]:
"""Resolve the padding next token ids for speculative decoding with overlap."""
Expand Down Expand Up @@ -487,7 +487,7 @@ def process_batch_result_decode(

if batch.spec_algorithm.is_none() or batch.is_spec_v2:
if batch.is_spec_v2:
next_token_ids = self._resolve_spec_overlap_token_ids(result, batch)
next_token_ids = self._resolve_spec_overlap_tokens(result, batch)
elif isinstance(next_token_ids, list):
pass # MLX path: already a list[int], skip torch round-trip
else:
Expand Down
14 changes: 7 additions & 7 deletions python/sglang/srt/speculative/dflash_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def _greedy_sample_from_vocab_parallel_head(
added_vocab_start = int(shard.added_vocab_start_index)

num_tokens = int(hidden_states.shape[0])
out_token_ids = torch.empty(
out_tokens = torch.empty(
(num_tokens,), dtype=torch.long, device=hidden_states.device
)

Expand All @@ -753,13 +753,13 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor:
hs = _cast_hs(hidden_states[start:end])
if num_org > 0:
base_logits = torch.matmul(hs, weight[:num_org].T)
out_token_ids[start:end] = (
out_tokens[start:end] = (
torch.argmax(base_logits, dim=-1).to(torch.long)
+ org_vocab_start
)
else:
out_token_ids[start:end] = 0
return out_token_ids
out_tokens[start:end] = 0
return out_tokens

for start in range(0, num_tokens, int(chunk_size)):
end = min(num_tokens, start + int(chunk_size))
Expand Down Expand Up @@ -812,7 +812,7 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor:
)

if tp_size == 1:
out_token_ids[start:end] = global_ids.to(torch.long)
out_tokens[start:end] = global_ids.to(torch.long)
continue

# Gather per-rank maxima and associated global ids, then select the global max.
Expand Down Expand Up @@ -869,9 +869,9 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor:
rank_index[0].copy_(best_rank)
selected_ids = self._draft_greedy_selected_ids_buf[:, :chunk_len]
torch.gather(gathered_ids, 0, rank_index, out=selected_ids)
out_token_ids[start:end].copy_(selected_ids.view(-1))
out_tokens[start:end].copy_(selected_ids.view(-1))

return out_token_ids
return out_tokens

def _append_target_hidden_to_draft_kv(
self,
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/speculative/eagle_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def verify(
logits_output=logits_output,
accept_tokens=accept_tokens,
num_correct_drafts_per_req_cpu=num_correct_drafts_list,
accepted_indices=accept_index,
accept_indices=accept_index,
)
else:
if page_size == 1 or self.topk == 1:
Expand Down Expand Up @@ -651,7 +651,7 @@ def verify(
logits_output=logits_output,
accept_tokens=accept_tokens,
num_correct_drafts_per_req_cpu=num_correct_drafts_list,
accepted_indices=accept_index,
accept_indices=accept_index,
)


Expand Down Expand Up @@ -972,7 +972,7 @@ class EagleVerifyOutput:
# Accepted token length per sequence in a batch in CPU (full set).
num_correct_drafts_per_req_cpu: List[int]
# Accepted indices from logits_output.next_token_logits
accepted_indices: torch.Tensor
accept_indices: torch.Tensor

@classmethod
def create_idle(
Expand All @@ -988,7 +988,7 @@ def create_idle(
logits_output=logits_output,
accept_tokens=torch.empty(0, dtype=torch.long, device=device),
num_correct_drafts_per_req_cpu=[],
accepted_indices=torch.full(
accept_indices=torch.full(
(0, spec_steps + 1), -1, dtype=torch.int32, device=device
),
)
18 changes: 9 additions & 9 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,9 @@ def verify(self, batch: ScheduleBatch):
# Post process based on verified outputs.
# Pick indices that we care (accepted)
logits_output.next_token_logits = logits_output.next_token_logits[
res.accepted_indices
res.accept_indices
]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
logits_output.hidden_states = logits_output.hidden_states[res.accept_indices]

if (
self.target_worker.model_runner.hybrid_gdn_config is not None
Expand Down Expand Up @@ -1029,16 +1029,16 @@ def _mamba_verify_update(
)

# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
# res.accepted_indices.shape[0] > 0 skips DP attn idle batch
if spec_info.topk > 1 and res.accepted_indices.shape[0] > 0:
# accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9]
# first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req)
# res.accept_indices.shape[0] > 0 skips DP attn idle batch
if spec_info.topk > 1 and res.accept_indices.shape[0] > 0:
# accept_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9]
# first_token_indices_per_req=prepend(0, accept_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accept_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req)
# last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
# equivalent: last_correct_step_indices = last_token_indices_per_req - first_token_indices_per_req;
# `accepted_indices_offset` equals `first_token_indices_per_req` because the first accepted slot of each req is its "current token" at logical position i * draft_token_num.
last_correct_step_indices = (
res.accepted_indices[cumulative_num_accept_tokens - 1]
res.accept_indices[cumulative_num_accept_tokens - 1]
- accepted_indices_offset
)
else:
Expand All @@ -1058,7 +1058,7 @@ def _mamba_verify_update(
to_track_ith = torch.clamp(tracking_point - seq_lens_pre_verify - 1, min=0)
mamba_steps_to_track = torch.where(
to_track_mask,
res.accepted_indices[to_track_ith + accepted_indices_start]
res.accept_indices[to_track_ith + accepted_indices_start]
- accepted_indices_offset,
-1,
)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/speculative/frozen_kv_mtp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,9 +755,9 @@ def verify(self, batch: ScheduleBatch):
)

logits_output.next_token_logits = logits_output.next_token_logits[
res.accepted_indices
res.accept_indices
]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
logits_output.hidden_states = logits_output.hidden_states[res.accept_indices]

if (
self.target_worker.model_runner.hybrid_gdn_config is not None
Expand Down
18 changes: 9 additions & 9 deletions python/sglang/srt/speculative/multi_layer_eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,9 +556,9 @@ def verify(self, batch: ScheduleBatch):
# Post process based on verified outputs.
# Pick indices that we care (accepted)
logits_output.next_token_logits = logits_output.next_token_logits[
res.accepted_indices
res.accept_indices
]
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
logits_output.hidden_states = logits_output.hidden_states[res.accept_indices]

if self.target_worker.model_runner.hybrid_gdn_config is not None:
num_accept_tokens = (
Expand All @@ -571,11 +571,11 @@ def verify(self, batch: ScheduleBatch):
)

# If topk > 1, we need to use retrieve_next_token and retrieve_next_sibling to handle the eagle tree custom attention mask
# res.accepted_indices.shape[0] > 0 skips DP attn idle batch
if spec_info.topk > 1 and res.accepted_indices.shape[0] > 0:
# accepted_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9]
# first_token_indices_per_req=prepend(0, accepted_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accepted_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req)
# res.accept_indices.shape[0] > 0 skips DP attn idle batch
if spec_info.topk > 1 and res.accept_indices.shape[0] > 0:
# accept_indices=[0,2,3,4,5,7,9,10,11], num_accept_tokens=[4, 3, 2], cumulative_num_accept_tokens=[4, 7, 9]
# first_token_indices_per_req=prepend(0, accept_indices[cumulative_num_accept_tokens[:-1]]) = [0, 5, 10]
# last_token_indices_per_req=accept_indices[cumulative_num_accept_tokens - 1] = [4, 9, 11] (last token ID of each req)
# last_correct_step_indices = [4,4,1]; those are the per-req spec-decoding step offsets that contain the correct mamba caches
cumulative_num_accept_tokens = torch.cumsum(num_accept_tokens, dim=0)
req_start_positions = torch.cat(
Expand All @@ -588,8 +588,8 @@ def verify(self, batch: ScheduleBatch):
cumulative_num_accept_tokens[:-1],
]
)
first_token_indices_per_req = res.accepted_indices[req_start_positions]
last_token_indices_per_req = res.accepted_indices[
first_token_indices_per_req = res.accept_indices[req_start_positions]
last_token_indices_per_req = res.accept_indices[
cumulative_num_accept_tokens - 1
]
last_correct_step_indices = (
Expand Down
30 changes: 15 additions & 15 deletions python/sglang/srt/speculative/ngram_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _fill_requests(
batch: ScheduleBatch,
logits_output: torch.Tensor,
):
accept_index_cpu = self.accepted_indices.tolist()
accept_index_cpu = self.accept_indices.tolist()
predict_cpu = self.predict.tolist()
has_finished = False
think_end_id = batch.model_config.think_end_id
Expand All @@ -176,7 +176,7 @@ def _fill_requests(
if req.finished():
has_finished = True
# set all tokens after finished token to -1 and break
self.accepted_indices[i, j + 1 :] = -1
self.accept_indices[i, j + 1 :] = -1
break
else:
if req.grammar is not None:
Expand All @@ -185,7 +185,7 @@ def _fill_requests(
except ValueError as e:
logger.info(
f"{i=}, {req=}\n"
f"{self.accepted_indices=}\n"
f"{self.accept_indices=}\n"
f"{self.predict=}\n"
)
raise e
Expand All @@ -197,17 +197,17 @@ def _fill_requests(
req.update_spec_correct_drafts_histogram(num_correct_drafts_this_req)

if has_finished:
self.num_correct_drafts = (self.accepted_indices != -1).sum(dim=1) - 1
self.accepted_indices = self.accepted_indices[self.accepted_indices != -1]
self.num_correct_drafts = (self.accept_indices != -1).sum(dim=1) - 1
self.accept_indices = self.accept_indices[self.accept_indices != -1]

logits_output.next_token_logits = logits_output.next_token_logits[
self.accepted_indices
self.accept_indices
]
if logits_output.hidden_states:
logits_output.hidden_states = logits_output.hidden_states[
self.accepted_indices
self.accept_indices
]
self.accept_tokens = self.predict[self.accepted_indices]
self.accept_tokens = self.predict[self.accept_indices]

def _free_cache(
self,
Expand All @@ -220,16 +220,16 @@ def _free_cache(
if page_size == 1:
# TODO: boolean array index leads to a device sync. Remove it.
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[self.accepted_indices] = False
evict_mask[self.accept_indices] = False
batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
batch.out_cache_loc = batch.out_cache_loc[self.accepted_indices]
batch.out_cache_loc = batch.out_cache_loc[self.accept_indices]
else:
# Shift the accepted tokens to the beginning.
# Only evict the last part
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
batch.seq_lens,
batch.out_cache_loc,
self.accepted_indices,
self.accept_indices,
self.num_correct_drafts,
self.draft_token_num,
page_size,
Expand Down Expand Up @@ -297,7 +297,7 @@ def _greedy_verify(
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1
self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
self.accepted_indices = torch.full(
self.accept_indices = torch.full(
(bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
)
self.num_correct_drafts = torch.empty(
Expand All @@ -306,7 +306,7 @@ def _greedy_verify(

verify_tree_greedy(
predicts=self.predict, # mutable
accept_index=self.accepted_indices, # mutable
accept_index=self.accept_indices, # mutable
accept_token_num=self.num_correct_drafts, # mutable
candidates=candidates,
# kwarg LHS retained as `retrive_*` to match sgl_kernel op schema.
Expand All @@ -327,7 +327,7 @@ def _sampling_verify(
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
predict_shape[-1] += 1
self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
self.accepted_indices = torch.full(
self.accept_indices = torch.full(
(bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
)
self.num_correct_drafts = torch.empty(
Expand Down Expand Up @@ -371,7 +371,7 @@ def _sampling_verify(
)
tree_speculative_sampling_target_only(
predicts=self.predict, # mutable
accept_index=self.accepted_indices, # mutable
accept_index=self.accept_indices, # mutable
accept_token_num=self.num_correct_drafts, # mutable
candidates=candidates.to(torch.int64),
# kwarg LHS retained as `retrive_*` to match sgl_kernel op schema.
Expand Down
Loading