Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
)
# Precompute maximum sequence length
metadata.max_seq_len_k = seqlens_in_batch.max().item()
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
# Precompute page table
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, : metadata.max_seq_len_k
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def call_fn(i, forward_batch):
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)

self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/attention/flashmla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
if spec_info is None:
max_seqlen_pad = triton.cdiv(
forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
Expand Down
25 changes: 12 additions & 13 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,21 +1373,22 @@ def merge_batch(self, other: "ScheduleBatch"):

def get_model_worker_batch(self) -> ModelWorkerBatch:
if self.forward_mode.is_decode_or_idle():
if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):
decode_seq_lens = self.seq_lens.cpu()
else:
decode_seq_lens = None
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
decode_seq_lens = None
extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens

# Create seq_lens_cpu when needed
if (
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):
seq_lens_cpu = self.seq_lens.cpu()
else:
seq_lens_cpu = None

if self.sampling_info:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
Expand All @@ -1410,7 +1411,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
decode_seq_lens=decode_seq_lens,
seq_lens_cpu=seq_lens_cpu,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
Expand Down Expand Up @@ -1471,6 +1472,7 @@ class ModelWorkerBatch:
req_pool_indices: torch.Tensor
# The sequence length
seq_lens: torch.Tensor
seq_lens_cpu: Optional[torch.Tensor]
# The indices of output tokens in the token_to_kv_pool_allocator
out_cache_loc: torch.Tensor

Expand All @@ -1487,9 +1489,6 @@ class ModelWorkerBatch:
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool

# For decode
decode_seq_lens: Optional[torch.Tensor]

# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,10 @@ def replay_prepare(self, forward_batch: ForwardBatch):
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
if forward_batch.decode_seq_lens_cpu is not None:
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)

if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
Expand Down
20 changes: 8 additions & 12 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import triton.language as tl

from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend

if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
Expand Down Expand Up @@ -148,6 +147,9 @@ class ForwardBatch:
# The sum of all sequence lengths
seq_lens_sum: int

# Optional seq_lens on cpu
seq_lens_cpu: Optional[torch.Tensor] = None

# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
Expand All @@ -162,9 +164,6 @@ class ForwardBatch:
# Position information
positions: torch.Tensor = None

# For decode
decode_seq_lens_cpu: Optional[torch.Tensor] = None

# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -293,12 +292,14 @@ def init_new(
):
ret.positions = ret.spec_info.positions

# Get seq_lens_cpu if needed
if ret.seq_lens_cpu is None:
ret.seq_lens_cpu = batch.seq_lens_cpu

# Init position information
if ret.forward_mode.is_decode():
if ret.positions is None:
ret.positions = clamp_position(batch.seq_lens)
if ret.decode_seq_lens_cpu is None:
ret.decode_seq_lens_cpu = batch.decode_seq_lens
ret.positions = torch.clamp((batch.seq_lens - 1), min=0).to(torch.int64)
else:
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
Expand Down Expand Up @@ -518,8 +519,3 @@ def compute_position_torch(
extend_start_loc = torch.zeros_like(extend_seq_lens)
extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
return positions.to(torch.int64), extend_start_loc


@torch.compile(dynamic=True, backend=get_compiler_backend())
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
10 changes: 5 additions & 5 deletions python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ def replay(self, forward_batch: ForwardBatch):
forward_batch.positions = self.positions[:num_tokens]

# Special handle for seq_len_cpu used when flashinfer mla is used
if (forward_batch.decode_seq_lens_cpu is not None) and (bs != raw_bs):
if (forward_batch.seq_lens_cpu is not None) and (bs != raw_bs):
self.seq_lens_cpu.fill_(1)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.decode_seq_lens_cpu)
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:bs]
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:bs]

self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
forward_batch, bs
Expand All @@ -233,7 +233,7 @@ def replay(self, forward_batch: ForwardBatch):
forward_batch.positions = self.positions[:raw_num_token]
forward_batch.seq_lens = self.seq_lens[:raw_bs]
forward_batch.req_pool_indices = self.req_pool_indices[:raw_bs]
if forward_batch.decode_seq_lens_cpu is not None:
forward_batch.decode_seq_lens_cpu = self.seq_lens_cpu[:raw_bs]
if forward_batch.seq_lens_cpu is not None:
forward_batch.seq_lens_cpu = self.seq_lens_cpu[:raw_bs]

return out
Loading