Skip to content
Merged
Changes from 19 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4e2b3f4
fix flashmla bug
sleepcoo Apr 11, 2025
718e565
Merge branch 'main' into fix_mla_head
sleepcoo Apr 11, 2025
42a5015
Merge branch 'main' into fix_mla_head
zhyncs Apr 13, 2025
69df8a4
Merge branch 'main' into fix_mla_head
sleepcoo Apr 14, 2025
dc5841f
Merge branch 'main' into fix_mla_head
sleepcoo Apr 14, 2025
9ccb3a4
Merge branch 'main' into fix_mla_head
Fridge003 Apr 17, 2025
e12fe56
Description of your changes
PopSoda2002 Apr 19, 2025
411f3a8
Add test
PopSoda2002 Apr 19, 2025
08f8940
Add FlashMLA backend unit tests
PopSoda2002 Apr 19, 2025
2c64672
Add FlashMLA backend unit tests
PopSoda2002 Apr 19, 2025
01d2a3d
Merge remote-tracking branch 'sleepcoo/fix_mla_head' into feature/test
PopSoda2002 Apr 19, 2025
97b8cb1
Merge
PopSoda2002 Apr 21, 2025
2b31408
using dpsk v2 lite model
PopSoda2002 Apr 23, 2025
f2fee49
Description of your changes
PopSoda2002 Apr 19, 2025
fe9f706
Add test
PopSoda2002 Apr 19, 2025
37d97d8
Add FlashMLA backend unit tests
PopSoda2002 Apr 19, 2025
59a90a6
Add FlashMLA backend unit tests
PopSoda2002 Apr 19, 2025
afb3ff8
using dpsk v2 lite model
PopSoda2002 Apr 23, 2025
aaa5856
using dpsk v2 lite model
PopSoda2002 Apr 23, 2025
dd54edd
Update flash mla attention backend test
PopSoda2002 Apr 30, 2025
5cafa02
Add test in CI
PopSoda2002 Apr 30, 2025
fe62c71
Delete gsm8k
PopSoda2002 Apr 30, 2025
05fb284
resolve conflicts
PopSoda2002 Apr 30, 2025
9a39286
Add CI
PopSoda2002 Apr 30, 2025
27bb705
Fix linting
PopSoda2002 Apr 30, 2025
d1c7bb0
fix
PopSoda2002 Apr 30, 2025
e548b6b
Merge branch 'main' into feature/test
Fridge003 Apr 30, 2025
cce7948
Merge branch 'main' into feature/test
Fridge003 May 1, 2025
11c3d2e
Add FlashMLA dependency to CMakeLists.txt
PopSoda2002 May 3, 2025
4b94d7b
Add FlashMLA to CI installation script
PopSoda2002 May 3, 2025
7e470ea
Remove kernel flashmla
PopSoda2002 May 3, 2025
089e1b6
Remove kernel flashmla
PopSoda2002 May 3, 2025
6542be8
Merge branch 'main' into feature/test
PopSoda2002 May 3, 2025
dbdece2
Merge branch 'main' into feature/test
PopSoda2002 May 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions test/srt/test_flash_mla_attention_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""
Usage:
python3 -m unittest test_flash_mla_attention_backend.TestFlashMLAAttnBackend.test_mmlu
"""

import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server,
run_bench_one_batch,
)

# Use DeepSeek V3 model for testing
DSV3_MODEL_FOR_TEST = "deepseek-ai/DeepSeek-V2-Lite"


class TestFlashMLAAttnBackend(unittest.TestCase):
def test_latency(self):
output_throughput = run_bench_one_batch(
DSV3_MODEL_FOR_TEST,
[
"--attention-backend",
"flashinfer",
"--enable-torch-compile",
"--cuda-graph-max-bs",
"16",
"--trust-remote-code",
],
)

if is_in_ci():
self.assertGreater(output_throughput, 153)

def test_mmlu(self):
model = DSV3_MODEL_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--attention-backend", "flashmla", "--trust-remote-code"],
)

try:
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.2)
finally:
kill_process_tree(process.pid)

def test_gsm8k(self):
model = DSV3_MODEL_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--attention-backend", "flashmla", "--trust-remote-code"],
)

try:
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="gsm8k",
num_examples=64,
num_threads=32,
)

metrics = run_eval(args)
self.assertGreaterEqual(
metrics["score"], 0.97
) # Higher threshold based on DSV3 GSM8K score from PR
finally:
kill_process_tree(process.pid)


if __name__ == "__main__":
unittest.main()
Loading