Skip to content

Commit da795d2

Browse files
authored
Merge branch 'main' into fix-cutlass-prepare-moe-input
2 parents efa9dfb + ed0a0b6 commit da795d2

File tree

5 files changed

+51
-50
lines changed

5 files changed

+51
-50
lines changed

python/sglang/srt/layers/attention/flashattention_backend.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,14 +1704,15 @@ def init_forward_metadata_replay_cuda_graph(
17041704

17051705
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
17061706
metadata_expand = self.target_verify_metadata_topk_expand[bs]
1707+
17071708
# metadata_expand.max_seq_len_q = 1, already set in capture
17081709
# metadata_expand.cu_seqlens_q already set in capture
1709-
17101710
offsets = torch.arange(
17111711
self.speculative_num_draft_tokens, device=device
17121712
).unsqueeze(
17131713
0
17141714
) # shape: (1, self.speculative_num_draft_tokens)
1715+
17151716
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
17161717
cum_len = torch.nn.functional.pad(
17171718
torch.cumsum(
@@ -1728,17 +1729,20 @@ def init_forward_metadata_replay_cuda_graph(
17281729
).view(1, -1)
17291730
# avoid extracting padded seq indices which will be out of boundary
17301731
mask_extraction_indices[
1731-
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
1732+
:,
1733+
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
17321734
].fill_(0)
1733-
17341735
mask = spec_info.custom_mask[mask_extraction_indices].view(
17351736
-1, self.speculative_num_draft_tokens
17361737
) # (bsz * draft_num, draft_num)
1738+
17371739
col_indices = offsets.expand(
17381740
mask.shape[0], self.speculative_num_draft_tokens
17391741
)
17401742
keys = torch.where(
1741-
mask, col_indices, col_indices + self.speculative_num_draft_tokens
1743+
mask,
1744+
col_indices,
1745+
col_indices + self.speculative_num_draft_tokens,
17421746
)
17431747
_, sort_order = torch.sort(keys, dim=1)
17441748

@@ -1747,6 +1751,7 @@ def init_forward_metadata_replay_cuda_graph(
17471751
.gather(1, cols)
17481752
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
17491753
) # (bsz, draft_num)
1754+
17501755
metadata_expand.page_table.copy_(
17511756
non_masked_page_table.gather(1, sort_order)
17521757
)
@@ -1758,6 +1763,7 @@ def init_forward_metadata_replay_cuda_graph(
17581763
dtype=torch.int32,
17591764
)
17601765
)
1766+
17611767
elif forward_mode.is_draft_extend():
17621768
metadata = self.draft_extend_metadata[bs]
17631769
metadata.cache_seqlens_int32.copy_(seq_lens)
@@ -1767,7 +1773,11 @@ def init_forward_metadata_replay_cuda_graph(
17671773
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
17681774
)
17691775
accept_length = spec_info.accept_length[:bs]
1770-
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
1776+
if spec_info.accept_length_cpu:
1777+
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
1778+
else:
1779+
metadata.max_seq_len_q = 1
1780+
17711781
metadata.cu_seqlens_q[1:].copy_(
17721782
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
17731783
)

python/sglang/srt/managers/scheduler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,11 +1821,6 @@ def prepare_mlp_sync_batch_raw(
18211821
else:
18221822
can_cuda_graph = 0
18231823

1824-
if not spec_algorithm.is_none():
1825-
# TODO(sang): Support cuda graph when idle batch is there.
1826-
if local_batch is None or local_batch.forward_mode.is_idle():
1827-
can_cuda_graph = 0
1828-
18291824
is_extend_in_batch = (
18301825
local_batch.forward_mode.is_extend() if local_batch else False
18311826
)

python/sglang/srt/model_executor/cuda_graph_runner.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -306,28 +306,30 @@ def __init__(self, model_runner: ModelRunner):
306306
self.encoder_lens = None
307307

308308
if self.require_gathered_buffer:
309+
self.gathered_buffer = torch.zeros(
310+
(
311+
self.max_num_token,
312+
self.model_runner.model_config.hidden_size,
313+
),
314+
dtype=self.model_runner.dtype,
315+
)
309316
if self.require_mlp_tp_gather:
310-
self.gathered_buffer = torch.zeros(
311-
(
312-
self.max_bs * self.dp_size * self.num_tokens_per_bs,
313-
self.model_runner.model_config.hidden_size,
314-
),
315-
dtype=self.model_runner.dtype,
316-
)
317317
self.global_num_tokens_gpu = torch.zeros(
318318
(self.dp_size,), dtype=torch.int32
319319
)
320320
else:
321321
assert self.require_attn_tp_gather
322-
self.gathered_buffer = torch.zeros(
323-
(
324-
self.max_bs * self.num_tokens_per_bs,
325-
self.model_runner.model_config.hidden_size,
326-
),
327-
dtype=self.model_runner.dtype,
328-
)
329322
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
330323

324+
self.custom_mask = torch.ones(
325+
(
326+
(self.seq_lens.sum().item() + self.max_num_token)
327+
* self.num_tokens_per_bs
328+
),
329+
dtype=torch.bool,
330+
device="cuda",
331+
)
332+
331333
# Capture
332334
try:
333335
with model_capture_mode():
@@ -674,19 +676,20 @@ def replay_prepare(
674676
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
675677
if self.enable_two_batch_overlap:
676678
self.tbo_plugin.replay_prepare(
677-
forward_mode=forward_batch.forward_mode,
679+
forward_mode=self.capture_forward_mode,
678680
bs=bs,
679681
num_token_non_padded=len(forward_batch.input_ids),
680682
)
681-
683+
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
684+
forward_batch.spec_info.custom_mask = self.custom_mask
682685
# Attention backend
683686
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
684687
bs,
685688
self.req_pool_indices[:bs],
686689
self.seq_lens[:bs],
687690
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
688691
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
689-
forward_batch.forward_mode,
692+
self.capture_forward_mode,
690693
forward_batch.spec_info,
691694
seq_lens_cpu=self.seq_lens_cpu[:bs],
692695
)
@@ -736,11 +739,7 @@ def get_spec_info(self, num_tokens: int):
736739
else:
737740
spec_info = EagleVerifyInput(
738741
draft_token=None,
739-
custom_mask=torch.ones(
740-
(num_tokens * self.model_runner.model_config.context_len),
741-
dtype=torch.bool,
742-
device="cuda",
743-
),
742+
custom_mask=self.custom_mask,
744743
positions=None,
745744
retrive_index=None,
746745
retrive_next_token=None,

python/sglang/srt/speculative/eagle_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def create_idle_input(
9999
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
100100
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
101101
capture_hidden_mode=capture_hidden_mode,
102+
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
103+
accept_length_cpu=[],
102104
)
103105

104106
def prepare_extend_after_decode(

python/sglang/srt/speculative/eagle_worker.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -322,13 +322,11 @@ def forward_batch_speculative_generation(
322322
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
323323
self.verify(batch, spec_info)
324324
)
325-
need_forward, can_run_draft_extend_cuda_graph = (
326-
self.check_forward_draft_extend_after_decode(batch)
327-
)
328-
if need_forward:
325+
326+
if self.check_forward_draft_extend_after_decode(batch):
329327
with self.draft_tp_context(self.draft_model_runner.tp_group):
330328
self.forward_draft_extend_after_decode(
331-
batch, can_run_draft_extend_cuda_graph
329+
batch,
332330
)
333331
return (
334332
logits_output,
@@ -344,7 +342,7 @@ def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
344342
and batch.spec_info.verified_id.shape[0] > 0
345343
)
346344
if not self.server_args.enable_dp_attention:
347-
return local_need_forward, True
345+
return local_need_forward
348346

349347
global_need_forward = torch.tensor(
350348
[
@@ -357,10 +355,7 @@ def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
357355
)
358356
global_need_forward_cnt = global_need_forward[0].item()
359357
need_forward = global_need_forward_cnt > 0
360-
can_run_draft_extend_cuda_graph = (
361-
global_need_forward_cnt == get_tensor_model_parallel_world_size()
362-
)
363-
return need_forward, can_run_draft_extend_cuda_graph
358+
return need_forward
364359

365360
def forward_target_extend(
366361
self, batch: ScheduleBatch
@@ -816,15 +811,12 @@ def forward_draft_extend(
816811
assert forward_batch.spec_info is batch.spec_info
817812
self.capture_for_decode(logits_output, forward_batch.spec_info)
818813

819-
def forward_draft_extend_after_decode(
820-
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
821-
):
814+
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
822815
# Backup fields that will be modified in-place
823816
seq_lens_backup = batch.seq_lens.clone()
824817
req_pool_indices_backup = batch.req_pool_indices
825818
accept_length_backup = batch.spec_info.accept_length
826819
return_logprob_backup = batch.return_logprob
827-
828820
input_is_idle = batch.forward_mode.is_idle()
829821
if not input_is_idle:
830822
# Prepare metadata
@@ -836,14 +828,18 @@ def forward_draft_extend_after_decode(
836828
else:
837829
batch = batch.copy()
838830
batch.prepare_for_idle()
831+
hidden_size = (
832+
self.model_config.hidden_size * 3
833+
if self.speculative_algorithm.is_eagle3()
834+
else self.model_config.hidden_size
835+
)
839836
batch.spec_info = EagleDraftInput.create_idle_input(
840837
device=self.device,
841-
hidden_size=self.model_config.hidden_size,
838+
hidden_size=hidden_size,
842839
dtype=self.model_config.dtype,
843840
topk=self.topk,
844841
capture_hidden_mode=CaptureHiddenMode.LAST,
845842
)
846-
847843
batch.return_hidden_states = False
848844
model_worker_batch = batch.get_model_worker_batch()
849845
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
@@ -858,8 +854,7 @@ def forward_draft_extend_after_decode(
858854

859855
# Run
860856
can_cuda_graph = (
861-
can_run_draft_extend_cuda_graph
862-
and self.cuda_graph_runner_for_draft_extend
857+
self.cuda_graph_runner_for_draft_extend
863858
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
864859
)
865860
if can_cuda_graph:

0 commit comments

Comments
 (0)