Skip to content
Closed
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,9 +767,7 @@ def init_cuda_graph_state(self, max_bs: int):
"cu_seqlens_q": torch.arange(
0, max_bs + 1, dtype=torch.int32, device=self.device
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
# cu_seqlens_k will be computed in capture_cuda_graph
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
Expand All @@ -789,9 +787,7 @@ def init_cuda_graph_state(self, max_bs: int):

self.target_verify_metadata = {
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
"cu_seqlens_q": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
# cu_seqlens_q will be computed in capture_cuda_graph
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
Expand Down Expand Up @@ -958,11 +954,15 @@ def init_forward_metadata_replay_cuda_graph(
)
)

page_table = self.req_to_token[
req_pool_indices, : metadata.max_seq_len_k
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
:,
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
]

metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
page_indices = page_indices[req_pool_indices] // self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
else:
# Normal Decode
max_len = seq_lens_cpu.max().item()
Expand All @@ -984,7 +984,6 @@ def init_forward_metadata_replay_cuda_graph(
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
metadata.page_table[:, max_seq_pages:].fill_(0)

elif forward_mode.is_target_verify():
metadata = self.target_verify_metadata[bs]
Expand All @@ -1003,8 +1002,16 @@ def init_forward_metadata_replay_cuda_graph(
(1, 0),
)
)
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)

max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
:,
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
]
page_indices = page_indices[req_pool_indices] // self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)

if encoder_lens is not None:
# Only support encoder size 1 for now
Expand Down
Loading