From 36ca90e7989afbfd22d0504c043764d75c0f753d Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Thu, 17 Apr 2025 18:39:40 +0000 Subject: [PATCH 01/17] Add Llama 4 to FA3 test --- python/sglang/test/test_utils.py | 3 +++ test/srt/test_fa3.py | 35 ++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 9db2b9a7c3e..13e159a794e 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -45,6 +45,9 @@ DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" +DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST = ( + "meta-llama/Llama-4-Scout-17B-16E-Instruct" +) DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B" DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 90b462aaa38..4313aa5620d 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -6,6 +6,9 @@ from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST, # Llama 4 +) from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -119,6 +122,38 @@ def get_server_args(cls): return args +class TestFlashAttention3MLA(BaseFlashAttentionTest): + """Test FlashAttention3 with Model with local attention, e.g. Llama 4.""" + + model = DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend(["--cuda-graph-max-bs", "2", "--tp", "4"]) + return args + + def test_gsm8k(self): + """ + Override the test_gsm8k to further test for average speculative accept length. + """ + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=5, + data_path=DATA_PATH, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.80) + + class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): """Test FlashAttention3 with speculative decode enabled.""" From 470c00ea04f56fde412f9c0714362baea08b9a4d Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Thu, 17 Apr 2025 18:41:34 +0000 Subject: [PATCH 02/17] fix --- test/srt/test_fa3.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 4313aa5620d..34de97f8157 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -7,9 +7,7 @@ from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( - DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST, # Llama 4 -) -from sglang.test.test_utils import ( + DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, From ea771dce733af69b5708c8fb89a3918c2bb05b4c Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Thu, 17 Apr 2025 20:06:38 +0000 Subject: [PATCH 03/17] fix --- test/srt/test_fa3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 34de97f8157..edb54c90d6f 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -149,7 +149,7 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreater(metrics["accuracy"], 0.80) + self.assertGreater(metrics["accuracy"], 0.70) class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): From 56b7726e540824c30b5daffff0483be65eb5d2e0 Mon Sep 17 00:00:00 2001 From: Stefan He Date: Thu, 17 Apr 2025 14:05:42 -0700 Subject: [PATCH 04/17] Rename test class for local attention --- test/srt/test_fa3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index edb54c90d6f..c1c6e275b3e 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -120,7 +120,7 @@ def get_server_args(cls): return args -class TestFlashAttention3MLA(BaseFlashAttentionTest): +class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): """Test FlashAttention3 with Model with local attention, e.g. Llama 4.""" model = DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST From 1d6e8823e05dd4ec455b47d46be255840db5d360 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Sat, 19 Apr 2025 00:46:59 +0000 Subject: [PATCH 05/17] delete cuda graph disabled test --- .github/workflows/pr-test.yml | 20 ++++++++++++++++++++ test/srt/run_suite.py | 4 +++- test/srt/test_fa3.py | 14 -------------- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index ee0b10c0ec9..842321262e4 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -93,6 +93,26 @@ jobs: cd test/srt python3 run_suite.py --suite per-commit-2-gpu + unit-test-backend-8-gpu: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: 8-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + env: + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} + run: | + bash scripts/ci_install_dependency.sh + + - name: Run test + timeout-minutes: 30 + run: | + cd test/srt + python3 run_suite.py --suite per-commit-8-gpu + performance-test-1-gpu-part-1: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3c94b2ba365..c2f41a5f60a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -28,7 +28,6 @@ class TestFile: TestFile("test_chunked_prefill.py", 336), TestFile("test_eagle_infer.py", 500), TestFile("test_ebnf_constrained.py"), - TestFile("test_fa3.py", 5), TestFile("test_fp8_kernel.py", 8), TestFile("test_embedding_openai_server.py", 36), TestFile("test_hidden_states.py", 55), @@ -89,6 +88,9 @@ class TestFile: TestFile("test_update_weights_from_distributed.py", 100), TestFile("test_verl_engine.py", 100), ], + "per-commit-8-gpu": [ + TestFile("test_fa3.py", 30), + ], "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), ], diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index c1c6e275b3e..f91bc951332 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -89,20 +89,6 @@ def get_server_args(cls): return args -class TestFlashAttention3DisableCudaGraph(BaseFlashAttentionTest): - """Test FlashAttention3 with CUDA graph disabled.""" - - @classmethod - def get_server_args(cls): - args = super().get_server_args() - args.extend( - [ - "--disable-cuda-graph", - ] - ) - return args - - class TestFlashAttention3MLA(BaseFlashAttentionTest): """Test FlashAttention3 with MLA.""" From edb689bee1fb1a895292be8c908bbf94b594932a Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Sat, 19 Apr 2025 04:00:47 +0000 Subject: [PATCH 06/17] fix --- python/sglang/test/test_utils.py | 5 +- test/srt/test_fa3.py | 205 ++++++++++++------------------- 2 files changed, 84 insertions(+), 126 deletions(-) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 13e159a794e..d64ec02c5dd 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -44,8 +44,11 @@ ) DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" +DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" +DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN" DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" -DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST = ( +DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = ( "meta-llama/Llama-4-Scout-17B-16E-Instruct" ) DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index f91bc951332..11c948678e1 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -6,42 +6,79 @@ from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST # llama 3.1 8B +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, # eagle3 llama 3.1 8B +) +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, # llama 4 scout 17B 16E +) +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST_MLA # deepseek v3 test +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, # deepseek v3 test nextN +) from sglang.test.test_utils import ( - DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST, - DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) +GSM_DATASET_PATH = None + +# In case of some machine lack internet connection, we can set OFFLINE_MODE to True. +OFFLINE_MODE = True + +# Change the path below when OFFLINE_MODE is True. +OFFLINE_PATH_DICT = { + DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct/07eb05b21d191a58c577b4a45982fe0c049d0693", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B/e5ed08d66f528a95ce89f5d4fd136a28f6def714", + DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/bhe/deepseek/dsv3-test/snapshots/fed995305b0e4f9acacb74bc0c17787a966bf42a/", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/bhe/deepseek/dsv3-test-NextN/snapshots/981e68c48968bacb228b645b2b8ddf69f98ee5a6/", + GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl", +} + + +if OFFLINE_MODE: + DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST] + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = OFFLINE_PATH_DICT[ + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 + ] + DEFAULT_MODEL_NAME_FOR_TEST_MLA = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST_MLA] + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = OFFLINE_PATH_DICT[ + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN + ] + GSM_DATASET_PATH = OFFLINE_PATH_DICT[GSM_DATASET_PATH] + + +# Default server arguments shared across all tests +DEFAULT_SERVER_ARGS = [ + "--trust-remote-code", + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + "--attention-backend", + "fa3", +] + """ Integration test for python/sglang/srt/layers/attention/flashattention_backend.py """ -# Change to your own model if testing model is not public. -MODEL_USED_FOR_TEST = DEFAULT_MODEL_NAME_FOR_TEST -MODEL_USED_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" -# Setting data path to None uses default data path in few_shot_gsm8k eval test. -DATA_PATH = None @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") class BaseFlashAttentionTest(unittest.TestCase): - """Base class for FlashAttention tests to reduce code duplication.""" + """Test FlashAttention3 with MHA model, also it will be base class for derived tests.""" - model = MODEL_USED_FOR_TEST + model = DEFAULT_MODEL_NAME_FOR_TEST base_url = DEFAULT_URL_FOR_TEST - accuracy_threshold = 0.62 + accuracy_threshold = 0.65 # derived tests need to override this + speculative_decode = False + spec_decode_threshold = 1.0 # derived spec decoding tests need to override this @classmethod def get_server_args(cls): """Return the arguments for the server launch. Override in subclasses.""" - args = [ - "--trust-remote-code", - "--enable-torch-compile", - "--attention-backend", - "fa3", - ] - return args + return DEFAULT_SERVER_ARGS @classmethod def setUpClass(cls): @@ -58,13 +95,13 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - num_questions=200, + num_shots=4, + num_questions=100, max_new_tokens=512, parallel=128, host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), - data_path=DATA_PATH, + data_path=GSM_DATASET_PATH, ) metrics = run_eval_few_shot_gsm8k(args) print(metrics) @@ -73,75 +110,40 @@ def test_gsm8k(self): metric_key = "accuracy" self.assertGreater(metrics[metric_key], self.accuracy_threshold) - -class TestFlashAttention3(BaseFlashAttentionTest): - """Test FlashAttention3 with MLA model and CUDA graph enabled.""" - - @classmethod - def get_server_args(cls): - args = super().get_server_args() - args.extend( - [ - "--cuda-graph-max-bs", - "2", - ] - ) - return args + if self.speculative_decode: + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold) class TestFlashAttention3MLA(BaseFlashAttentionTest): - """Test FlashAttention3 with MLA.""" + """Test FlashAttention3 with MLA, e.g. deepseek v3 test model""" - model = MODEL_USED_FOR_TEST_MLA - - @classmethod - def get_server_args(cls): - args = super().get_server_args() - args.extend( - [ - "--cuda-graph-max-bs", - "2", - ] - ) - return args + accuracy_threshold = 0.60 + model = DEFAULT_MODEL_NAME_FOR_TEST_MLA class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): """Test FlashAttention3 with Model with local attention, e.g. Llama 4.""" - model = DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST + self.accuracy_threshold = 0.70 + self.model = DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST @classmethod def get_server_args(cls): args = super().get_server_args() - args.extend(["--cuda-graph-max-bs", "2", "--tp", "4"]) + args.extend(["--tp", "4"]) return args - def test_gsm8k(self): - """ - Override the test_gsm8k to further test for average speculative accept length. - """ - requests.get(self.base_url + "/flush_cache") - - args = SimpleNamespace( - num_shots=5, - data_path=DATA_PATH, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.70) - class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): - """Test FlashAttention3 with speculative decode enabled.""" + """Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model""" - model = "meta-llama/Llama-3.1-8B-Instruct" + model = DEFAULT_MODEL_NAME_FOR_TEST + accuracy_threshold = 0.65 + speculative_decode = True + spec_decode_threshold = 1.5 @classmethod def get_server_args(cls): @@ -153,7 +155,7 @@ def get_server_args(cls): "--speculative-algorithm", "EAGLE3", "--speculative-draft", - "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, "--speculative-num-steps", "3", "--speculative-eagle-topk", @@ -166,36 +168,14 @@ def get_server_args(cls): ) return args - def test_gsm8k(self): - """ - Override the test_gsm8k to further test for average speculative accept length. - """ - requests.get(self.base_url + "/flush_cache") - - args = SimpleNamespace( - num_shots=5, - data_path=DATA_PATH, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 1.5) - class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): - """Test FlashAttention3 with speculative decode enabled.""" + """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model""" - model = MODEL_USED_FOR_TEST_MLA + model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + accuracy_threshold = 0.60 + speculative_decode = True + spec_decode_threshold = 1.5 @classmethod def get_server_args(cls): @@ -207,7 +187,7 @@ def get_server_args(cls): "--speculative-algorithm", "EAGLE", "--speculative-draft", - "lmsys/sglang-ci-dsv3-test-NextN", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, "--speculative-num-steps", "3", "--speculative-eagle-topk", @@ -218,31 +198,6 @@ def get_server_args(cls): ) return args - def test_gsm8k(self): - """ - Override the test_gsm8k to further test for average speculative accept length. - """ - requests.get(self.base_url + "/flush_cache") - - args = SimpleNamespace( - num_shots=5, - data_path=DATA_PATH, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 1.5) - if __name__ == "__main__": unittest.main() From 0141af63024ec7b253f88ab32efe129b342e33ac Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Sat, 19 Apr 2025 05:10:13 +0000 Subject: [PATCH 07/17] fix --- test/srt/test_fa3.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 11c948678e1..b7c317afaf2 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -26,14 +26,14 @@ GSM_DATASET_PATH = None # In case of some machine lack internet connection, we can set OFFLINE_MODE to True. -OFFLINE_MODE = True +OFFLINE_MODE = False # Change the path below when OFFLINE_MODE is True. OFFLINE_PATH_DICT = { - DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct/07eb05b21d191a58c577b4a45982fe0c049d0693", - DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B/e5ed08d66f528a95ce89f5d4fd136a28f6def714", - DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/bhe/deepseek/dsv3-test/snapshots/fed995305b0e4f9acacb74bc0c17787a966bf42a/", - DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/bhe/deepseek/dsv3-test-NextN/snapshots/981e68c48968bacb228b645b2b8ddf69f98ee5a6/", + DEFAULT_MODEL_NAME_FOR_TEST: "meta-llama/Meta-Llama-3.1-8B-Instruct", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + DEFAULT_MODEL_NAME_FOR_TEST_MLA: "deepseek/dsv3-test/snapshots/", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "deepseek/dsv3-test-NextN/snapshots/", GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl", } From 24199dbb2bace60c3b528d0fc198d60508e37f9a Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Mon, 21 Apr 2025 00:13:10 +0000 Subject: [PATCH 08/17] fix --- test/srt/test_fa3.py | 91 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 75 insertions(+), 16 deletions(-) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index b7c317afaf2..021c85c41f6 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -2,22 +2,15 @@ from types import SimpleNamespace import requests -import torch from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k -from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST # llama 3.1 8B -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, # eagle3 llama 3.1 8B -) -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, # llama 4 scout 17B 16E -) -from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST_MLA # deepseek v3 test -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, # deepseek v3 test nextN -) from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, + DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -127,8 +120,8 @@ class TestFlashAttention3MLA(BaseFlashAttentionTest): class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): """Test FlashAttention3 with Model with local attention, e.g. Llama 4.""" - self.accuracy_threshold = 0.70 - self.model = DEFAULT_LOCAL_ATTENTION_MODEL_NAME_FOR_TEST + accuracy_threshold = 0.70 + model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION @classmethod def get_server_args(cls): @@ -161,7 +154,41 @@ def get_server_args(cls): "--speculative-eagle-topk", "1", "--speculative-num-draft-tokens", - "3", + "4", + "--dtype", + "float16", + ] + ) + return args + + +class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): + """Tests FlashAttention3 with enhanced speculative decoding using Llama 3.1 8B and EAGLE3. + This test will be using top-k value > 1 which would verify the other branches of the FA3 code + """ + + model = DEFAULT_MODEL_NAME_FOR_TEST + accuracy_threshold = 0.65 + speculative_decode = True + spec_decode_threshold = 1.5 + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--speculative-algorithm", + "EAGLE3", + "--speculative-draft", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "8", "--dtype", "float16", ] @@ -193,7 +220,39 @@ def get_server_args(cls): "--speculative-eagle-topk", "1", "--speculative-num-draft-tokens", - "3", + "4", + ] + ) + return args + + +class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest): + """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model + This test will be using top-k value > 1 which would verify the other branches of the FA3 code + """ + + model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + accuracy_threshold = 0.60 + speculative_decode = True + spec_decode_threshold = 1.5 + + @classmethod + def get_server_args(cls): + args = super().get_server_args() + args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "8", ] ) return args From 88200a5f7cd20f118496a8e52619379a56ba205c Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Mon, 21 Apr 2025 00:24:53 +0000 Subject: [PATCH 09/17] fix --- test/srt/test_fa3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 021c85c41f6..fe55829fb33 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -2,6 +2,7 @@ from types import SimpleNamespace import requests +import torch from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k From 9ddc78385a1020dc0c5cc8797d20ecbb214bad55 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Mon, 21 Apr 2025 00:50:51 +0000 Subject: [PATCH 10/17] fix --- test/srt/test_fa3.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index fe55829fb33..6c7302ef83f 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -61,7 +61,7 @@ @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") class BaseFlashAttentionTest(unittest.TestCase): - """Test FlashAttention3 with MHA model, also it will be base class for derived tests.""" + """Base class for testing FlashAttention3.""" model = DEFAULT_MODEL_NAME_FOR_TEST base_url = DEFAULT_URL_FOR_TEST @@ -117,6 +117,10 @@ class TestFlashAttention3MLA(BaseFlashAttentionTest): accuracy_threshold = 0.60 model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + @classmethod + def get_server_args(cls): + DEFAULT_SERVER_ARGS + class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): """Test FlashAttention3 with Model with local attention, e.g. Llama 4.""" @@ -126,9 +130,9 @@ class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): @classmethod def get_server_args(cls): - args = super().get_server_args() - args.extend(["--tp", "4"]) - return args + cloned_args = DEFAULT_SERVER_ARGS.deepcopy() + cloned_args.extend(["--tp", "4"]) + return cloned_args class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): @@ -141,7 +145,7 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): @classmethod def get_server_args(cls): - args = super().get_server_args() + args = DEFAULT_SERVER_ARGS args.extend( [ "--cuda-graph-max-bs", @@ -175,7 +179,7 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): @classmethod def get_server_args(cls): - args = super().get_server_args() + args = DEFAULT_SERVER_ARGS args.extend( [ "--cuda-graph-max-bs", @@ -207,7 +211,7 @@ class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): @classmethod def get_server_args(cls): - args = super().get_server_args() + args = DEFAULT_SERVER_ARGS args.extend( [ "--cuda-graph-max-bs", @@ -239,7 +243,7 @@ class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest): @classmethod def get_server_args(cls): - args = super().get_server_args() + args = DEFAULT_SERVER_ARGS args.extend( [ "--cuda-graph-max-bs", From 3b4e9e7b3fd149b6c31d4ab7c9f33f929d2e5b22 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Mon, 21 Apr 2025 01:12:25 +0000 Subject: [PATCH 11/17] fix --- test/srt/test_fa3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 6c7302ef83f..c1b85876714 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -119,7 +119,7 @@ class TestFlashAttention3MLA(BaseFlashAttentionTest): @classmethod def get_server_args(cls): - DEFAULT_SERVER_ARGS + return DEFAULT_SERVER_ARGS class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): @@ -130,7 +130,7 @@ class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): @classmethod def get_server_args(cls): - cloned_args = DEFAULT_SERVER_ARGS.deepcopy() + cloned_args = DEFAULT_SERVER_ARGS.copy() cloned_args.extend(["--tp", "4"]) return cloned_args From db662b4603a88e49377611d7e9dc8e5ac50464e0 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Tue, 22 Apr 2025 17:24:14 +0000 Subject: [PATCH 12/17] fix --- python/sglang/srt/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 087ef3f3969..ce25e2f8d3a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1101,7 +1101,7 @@ def get_amdgpu_memory_capacity(): def get_device_sm(): - print(torch.cuda.is_available()) + print(f"torch.cuda.is_available(): {torch.cuda.is_available()}") if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() print(f"the version is {major} {minor}") From d365cfdd786fcb315bcc6aa352c1aa87f79a72a0 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Tue, 22 Apr 2025 17:36:08 +0000 Subject: [PATCH 13/17] fix --- python/sglang/srt/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ce25e2f8d3a..207e5c5f3cc 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1101,6 +1101,8 @@ def get_amdgpu_memory_capacity(): def get_device_sm(): + # print python interpreter path + print(f"python interpreter path: {sys.executable}") print(f"torch.cuda.is_available(): {torch.cuda.is_available()}") if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() From 8b14fce4c722e897cf818077c053fbdd6ad20a1c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 23 Apr 2025 02:39:52 -0700 Subject: [PATCH 14/17] Update utils.py --- python/sglang/srt/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d5a8f6700f3..1796e237887 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1103,6 +1103,7 @@ def get_amdgpu_memory_capacity(): def get_device_sm(): # print python interpreter path print(f"python interpreter path: {sys.executable}") + print(f"torch.cuda.is_available(): {torch.cuda.is_available()}") if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() From bf987243fefa082a17bea79511e1fe9d1510465a Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Wed, 23 Apr 2025 23:47:35 +0000 Subject: [PATCH 15/17] fix --- python/sglang/srt/utils.py | 5 ----- test/srt/test_fa3.py | 24 ++++++++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1796e237887..ba6bb61402c 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1101,13 +1101,8 @@ def get_amdgpu_memory_capacity(): def get_device_sm(): - # print python interpreter path - print(f"python interpreter path: {sys.executable}") - - print(f"torch.cuda.is_available(): {torch.cuda.is_available()}") if torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability() - print(f"the version is {major} {minor}") return major * 10 + minor return 0 diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 63f21569f5e..b2da6dae35d 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -1,3 +1,4 @@ +import os import unittest from types import SimpleNamespace @@ -20,19 +21,18 @@ GSM_DATASET_PATH = None # In case of some machine lack internet connection, we can set OFFLINE_MODE to True. -OFFLINE_MODE = True +OFFLINE_MODE = False # Change the path below when OFFLINE_MODE is True. OFFLINE_PATH_DICT = { - DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct/07eb05b21d191a58c577b4a45982fe0c049d0693", - DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B/e5ed08d66f528a95ce89f5d4fd136a28f6def714", - DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/bhe/deepseek/dsv3-test/snapshots/fed995305b0e4f9acacb74bc0c17787a966bf42a/", - DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/bhe/deepseek/dsv3-test-NextN/snapshots/981e68c48968bacb228b645b2b8ddf69f98ee5a6/", - GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl" + DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/deepseek/dsv3-test/snapshots/", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/deepseek/dsv3-test-NextN/snapshots/", + GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl", } - if OFFLINE_MODE: DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST] DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = OFFLINE_PATH_DICT[ @@ -77,11 +77,15 @@ def get_server_args(cls): @classmethod def setUpClass(cls): + # disable deep gemm precompile to make launch server faster + # please don't do this if you want to make your inference workload faster + os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "False" cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=cls.get_server_args(), + env=os.environ, ) @classmethod @@ -205,7 +209,7 @@ def get_server_args(cls): class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): """Test FlashAttention3 with speculative decode enabled, topk > 1""" - model = "meta-llama/Llama-3.1-8B-Instruct" + model = DEFAULT_MODEL_NAME_FOR_TEST @classmethod def get_server_args(cls): @@ -217,7 +221,7 @@ def get_server_args(cls): "--speculative-algorithm", "EAGLE3", "--speculative-draft", - "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, "--speculative-num-steps", "5", "--speculative-eagle-topk", @@ -238,7 +242,7 @@ def test_gsm8k(self): args = SimpleNamespace( num_shots=5, - data_path=DATA_PATH, + data_path=GSM_DATASET_PATH, num_questions=200, max_new_tokens=512, parallel=128, From 414f3115265e51203fd1b5f6ff21af827900bd37 Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Fri, 25 Apr 2025 23:33:54 +0000 Subject: [PATCH 16/17] adjust context len for llama 4 to avoid oom --- test/srt/test_fa3.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index b2da6dae35d..23b1438a738 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -136,7 +136,8 @@ class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): @classmethod def get_server_args(cls): cloned_args = DEFAULT_SERVER_ARGS.copy() - cloned_args.extend(["--tp", "4"]) + # we cannot use scout's 10m context due to this bug: https://github.com/sgl-project/sglang/issues/5755 + cloned_args.extend(["--tp", "4", "--context-length", "1000000"]) return cloned_args From 358fc9cefbb30786fbd1b038afc24f8acb0df33f Mon Sep 17 00:00:00 2001 From: hebiao064 Date: Sat, 26 Apr 2025 04:19:14 +0000 Subject: [PATCH 17/17] disable torch compile --- test/srt/test_fa3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 23b1438a738..d02c799facd 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -136,6 +136,8 @@ class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): @classmethod def get_server_args(cls): cloned_args = DEFAULT_SERVER_ARGS.copy() + # remove --enable-torch-compile from cloned_args since llama4 does not support it for now + cloned_args.remove("--enable-torch-compile") # we cannot use scout's 10m context due to this bug: https://github.com/sgl-project/sglang/issues/5755 cloned_args.extend(["--tp", "4", "--context-length", "1000000"]) return cloned_args