diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 591e3ca604a0..fb99878f818b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -169,7 +169,7 @@ ], "per-commit-4-gpu-b200": [ TestFile("test_deepseek_v3_fp4_4gpu.py", 1800), - TestFile("test_flash_attention_4.py", 300), + TestFile("test_flash_attention_4.py", 90), TestFile("test_fp8_blockwise_gemm.py", 280), TestFile("test_gpt_oss_4gpu.py", 600), TestFile("test_llama31_fp4.py", 300), diff --git a/test/srt/test_flash_attention_4.py b/test/srt/test_flash_attention_4.py index 44623a132c3a..9d81ccd85b49 100644 --- a/test/srt/test_flash_attention_4.py +++ b/test/srt/test_flash_attention_4.py @@ -1,5 +1,6 @@ import unittest from types import SimpleNamespace +from urllib.parse import urlparse from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k @@ -18,8 +19,6 @@ def setUpClass(cls): cls.base_url = DEFAULT_URL_FOR_TEST other_args = [ "--trust-remote-code", - "--mem-fraction-static", - "0.8", "--prefill-attention-backend", "fa4", "--decode-attention-backend", @@ -37,19 +36,20 @@ def tearDownClass(cls): kill_process_tree(cls.process.pid) def test_gsm8k(self): + parsed_url = urlparse(self.base_url) args = SimpleNamespace( - num_shots=4, + num_shots=5, data_path=None, - num_questions=100, + num_questions=1319, max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), + parallel=200, + host=f"{parsed_url.scheme}://{parsed_url.hostname}", + port=parsed_url.port, ) metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.75) + self.assertGreater(metrics["accuracy"], 0.89) if __name__ == "__main__":