Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2974,6 +2974,7 @@ def run_batch(
# TODO(lsyin): delete this branch after unifying the abstraction.
worker_batch_or_batch = batch

future_indices = None
if self.enable_overlap:
model_worker_batch = worker_batch_or_batch
self.record_batch_in_overlap(model_worker_batch)
Expand Down Expand Up @@ -3002,17 +3003,6 @@ def run_batch(

# FIXME(lsyin): move this assignment elsewhere
future_indices_or_next_token_ids = -future_indices.indices

if batch.is_spec_v2:
# FIXME(lsyin): tmp code for spec v2
# We only keep future indices for next draft input

batch.spec_info = batch_result.next_draft_input
batch.spec_info.future_indices = future_indices

# The future value, usually for next batch preparation
# Current implementation strictly synchronizes the seq_lens
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
elif self.enable_pdmux and batch.forward_mode.is_split_prefill():
batch_result = self.tp_worker.forward_batch_split_prefill(batch)
future_indices_or_next_token_ids = batch_result.next_token_ids
Expand All @@ -3028,6 +3018,17 @@ def run_batch(
future_indices_or_next_token_ids = batch_result.next_token_ids
self.update_cache_from_scheduler(batch, batch_result)

if batch_result.next_draft_input is not None:
batch.spec_info = batch_result.next_draft_input
if batch.is_spec_v2:
# FIXME(lsyin): tmp code for spec v2
# We only keep future indices for next draft input
batch.spec_info.future_indices = future_indices

# The future value, usually for next batch preparation
# Current implementation strictly synchronizes the seq_lens
batch.seq_lens = batch_result.next_draft_input.new_seq_lens

# NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
# which can probably be replaced by future_indices later [TODO(lsyin)].
# we shall still keep the original outputs, e.g. next_token_ids
Expand Down
12 changes: 8 additions & 4 deletions python/sglang/srt/speculative/eagle_info_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,11 @@ def prepare_for_extend_to_fill_draft_kvcache(
draft_model_runner: Any,
cuda_graph_runner: Any,
):
# Caller is responsible for `batch.spec_info = self` before calling.
assert batch.spec_info is self
seq_lens_cpu_ = batch.seq_lens_cpu
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens

batch.spec_info = self
batch.input_ids = predict
batch.seq_lens = batch.seq_lens + num_draft_tokens
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
Expand Down Expand Up @@ -323,15 +324,18 @@ def prepare_for_v2_verify(

return verify_forward_batch, can_run_cuda_graph

def sample(
def verify_v2(
self: EagleVerifyInput,
batch: ModelWorkerBatch,
logits_output: LogitsProcessorOutput,
vocab_mask: torch.Tensor = None,
) -> Tuple["EagleVerifyOutput", torch.Tensor]:
"""
Verify and find accepted tokens based on logits output and batch
(which contains spec decoding information).
V2 counterpart to `EagleVerifyInput.verify` (V1). Sample target tokens,
verify against drafts, and produce an `EagleVerifyOutput`.

Cannot be named `verify` because the V1 method is defined directly on
`EagleVerifyInput` and would shadow this mixin method via MRO.

Returns `(verify_output, predict)` where `predict` is the full
per-position sampled tokens (V2 caller uses it as `next_token_ids`).
Expand Down
154 changes: 77 additions & 77 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch

from sglang.srt.distributed import get_tp_group
from sglang.srt.hardware_backend.npu.graph_runner.eagle_draft_npu_graph_runner import (
EAGLEDraftNpuGraphRunner,
)
Expand Down Expand Up @@ -162,7 +161,7 @@ def __init__(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
pp_rank=0, # FIXME
pp_rank=0, # spec workers don't support pipeline parallelism
dp_rank=dp_rank,
moe_ep_rank=moe_ep_rank,
attn_cp_rank=attn_cp_rank,
Expand Down Expand Up @@ -462,7 +461,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
self.forward_draft_extend(
next_draft_input = self.forward_draft_extend(
batch,
logits_output.hidden_states,
next_token_ids,
Expand All @@ -474,6 +473,7 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
next_token_ids=next_token_ids,
num_accepted_drafts=0,
can_run_cuda_graph=can_run_cuda_graph,
next_draft_input=next_draft_input,
)
else:
set_time_batch(batch.reqs, "set_spec_draft_start_time", trace_only=True)
Expand Down Expand Up @@ -502,25 +502,49 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
with self.draft_tp_context(
self.draft_model_runner.tp_group
), speculative_moe_backend_context(), speculative_moe_a2a_backend_context():
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
draft_extend_input = verify_output.draft_extend_input
next_draft_input = None
if (
self.server_args.enable_dp_attention
or draft_extend_input.input_ids.shape[0] > 0
):
# decode is not finished
# Install draft_extend_input for the extend forward, then
# install the assembled next-iter EagleDraftInput it returns.
batch.spec_info = draft_extend_input
next_draft_input = self.forward_draft_extend_after_decode(batch)
batch.spec_info = next_draft_input
# decode is not finished. Install draft_extend_input as
# batch.spec_info for the extend forward; the assembled
# next-iter EagleDraftInput is returned via batch_result
# and installed by the scheduler before the next draft.
extend_batch = batch
if (
not batch.forward_mode.is_idle()
and draft_extend_input.input_ids.numel() == 0
):
# All reqs finished this verify (only reachable when
# dp_attention forces the forward). Run the extend on
# an idle batch copy + idle ExtendInput so the original
# batch is not mutated.
extend_batch = batch.copy()
extend_batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
and self.eagle_use_aux_hidden_state
else self.model_config.spec_hidden_size
)
draft_extend_input = EagleDraftExtendInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
extend_batch.spec_info = draft_extend_input
next_draft_input = self.forward_draft_extend_after_decode(
extend_batch
)
else:
# All reqs finished this verify and dp_attention is not
# forcing the forward. Install an empty EagleDraftInput so
# next iter's merge_batch short-circuits on None
# hidden_states (EagleVerifyInput has no merge_batch).
batch.spec_info = EagleDraftInput(
# All reqs finished and dp_attention is not forcing the
# forward. Use an empty EagleDraftInput so next iter's
# merge_batch short-circuits on None hidden_states
# (EagleVerifyInput has no merge_batch).
next_draft_input = EagleDraftInput(
capture_hidden_mode=CaptureHiddenMode.LAST,
)

Expand All @@ -539,28 +563,9 @@ def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResul
num_accepted_drafts=sum(verify_output.num_accepted_drafts_per_req_cpu),
num_accepted_drafts_per_req_cpu=verify_output.num_accepted_drafts_per_req_cpu,
can_run_cuda_graph=verify_output.can_run_cuda_graph,
next_draft_input=next_draft_input,
)

def check_forward_draft_extend_after_decode(
self, batch: ScheduleBatch, verify_output: EagleVerifyOutput
):
local_need_forward = verify_output.draft_extend_input.input_ids.shape[0] > 0
if not self.server_args.enable_dp_attention:
return local_need_forward

global_need_forward = torch.tensor(
[
(local_need_forward),
],
dtype=torch.int64,
)
torch.distributed.all_reduce(
global_need_forward, group=get_tp_group().cpu_group
)
global_need_forward_cnt = global_need_forward[0].item()
need_forward = global_need_forward_cnt > 0
return need_forward

def forward_target_extend(
self, batch: ScheduleBatch
) -> Tuple[LogitsProcessorOutput, torch.Tensor, Optional[torch.Tensor], bool]:
Expand Down Expand Up @@ -1088,41 +1093,54 @@ def forward_draft_extend(
next_token_ids: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
mm_input_embeds: Optional[torch.Tensor] = None,
):
"""Run draft model extend. This API modifies the states of the batch.
) -> EagleDraftInput:
"""Run draft model extend. Returns next-iter `EagleDraftInput`;
scheduler installs it on `batch.spec_info` via `batch_result.next_draft_input`.

Args:
batch: The batch to run.
hidden_states: Hidden states from the target model forward
next_token_ids: Next token ids generated from the target forward.
We mutate `batch.spec_info` transiently so the forward kernel can read
the draft input via `batch.get_model_worker_batch`, then restore on
return.
"""
batch.spec_info = EagleDraftInput(
next_draft_input = EagleDraftInput(
hidden_states=hidden_states,
bonus_tokens=next_token_ids,
num_tokens_per_req=1,
num_tokens_for_logprob_per_req=1,
)
return_hidden_states_backup = batch.return_hidden_states
spec_info_backup = batch.spec_info

batch.spec_info = next_draft_input
batch.return_hidden_states = False
batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
forward_batch.return_logprob = False
if mm_input_embeds is not None:
forward_batch.mm_input_embeds = mm_input_embeds
logits_output = self.draft_model_runner.forward(forward_batch).logits_output
maybe_detect_nan(logits_output.next_token_logits, "draft_extend_for_prefill")
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info
self.capture_for_decode(logits_output, forward_batch.spec_info)
next_draft_input.prepare_for_extend(batch)
next_draft_input.capture_hidden_mode = CaptureHiddenMode.LAST
try:
model_worker_batch = batch.get_model_worker_batch(
seq_lens_cpu_cache=seq_lens_cpu
)
forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner
)
forward_batch.return_logprob = False
if mm_input_embeds is not None:
forward_batch.mm_input_embeds = mm_input_embeds
logits_output = self.draft_model_runner.forward(forward_batch).logits_output
maybe_detect_nan(
logits_output.next_token_logits, "draft_extend_for_prefill"
)
self.capture_for_decode(logits_output, next_draft_input)
finally:
batch.spec_info = spec_info_backup
batch.return_hidden_states = return_hidden_states_backup

return next_draft_input

def forward_draft_extend_after_decode(
self, batch: ScheduleBatch
) -> EagleDraftInput:
# Caller installs the EagleDraftExtendInput on `batch.spec_info`
# (using either the verify-produced one or an idle one for the
# all-finished + dp_attention edge case).
draft_extend_input: EagleDraftExtendInput = batch.spec_info

# Backup fields that will be modified in-place
Expand All @@ -1133,24 +1151,6 @@ def forward_draft_extend_after_decode(

input_is_idle = batch.forward_mode.is_idle()

if not input_is_idle and draft_extend_input.input_ids.numel() == 0:
# All reqs finished this verify; swap to an idle ExtendInput.
batch = batch.copy()
batch.prepare_for_idle()
hidden_size = (
self.model_config.hidden_size * 3
if self.speculative_algorithm.is_eagle3()
and self.eagle_use_aux_hidden_state
else self.model_config.spec_hidden_size
)
draft_extend_input = EagleDraftExtendInput.create_idle_input(
device=self.device,
hidden_size=hidden_size,
dtype=self.model_config.dtype,
capture_hidden_mode=CaptureHiddenMode.LAST,
)
batch.spec_info = draft_extend_input

# Phase 1: prepare extend (kernel writes draft_extend_input.{positions, bonus_tokens})
draft_extend_input.num_tokens_per_req = self.speculative_num_steps + 1
draft_extend_input.num_tokens_for_logprob_per_req = 1
Expand Down
9 changes: 7 additions & 2 deletions python/sglang/srt/speculative/eagle_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
pp_rank=0, # FIXME
pp_rank=0, # spec workers don't support pipeline parallelism
dp_rank=dp_rank,
moe_ep_rank=moe_ep_rank,
attn_cp_rank=attn_cp_rank,
Expand Down Expand Up @@ -590,6 +590,9 @@ def _draft_extend_for_decode(
- 1
)

# Install draft_extend_input as `batch.spec_info` for the extend forward.
batch.spec_info = draft_extend_input

# Prepare for draft extend in a separate stream
with self.plan_stream_ctx:
forward_batch = draft_extend_input.prepare_for_extend_to_fill_draft_kvcache(
Expand Down Expand Up @@ -1027,7 +1030,9 @@ def verify(self, batch: ModelWorkerBatch):

# Sample
maybe_detect_nan(logits_output.next_token_logits, "verify: target model logits")
verify_output, predict = verify_input.sample(batch, logits_output, vocab_mask)
verify_output, predict = verify_input.verify_v2(
batch, logits_output, vocab_mask
)
accept_lens = verify_output.draft_extend_input.num_accepted_tokens
new_seq_lens = batch.seq_lens + accept_lens

Expand Down
Loading
Loading