Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def prepare_gdn_inputs(
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
cache_indices = self.forward_metadata.mamba_cache_indices
self.num_accepted_tokens = torch.ones(
self.num_accept_tokens = torch.ones(
[bs], dtype=torch.int32, device=cache_indices.device
)
self.actual_seq_lengths = torch.ones(
Expand Down Expand Up @@ -237,7 +237,7 @@ def forward_extend(
seq_len = forward_batch.num_token_non_padded_cpu

mixed_qkv_reshaped = mixed_qkv.view(batch_size, draft_token_num, -1)
num_accepted_tokens = torch.full(
num_accept_tokens = torch.full(
(batch_size,),
draft_token_num,
dtype=torch.int32,
Expand All @@ -249,7 +249,7 @@ def forward_extend(
conv_states,
cache_indices,
layer.bias,
num_accepted_tokens,
num_accept_tokens,
None,
layer.activation == "silu",
self.pad_slot_id,
Expand Down Expand Up @@ -391,15 +391,15 @@ def fused_recurrent_gated_delta_rule_update(
)

if self.graph_mode:
num_accepted_tokens = torch.full(
num_accept_tokens = torch.full(
[batch_size], 1, dtype=torch.int32, device=cache_indices.device
)
actual_seq_lengths = torch.full(
[batch_size], seq_len, dtype=torch.int32, device=cache_indices.device
)
ssm_state_indices = self.forward_metadata.mamba_cache_indices_gdn
else:
num_accepted_tokens = self.num_accepted_tokens
num_accept_tokens = self.num_accept_tokens
actual_seq_lengths = self.actual_seq_lengths
ssm_state_indices = self.ssm_state_indices

Expand All @@ -414,7 +414,7 @@ def fused_recurrent_gated_delta_rule_update(
nv=num_value_heads,
intermediate_state=intermediate_state,
cache_indices=cache_indices,
num_accepted_tokens=num_accepted_tokens,
num_accept_tokens=num_accept_tokens,
g=g,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(

def update_mamba_state_after_mtp_verify(
self,
accepted_steps: torch.Tensor,
accept_steps: torch.Tensor,
mamba_track_indices: Optional[torch.Tensor],
mamba_steps_to_track: Optional[torch.Tensor],
model,
Expand All @@ -233,7 +233,7 @@ def update_mamba_state_after_mtp_verify(
- index_select kernel launches
- nonzero kernel launches
"""
request_number = accepted_steps.shape[0]
request_number = accept_steps.shape[0]

state_indices_tensor = (
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
Expand All @@ -254,7 +254,7 @@ def update_mamba_state_after_mtp_verify(
device=dst_indices_tensor.device,
dtype=torch.int64,
)
last_steps = accepted_steps.to(torch.int64) # [N]
last_steps = accept_steps.to(torch.int64) # [N]

move_intermediate_cache(
ssm_states,
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/attention/aiter_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
self.indices_updater_prefill.max_kv_len,
)
elif forward_batch.forward_mode.is_draft_extend():
# EAGLE V1: DRAFT_EXTEND mode - uses spec_info.num_accepted_tokens
# EAGLE V1: DRAFT_EXTEND mode - uses spec_info.num_accept_tokens
if self.use_mla:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)
)
kv_indices = kv_indices.to(torch.int64)
draft_max_extend_len = torch.max(spec_info.num_accepted_tokens).item()
draft_max_extend_len = torch.max(spec_info.num_accept_tokens).item()

self.forward_metadata = ForwardMetadata(
kv_indptr,
Expand Down Expand Up @@ -2240,10 +2240,10 @@ def init_forward_metadata_replay_cuda_graph(
num_kv_splits=num_kv_splits,
)
elif forward_mode.is_draft_extend():
# EAGLE V1: Uses spec_info.num_accepted_tokens
# EAGLE V1: Uses spec_info.num_accept_tokens
num_tokens_per_bs = self.speculative_num_steps + 1
seq_lens = seq_lens[:bs]
extend_lens = spec_info.num_accepted_tokens[:bs]
extend_lens = spec_info.num_accept_tokens[:bs]
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
kv_indptr = self.kv_indptr[: bs + 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2162,9 +2162,9 @@ def init_forward_metadata_replay_cuda_graph(
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
extend_lens = spec_info.num_accepted_tokens[:bs]
if spec_info.num_accepted_tokens_cpu:
metadata.max_seq_len_q = max(spec_info.num_accepted_tokens_cpu)
extend_lens = spec_info.num_accept_tokens[:bs]
if spec_info.num_accept_tokens_cpu:
metadata.max_seq_len_q = max(spec_info.num_accept_tokens_cpu)
else:
metadata.max_seq_len_q = 1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ def forward(

def update_mamba_state_after_mtp_verify(
self,
accepted_steps: torch.Tensor,
accept_steps: torch.Tensor,
mamba_track_indices: Optional[torch.Tensor],
mamba_steps_to_track: Optional[torch.Tensor],
model,
Expand All @@ -950,7 +950,7 @@ def update_mamba_state_after_mtp_verify(
- index_select kernel launches
- nonzero kernel launches
"""
request_number = accepted_steps.shape[0]
request_number = accept_steps.shape[0]

state_indices_tensor = (
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
Expand All @@ -973,13 +973,13 @@ def update_mamba_state_after_mtp_verify(
ssm_states,
intermediate_state_cache,
state_indices_tensor,
accepted_steps,
accept_steps,
)
fused_mamba_state_scatter_with_mask(
conv_states,
intermediate_conv_window_cache,
state_indices_tensor,
accepted_steps,
accept_steps,
)

# Track indices used for tracking mamba states for prefix cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def _causal_conv1d_update_kernel(
conv_state_ptr,
cache_seqlens_ptr, # circular buffer
conv_state_indices_ptr,
num_accepted_tokens_ptr,
num_accept_tokens_ptr,
intermediate_conv_window_ptr,
intermediate_state_indices_ptr,
retrieve_next_token_ptr,
Expand Down Expand Up @@ -667,7 +667,7 @@ def _causal_conv1d_update_kernel(
# - accept 1 tokens: [history2, ..., historyM, draft1]
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1
conv_state_token_offset = tl.load(num_accept_tokens_ptr + idx_seq) - 1
else:
conv_state_token_offset = 0

Expand Down Expand Up @@ -985,7 +985,7 @@ def causal_conv1d_update(
activation: Union[bool, str, None] = None,
cache_seqlens: Optional[torch.Tensor] = None,
conv_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
num_accept_tokens: Optional[torch.Tensor] = None,
intermediate_conv_window: Optional[torch.Tensor] = None,
intermediate_state_indices: Optional[torch.Tensor] = None,
retrieve_next_token: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1071,7 +1071,7 @@ def causal_conv1d_update(
if intermediate_state_indices is not None
else 0
)
if num_accepted_tokens is not None:
if num_accept_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:
state_len = width - 1
Expand Down Expand Up @@ -1130,7 +1130,7 @@ def grid(META):
conv_state,
cache_seqlens,
conv_state_indices,
num_accepted_tokens,
num_accept_tokens,
intermediate_conv_window if intermediate_conv_window is not None else x,
intermediate_state_indices,
retrieve_next_token,
Expand Down Expand Up @@ -1174,7 +1174,7 @@ def grid(META):
KERNEL_WIDTH=width,
SILU_ACTIVATION=activation in ["silu", "swish"],
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
IS_SPEC_DECODING=num_accepted_tokens is not None,
IS_SPEC_DECODING=num_accept_tokens is not None,
NP2_STATELEN=np2_statelen,
NP2_SEQLEN=np2_seqlen,
USE_PAD_SLOT=pad_slot_id is not None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _fused_mamba_state_scatter_with_mask_kernel(
dst_ptr,
# Raw index arrays (before index_select)
dst_indices_raw_ptr, # [total_requests] - state_indices_tensor
step_indices_raw_ptr, # [total_requests] - accepted_steps or mamba_steps_to_track
step_indices_raw_ptr, # [total_requests] - accept_steps or mamba_steps_to_track
elem_per_entry: tl.constexpr,
src_layer_stride,
src_req_stride,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def _precompute_draft_extend_mode(
cache_seqlens = seq_lens.to(torch.int32)
cu_seqlens_k = compute_cu_seqlens(cache_seqlens)

# Extend seqlens from spec_info: num_accepted_tokens already includes
# Extend seqlens from spec_info: num_accept_tokens already includes
# the bonus token (drafts + 1).
extend_seq_lens = spec_info.num_accepted_tokens[:bs]
extend_seq_lens = spec_info.num_accept_tokens[:bs]
extend_seq_lens_cpu = extend_seq_lens.tolist()

# Page indices (repeated per accept length)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/attention/nsa_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
page_table, repeats=self.speculative_num_draft_tokens, dim=0
)
else:
# DRAFT_EXTEND (v1): V1 worker extends by (num_accepted_drafts + 1) per request
# DRAFT_EXTEND (v1): V1 worker extends by (num_correct_drafts + 1) per request
# after verification. Lengths vary per request based on how many tokens
# were accepted.
page_table = torch.repeat_interleave(
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def init_forward_metadata_replay_cuda_graph(
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
)

extend_seq_lens = spec_info.num_accepted_tokens[:bs]
extend_seq_lens = spec_info.num_accept_tokens[:bs]
extend_seq_lens_cpu = extend_seq_lens.tolist()

page_indices = self.req_to_token[req_pool_indices, :max_seqlen_k]
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/attention/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = kv_indices.to(torch.int64)
mask_indptr = None
# TODO(FIXME): This will trigger an invalid Eagle tree when using
# `max(spec_info.num_accepted_tokens_cpu)`.
# `max(spec_info.num_accept_tokens_cpu)`.
# It might have been forgotten to update somewhere.
max_extend_len = torch.max(spec_info.num_accepted_tokens).item()
max_extend_len = torch.max(spec_info.num_accept_tokens).item()
num_kv_splits = None
attn_logits = None
attn_lse = None
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/attention/trtllm_mha_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,9 @@ def init_forward_metadata_replay_cuda_graph(
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
extend_lens = spec_info.num_accepted_tokens[:bs]
if spec_info.num_accepted_tokens_cpu:
metadata.max_seq_len_q = max(spec_info.num_accepted_tokens_cpu)
extend_lens = spec_info.num_accept_tokens[:bs]
if spec_info.num_accept_tokens_cpu:
metadata.max_seq_len_q = max(spec_info.num_accept_tokens_cpu)
else:
metadata.max_seq_len_q = 1

Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/attention/trtllm_mla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def pad_draft_extend_query_kernel(
def unpad_draft_extend_output_kernel(
raw_out_ptr, # Input raw output tensor (batch_size, token_per_batch, tp_q_head_num, v_head_dim)
output_ptr, # Output tensor (-1, tp_q_head_num, v_head_dim)
accept_length_ptr, # Accept lengths for each sequence [batch_size]
num_accept_tokens_ptr, # Accept lengths for each sequence [batch_size]
cumsum_ptr, # Cumulative sum of accept lengths [batch_size + 1]
batch_size,
token_per_batch,
Expand All @@ -151,7 +151,7 @@ def unpad_draft_extend_output_kernel(
return

# Load accept length for this batch
accept_len = tl.load(accept_length_ptr + batch_id)
accept_len = tl.load(num_accept_tokens_ptr + batch_id)

if seq_pos >= accept_len:
return
Expand Down Expand Up @@ -745,7 +745,7 @@ def unpad_draft_extend_output(
unpad_draft_extend_output_kernel[grid](
raw_out_ptr=raw_out,
output_ptr=output,
accept_length_ptr=seq_lens_q,
num_accept_tokens_ptr=seq_lens_q,
cumsum_ptr=cu_seqlens_q,
batch_size=batch_size,
token_per_batch=token_per_batch,
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def forward_extend(
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
needs_unpad = False
else:
# draft_extend: handle varying num_accepted_drafts_per_req. If total_tokens % bs == 0,
# draft_extend: handle varying num_correct_drafts_per_req. If total_tokens % bs == 0,
# we can directly reshape q; otherwise, pad to max_seq_len_q.
total_tokens = q.shape[0]
tokens_per_seq = total_tokens // bs if bs > 0 else 0
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/attention/wave_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
)
mask_indptr = None
# TODO(FIXME): This will trigger an invalid Eagle tree when using
# `max(spec_info.num_accepted_tokens_cpu)`.
# `max(spec_info.num_accept_tokens_cpu)`.
# It might have been forgotten to update somewhere.
max_extend_len = torch.max(spec_info.num_accepted_tokens).item()
max_extend_len = torch.max(spec_info.num_accept_tokens).item()
num_kv_splits = None
attn_logits = None
attn_lse = None
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/utils/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,11 @@ def add_output_logprobs_for_spec_v1(
if logits_output is None:
logits_output = res.logits_output

if hasattr(res, "num_accepted_drafts_per_req_cpu"):
num_accepted_drafts_per_req_cpu = res.num_accepted_drafts_per_req_cpu
if hasattr(res, "num_correct_drafts_per_req_cpu"):
num_correct_drafts_per_req_cpu = res.num_correct_drafts_per_req_cpu
else:
# FIXME: Get a NgramVerifyOutput class and use that instead of this hack.
num_accepted_drafts_per_req_cpu = res.num_accepted_drafts.tolist()
num_correct_drafts_per_req_cpu = res.num_correct_drafts.tolist()

top_logprobs_nums = batch.top_logprobs_nums
token_ids_logprobs = batch.token_ids_logprobs
Expand All @@ -363,7 +363,7 @@ def add_output_logprobs_for_spec_v1(
logits_output.next_token_logits / temperatures, dim=-1
)
batch_next_token_ids = res.accept_tokens
num_tokens_per_req = [accept + 1 for accept in num_accepted_drafts_per_req_cpu]
num_tokens_per_req = [accept + 1 for accept in num_correct_drafts_per_req_cpu]

# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = [
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ def handle_batch_token_id_out(self, recv_obj: BatchTokenIDOutput):
cached_tokens=recv_obj.cached_tokens,
cached_tokens_details=recv_obj.cached_tokens_details,
spec_verify_ct=recv_obj.spec_verify_ct,
spec_accepted_drafts=recv_obj.spec_accepted_drafts,
spec_acceptance_histogram=recv_obj.spec_acceptance_histogram,
spec_num_correct_drafts=recv_obj.spec_num_correct_drafts,
spec_correct_drafts_histogram=recv_obj.spec_correct_drafts_histogram,
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ class SpeculativeDecodingMetricsMixin:

# Accepted drafts: Number of accepted draft tokens during speculative decoding
# (strict drafts-only count, excludes the bonus token).
spec_accepted_drafts: List[int]
spec_num_correct_drafts: List[int]

# Acceptance histogram: List of lists, where each inner list represents histogram counts.
# List index = number of accepted tokens in a step, List value = count of steps with that many accepted tokens.
# Example: histogram[0] = 5 means 5 steps with 0 accepted tokens, histogram[3] = 10 means 10 steps with 3 accepted tokens.
# Empty list [] when speculative decoding is disabled.
spec_acceptance_histogram: List[List[int]]
spec_correct_drafts_histogram: List[List[int]]


# Parameters for a session
Expand Down
Loading
Loading