Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
4cc5501
Add Speculative Decoding Eagle3 topk > 1
qingquansong Apr 11, 2025
6d5247c
Merge remote-tracking branch 'upstream/main' into qsong/sdtopk
qingquansong Apr 12, 2025
2cabcde
Merge branch 'main' into qsong/sdtopk
hebiao064 Apr 12, 2025
a431024
Support Cuda Graph for Draft Decode when topk > 1
hebiao064 Apr 12, 2025
855755a
Support CUDA Graph for Target Verfy
hebiao064 Apr 12, 2025
d29608a
set metadata expand
hebiao064 Apr 12, 2025
121021f
fix
hebiao064 Apr 13, 2025
2afa5fb
update cuda graph
qingquansong Apr 13, 2025
202867a
Fix problem which break normal path
hebiao064 Apr 13, 2025
e1eb605
switch to vllm merge state
qingquansong Apr 13, 2025
2122a0a
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 14, 2025
d6e9cc9
clean up
qingquansong Apr 14, 2025
02a1a15
support deepseek
hebiao064 Apr 14, 2025
3bf8e77
remove verify expand attention mask pad
qingquansong Apr 14, 2025
0bdb3b8
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 15, 2025
64f1c0f
switch to merge_state v2
qingquansong Apr 15, 2025
e73c9c4
update to triton
qingquansong Apr 15, 2025
e140f6d
addd mode
qingquansong Apr 15, 2025
6381207
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 15, 2025
6c67661
rebase
qingquansong Apr 15, 2025
36668c2
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 15, 2025
0ec0422
remove comment
hebiao064 Apr 15, 2025
e45368a
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 15, 2025
e78a96f
cleanup
qingquansong Apr 15, 2025
57cec66
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 15, 2025
ab16678
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 15, 2025
f133545
fix
hebiao064 Apr 16, 2025
5f9a0d6
fix
hebiao064 Apr 16, 2025
336009f
fix test
hebiao064 Apr 16, 2025
beb981f
Merge branch 'main' into qsong/sdtopk
hebiao064 Apr 16, 2025
190dcfa
fix
hebiao064 Apr 17, 2025
0a2606a
Revert "fix"
hebiao064 Apr 17, 2025
1452022
Revert "fix test"
hebiao064 Apr 17, 2025
de0bb77
fix
hebiao064 Apr 17, 2025
43377ab
Merge branch 'main' into qsong/sdtopk
hebiao064 Apr 17, 2025
669ae0d
fix
hebiao064 Apr 17, 2025
d7ecff3
Merge remote-tracking branch 'upstream/main' into qsong/sdtopk
qingquansong Apr 17, 2025
deb083c
format
qingquansong Apr 17, 2025
30bfa5c
remove submodule
qingquansong Apr 17, 2025
039d5b8
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 17, 2025
48caaaf
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 17, 2025
e16c2e2
fix
hebiao064 Apr 18, 2025
5909701
fix rebase
qingquansong Apr 18, 2025
3fca5dd
Merge branch 'main' into qsong/sdtopk
hebiao064 Apr 18, 2025
92a9307
address comment about return_softmax_lse
hebiao064 Apr 18, 2025
595dd70
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 18, 2025
352a83e
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 18, 2025
506653a
fix
hebiao064 Apr 19, 2025
7a40ffe
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 19, 2025
68d43d2
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 19, 2025
ef01c2d
enable fa3 for broader use case
qingquansong Apr 19, 2025
cf7cb69
fix format and remove is_no_spec_infer_or_topk_one
hebiao064 Apr 19, 2025
67691f2
support page size > 1 for top k = 1
hebiao064 Apr 19, 2025
788051f
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 19, 2025
c1b6f87
fix
hebiao064 Apr 19, 2025
6fe3ec0
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 19, 2025
f648298
Merge branch 'main' into qsong/sdtopk
zhyncs Apr 19, 2025
beee6a0
Update model_runner.py typo
hebiao064 Apr 19, 2025
cac34ec
Merge branch 'main' into qsong/sdtopk
zhyncs Apr 20, 2025
6469e62
auto adjust draft_tokens = num_steps + 1
hebiao064 Apr 20, 2025
8af3ec1
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 20, 2025
67e0095
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 20, 2025
b966c6f
add test for top k > 1
hebiao064 Apr 20, 2025
9985da0
Merge branch 'main' into qsong/sdtopk
zhyncs Apr 20, 2025
73c1868
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 119 additions & 69 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

