Skip to content
Merged
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
4cc5501
Add Speculative Decoding Eagle3 topk > 1
qingquansong Apr 11, 2025
6d5247c
Merge remote-tracking branch 'upstream/main' into qsong/sdtopk
qingquansong Apr 12, 2025
2cabcde
Merge branch 'main' into qsong/sdtopk
hebiao064 Apr 12, 2025
a431024
Support Cuda Graph for Draft Decode when topk > 1
hebiao064 Apr 12, 2025
855755a
Support CUDA Graph for Target Verfy
hebiao064 Apr 12, 2025
d29608a
set metadata expand
hebiao064 Apr 12, 2025
121021f
fix
hebiao064 Apr 13, 2025
2afa5fb
update cuda graph
qingquansong Apr 13, 2025
202867a
Fix problem which break normal path
hebiao064 Apr 13, 2025
e1eb605
switch to vllm merge state
qingquansong Apr 13, 2025
2122a0a
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 14, 2025
d6e9cc9
clean up
qingquansong Apr 14, 2025
02a1a15
support deepseek
hebiao064 Apr 14, 2025
3bf8e77
remove verify expand attention mask pad
qingquansong Apr 14, 2025
0bdb3b8
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 15, 2025
64f1c0f
switch to merge_state v2
qingquansong Apr 15, 2025
e73c9c4
update to triton
qingquansong Apr 15, 2025
e140f6d
addd mode
qingquansong Apr 15, 2025
6381207
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 15, 2025
6c67661
rebase
qingquansong Apr 15, 2025
36668c2
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 15, 2025
0ec0422
remove comment
hebiao064 Apr 15, 2025
e45368a
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 15, 2025
e78a96f
cleanup
qingquansong Apr 15, 2025
57cec66
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 15, 2025
ab16678
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 15, 2025
f133545
fix
hebiao064 Apr 16, 2025
5f9a0d6
fix
hebiao064 Apr 16, 2025
336009f
fix test
hebiao064 Apr 16, 2025
beb981f
Merge branch 'main' into qsong/sdtopk
hebiao064 Apr 16, 2025
190dcfa
fix
hebiao064 Apr 17, 2025
0a2606a
Revert "fix"
hebiao064 Apr 17, 2025
1452022
Revert "fix test"
hebiao064 Apr 17, 2025
de0bb77
fix
hebiao064 Apr 17, 2025
43377ab
Merge branch 'main' into qsong/sdtopk
hebiao064 Apr 17, 2025
669ae0d
fix
hebiao064 Apr 17, 2025
d7ecff3
Merge remote-tracking branch 'upstream/main' into qsong/sdtopk
qingquansong Apr 17, 2025
deb083c
format
qingquansong Apr 17, 2025
30bfa5c
remove submodule
qingquansong Apr 17, 2025
039d5b8
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 17, 2025
48caaaf
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 17, 2025
e16c2e2
fix
hebiao064 Apr 18, 2025
5909701
fix rebase
qingquansong Apr 18, 2025
3fca5dd
Merge branch 'main' into qsong/sdtopk
hebiao064 Apr 18, 2025
92a9307
address comment about return_softmax_lse
hebiao064 Apr 18, 2025
595dd70
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 18, 2025
352a83e
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 18, 2025
506653a
fix
hebiao064 Apr 19, 2025
7a40ffe
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 19, 2025
68d43d2
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 19, 2025
ef01c2d
enable fa3 for broader use case
qingquansong Apr 19, 2025
cf7cb69
fix format and remove is_no_spec_infer_or_topk_one
hebiao064 Apr 19, 2025
67691f2
support page size > 1 for top k = 1
hebiao064 Apr 19, 2025
788051f
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 19, 2025
c1b6f87
fix
hebiao064 Apr 19, 2025
6fe3ec0
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 19, 2025
f648298
Merge branch 'main' into qsong/sdtopk
zhyncs Apr 19, 2025
beee6a0
Update model_runner.py typo
hebiao064 Apr 19, 2025
cac34ec
Merge branch 'main' into qsong/sdtopk
zhyncs Apr 20, 2025
6469e62
auto adjust draft_tokens = num_steps + 1
hebiao064 Apr 20, 2025
8af3ec1
Merge branch 'qsong/sdtopk' of https://github.com/qingquansong/sglang…
hebiao064 Apr 20, 2025
67e0095
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 20, 2025
b966c6f
add test for top k > 1
hebiao064 Apr 20, 2025
9985da0
Merge branch 'main' into qsong/sdtopk
zhyncs Apr 20, 2025
73c1868
Merge branch 'main' into qsong/sdtopk
qingquansong Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
931 changes: 781 additions & 150 deletions python/sglang/srt/layers/attention/flashattention_backend.py

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,16 @@ def model_specific_adjustment(self):
server_args = self.server_args

if server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
"""
We auto select the fastest attention backend according to the current offering
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 Otherwise, we will use triton backend.
"""

if not self.use_mla_backend:
if (
is_hopper_with_cuda_12_3()
Expand All @@ -234,9 +243,7 @@ def model_specific_adjustment(self):
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
server_args
):
if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"
Expand Down
13 changes: 12 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,18 @@ def __post_init__(self):

if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
logger.info(
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
)

if (
self.speculative_eagle_topk == 1
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
):
logger.info(
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
)
self.speculative_num_draft_tokens = self.speculative_num_steps + 1

# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,8 @@ def is_page_size_one(server_args):
return server_args.page_size == 1


# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
def is_no_spec_infer_or_topk_one(server_args):
return server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None
Expand Down
2 changes: 1 addition & 1 deletion test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TestFile:
TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"),
TestFile("test_fa3.py", 5),
TestFile("test_fa3.py", 200),
TestFile("test_fp8_kernel.py", 8),
TestFile("test_embedding_openai_server.py", 36),
TestFile("test_hidden_states.py", 55),
Expand Down
54 changes: 54 additions & 0 deletions test/srt/test_fa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,60 @@ def test_gsm8k(self):
self.assertGreater(avg_spec_accept_length, 1.5)


class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""

model = "meta-llama/Llama-3.1-8B-Instruct"

@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"8",
"--dtype",
"float16",
]
)
return args

def test_gsm8k(self):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests.get(self.base_url + "/flush_cache")

args = SimpleNamespace(
num_shots=5,
data_path=DATA_PATH,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)

self.assertGreater(metrics["accuracy"], 0.60)

server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.8)


class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled."""

Expand Down