diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index c3f494a3ce7..7a5c4310b94 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -51,7 +51,7 @@ class TestFile: TestFile("test_mla_int8_deepseek_v3.py", 389), TestFile("test_mla_flashinfer.py", 395), TestFile("test_mla_fp8.py", 153), - TestFile("test_flash_mla_attention_backend.py", 300), + TestFile("test_flashmla.py", 300), TestFile("test_no_chunked_prefill.py", 108), TestFile("test_no_overlap_scheduler.py", 216), TestFile("test_openai_server.py", 149), diff --git a/test/srt/test_flash_mla_attention_backend.py b/test/srt/test_flash_mla_attention_backend.py deleted file mode 100644 index 8d895d2ebaf..00000000000 --- a/test/srt/test_flash_mla_attention_backend.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Usage: -python3 -m unittest test_flash_mla_attention_backend.TestFlashMLAAttnBackend.test_mmlu -""" - -import unittest -from types import SimpleNamespace - -from sglang.srt.utils import kill_process_tree -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - is_in_ci, - popen_launch_server, - run_bench_one_batch, -) - - -class TestFlashMLAAttnBackend(unittest.TestCase): - def test_latency(self): - output_throughput = run_bench_one_batch( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, - [ - "--attention-backend", - "flashmla", - "--enable-torch-compile", - "--cuda-graph-max-bs", - "16", - "--trust-remote-code", - ], - ) - - if is_in_ci(): - self.assertGreater(output_throughput, 153) - - def test_mmlu(self): - model = DEFAULT_MLA_MODEL_NAME_FOR_TEST - base_url = DEFAULT_URL_FOR_TEST - process = popen_launch_server( - model, - base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--attention-backend", "flashmla", "--trust-remote-code"], - ) - - try: - args = SimpleNamespace( - base_url=base_url, - model=model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.2) - finally: - kill_process_tree(process.pid) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/test_flashmla.py b/test/srt/test_flashmla.py new file mode 100644 index 00000000000..f546322a751 --- /dev/null +++ b/test/srt/test_flashmla.py @@ -0,0 +1,86 @@ +""" +Usage: +python3 test/srt/test_flashmla.py +""" + +import os +import unittest +from types import SimpleNamespace + +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + is_in_ci, + popen_launch_server, + run_bench_one_batch, +) + + +class TestFlashMLAAttnBackend(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = ["--trust-remote-code"] + if torch.cuda.is_available() and torch.version.cuda: + other_args.extend( + [ + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + "--attention-backend", + "flashmla", + ] + ) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + +class TestFlashMLAAttnLatency(unittest.TestCase): + def test_latency(self): + output_throughput = run_bench_one_batch( + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + [ + "--attention-backend", + "flashmla", + "--enable-torch-compile", + "--cuda-graph-max-bs", + "16", + "--trust-remote-code", + ], + ) + + if is_in_ci(): + self.assertGreater(output_throughput, 100) + + +if __name__ == "__main__": + unittest.main()