Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,7 @@ def init_forward_metadata_replay_cuda_graph(

def get_cuda_graph_seq_len_fill_value(self):
"""Get the fill value for sequence length in CUDA graph."""
return 0
return 1

def _init_local_attn_metadata(self, metadata: FlashAttentionMetadata, device):
"""Centralized utility to initialize local_attn_metadata if chunked attention is enabled."""
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def init_forward_metadata_replay_cuda_graph(
raise ValueError("Invalid forward mode")

def get_cuda_graph_seq_len_fill_value(self):
return 0
return 1

def forward_extend(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def init_forward_metadata_replay_cuda_graph(
raise ValueError(f"Invalid forward mode: {forward_mode=}")

def get_cuda_graph_seq_len_fill_value(self):
return 0
return 1

def forward_extend(
self,
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ def replay_prepare(
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()

# Common inputs
Expand All @@ -624,7 +624,7 @@ def replay_prepare(

if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)

if pp_proxy_tensors:
Expand Down Expand Up @@ -652,7 +652,7 @@ def replay_prepare(
bs,
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs),
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ def replay(self, forward_batch: ForwardBatch):
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(1)
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.positions.zero_()

num_tokens = bs * self.num_tokens_per_bs

Expand All @@ -211,15 +210,15 @@ def replay(self, forward_batch: ForwardBatch):
forward_batch.req_pool_indices = self.req_pool_indices[:bs]
forward_batch.positions = self.positions[:num_tokens]

# Special handle for seq_len_cpu used when flashinfer mla is used
if forward_batch.seq_lens_cpu is not None and bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]

self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs
)
# TODO: The forward_batch.seq_len_sum might need to be updated to reflect the padding in the cuda graph

# Replay
self.graphs[bs].replay()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def replay(self, forward_batch: ForwardBatch):
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs * self.num_tokens_per_bs != num_tokens:
self.seq_lens.fill_(1)
self.accept_length.fill_(1)
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
self.accept_length.fill_(1)

# Common inputs
self.input_ids[:num_tokens].copy_(forward_batch.input_ids)
Expand All @@ -223,18 +223,19 @@ def replay(self, forward_batch: ForwardBatch):

if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)

if bs != raw_bs:
forward_batch.spec_info.positions = self.positions[:num_tokens]
forward_batch.spec_info.accept_length = self.accept_length[:bs]
forward_batch.spec_info.positions = None

self.eagle_worker.draft_extend_attn_backend.init_forward_metadata_replay_cuda_graph(
bs=bs,
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
seq_lens_sum=forward_batch.seq_lens_sum + (bs - raw_bs),
seq_lens_sum=forward_batch.seq_lens_sum
+ (bs - raw_bs) * self.seq_len_fill_value,
encoder_lens=None,
forward_mode=ForwardMode.DRAFT_EXTEND,
spec_info=forward_batch.spec_info,
Expand Down
8 changes: 4 additions & 4 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def __init__(

def init_attention_backend(self):
# Create multi-step attn backends and cuda graph runners

self.has_prefill_wrapper_verify = False
self.draft_extend_attn_backend = None

if self.server_args.attention_backend == "flashinfer":
if not global_server_args_dict["use_mla_backend"]:
from sglang.srt.layers.attention.flashinfer_backend import (
Expand Down Expand Up @@ -201,7 +205,6 @@ def init_attention_backend(self):
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "fa3":
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
Expand All @@ -217,7 +220,6 @@ def init_attention_backend(self):
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = False
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import (
FlashMLAMultiStepDraftBackend,
Expand All @@ -228,8 +230,6 @@ def init_attention_backend(self):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.has_prefill_wrapper_verify = False
else:
raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
Expand Down
Loading