from sglang.srt.utils import is_flashinfer_available

if is_flashinfer_available():
from flashinfer.cascade import merge_state
# if is_flashinfer_available():
# from flashinfer.cascade import merge_state

from vllm.v1.attention.backends.flash_attn import merge_attn_states

@dataclass
class FlashAttentionMetadata:
Expand Down Expand Up @@ -474,15 +475,21 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
]

metadata_expand = FlashAttentionMetadata()
metadata_expand.cache_seqlens_int32 = torch.full(
(
forward_batch.seq_lens.numel()
* self.speculative_num_draft_tokens,
),
self.speculative_num_draft_tokens,
# metadata_expand.cache_seqlens_int32 = torch.full(
# (
# forward_batch.seq_lens.numel()
# * self.speculative_num_draft_tokens,
# ),
# self.speculative_num_draft_tokens,
# device=device,
# dtype=torch.int32,
# )
metadata_expand.cache_seqlens_int32 = torch.arange(
1,
self.speculative_num_draft_tokens + 1,
device=device,
dtype=torch.int32,
)
).repeat(forward_batch.seq_lens.numel())
metadata_expand.max_seq_len_q = 1
metadata_expand.max_seq_len_k = self.speculative_num_draft_tokens
metadata_expand.cu_seqlens_q = torch.arange(
Expand All @@ -491,14 +498,20 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
dtype=torch.int32,
device=device,
)
metadata_expand.cu_seqlens_k = torch.arange(
0,
metadata_expand.cache_seqlens_int32.numel()
* self.speculative_num_draft_tokens
+ 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=device,
# metadata_expand.cu_seqlens_k = torch.arange(
# 0,
# metadata_expand.cache_seqlens_int32.numel()
# * self.speculative_num_draft_tokens
# + 1,
# step=self.speculative_num_draft_tokens,
# dtype=torch.int32,
# device=device,
# )
metadata.cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
),
(1, 0),
)
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
Expand Down Expand Up @@ -776,12 +789,20 @@ def forward_extend(
v_descale=v_descale,
return_softmax_lse=True,
)
o, _ = merge_state(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
# o, _ = merge_state(
# o,
# softmax_lse.T.contiguous(),
# o_expand,
# softmax_lse_expand.T.contiguous(),
# )

# Run the kernel.
output = torch.empty_like(o)
merge_attn_states(output,
o, softmax_lse.contiguous(),
o_expand,
softmax_lse_expand.contiguous())
o = output
# torch.distributed.breakpoint()
else:
# Do absorbed multi-latent attention
Expand Down Expand Up @@ -932,12 +953,18 @@ def forward_decode(
v_descale=v_descale,
return_softmax_lse=True,
)
o, _ = merge_state(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
# o, _ = merge_state(
# o,
# softmax_lse.T.contiguous(),
# o_expand,
# softmax_lse_expand.T.contiguous(),
# )
output = torch.empty_like(o)
merge_attn_states(output,
o, softmax_lse.contiguous(),
o_expand,
softmax_lse_expand.contiguous())
o = output
# print("draft decode merge finish! id:", self.speculative_step_id)
# torch.distributed.breakpoint()
else:
Expand Down Expand Up @@ -1121,35 +1148,48 @@ def init_cuda_graph_state(self, max_bs: int):
}

self.target_verify_metadata_topk_expand = {
"cache_seqlens": torch.full(
(max_bs * self.speculative_num_draft_tokens,),
self.speculative_num_draft_tokens,
dtype=torch.int32,
# "cache_seqlens": torch.full(
# (max_bs * self.speculative_num_draft_tokens,),
# self.speculative_num_draft_tokens,
# dtype=torch.int32,
# device=self.device,
# ),
"cache_seqlens": torch.arange(
1,
self.speculative_num_draft_tokens + 1,
device=self.device,
),
dtype=torch.int32,
).repeat(max_bs),
"cu_seqlens_q": torch.arange(
0,
max_bs * self.speculative_num_draft_tokens + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.arange(
0,
max_bs
* self.speculative_num_draft_tokens
* self.speculative_num_draft_tokens
+ 1,
step=self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
# "cu_seqlens_k": torch.arange(
# 0,
# max_bs
# * self.speculative_num_draft_tokens
# * self.speculative_num_draft_tokens
# + 1,
# step=self.speculative_num_draft_tokens,
# dtype=torch.int32,
# device=self.device,
# ),
"page_table": torch.zeros(
max_bs * self.speculative_num_draft_tokens,
self.speculative_num_draft_tokens,
dtype=torch.int32,
device=self.device,
),
}
self.target_verify_metadata_topk_expand["cu_seqlens_k"] = torch.nn.functional.pad(
torch.cumsum(
self.target_verify_metadata_topk_expand["cache_seqlens"], dim=0, dtype=torch.int32
),
(1, 0),
)


# Biao's Note: Consolidate the encoder metadata

Expand Down Expand Up @@ -1442,7 +1482,12 @@ def init_forward_metadata_replay_cuda_graph(
cache_loc[:, :decode_length].contiguous().to(torch.int32)
)
# may not need
metadata_expand.page_table[cache_loc.shape[0] :].fill_(0)
# metadata_expand.page_table[cache_loc.shape[0] :].fill_(0)
# metadata.page_table[cache_loc.shape[0] :].fill_(0)

# print("decode draft step id", self.speculative_step_id)
# print("metadata", metadata)
# print("metadata_expand", metadata_expand)
else:
metadata = self.decode_cuda_graph_metadata[bs]
# Normal Decode
Expand Down Expand Up @@ -1509,6 +1554,7 @@ def init_forward_metadata_replay_cuda_graph(
req_pool_indices, : metadata.max_seq_len_k
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
# metadata.page_table[spec_info.positions.numel() :].fill_(0)

# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.target_verify_metadata_topk_expand[bs]
Expand Down Expand Up @@ -1555,31 +1601,35 @@ def init_forward_metadata_replay_cuda_graph(
(non_masked_page_table * mask).to(torch.int32)
)
# may not need
metadata_expand.page_table[spec_info.positions.numel() :].fill_(0)

if encoder_lens is not None:
# Only support encoder size 1 for now
metadata.encoder_max_seq_len_k = encoder_lens[0]
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
metadata.encoder_cu_seqlens_k.copy_(
torch.nn.functional.pad(
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
(1, 0),
)
)

metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
)
# metadata_expand.page_table[spec_info.positions.numel() :].fill_(0)

# print("target verify")
# print("metadata", metadata)
# print("metadata_expand", metadata_expand)

# Update the regular page table
page_table = self.req_to_token[
req_pool_indices,
metadata.encoder_max_seq_len_k : (
metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
),
]
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
# if encoder_lens is not None:
# # Only support encoder size 1 for now
# metadata.encoder_max_seq_len_k = encoder_lens[0]
# metadata.encoder_lens_int32.copy_(encoder_lens[:1])
# metadata.encoder_cu_seqlens_k.copy_(
# torch.nn.functional.pad(
# torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
# (1, 0),
# )
# )

# metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
# self.req_to_token[req_pool_indices, : metadata.encoder_max_seq_len_k]
# )

# # Update the regular page table
# page_table = self.req_to_token[
# req_pool_indices,
# metadata.encoder_max_seq_len_k : (
# metadata.encoder_max_seq_len_k + metadata.max_seq_len_k
# ),
# ]
# metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)

self.forward_metadata = metadata
self.forward_metadata_spec_decode_expand = metadata_expand
Expand Down
Loading