diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 0c71bf05bff..368c205a236 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -87,6 +87,26 @@ jobs: cd test/srt python3 run_suite.py --suite per-commit-2-gpu + unit-test-backend-8-gpu: + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && + github.event.pull_request.draft == false + runs-on: 8-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + env: + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.5/flashinfer-python' || 'https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python' }} + run: | + bash scripts/ci_install_dependency.sh + + - name: Run test + timeout-minutes: 30 + run: | + cd test/srt + python3 run_suite.py --suite per-commit-8-gpu + performance-test-1-gpu-part-1: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 7d8899b72fa..e57f9ce6b7c 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -44,7 +44,13 @@ ) DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" +DEFAULT_MODEL_NAME_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" +DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = "lmsys/sglang-ci-dsv3-test-NextN" DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" +DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = ( + "meta-llama/Llama-4-Scout-17B-16E-Instruct" +) DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B" DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 80775e866ab..69b3db2a39f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -30,7 +30,6 @@ class TestFile: TestFile("test_chunked_prefill.py", 336), TestFile("test_eagle_infer.py", 500), TestFile("test_ebnf_constrained.py"), - TestFile("test_fa3.py", 200), TestFile("test_fp8_kernel.py", 8), TestFile("test_embedding_openai_server.py", 36), TestFile("test_hidden_states.py", 55), @@ -91,6 +90,9 @@ class TestFile: TestFile("test_update_weights_from_distributed.py", 100), TestFile("test_verl_engine.py", 100), ], + "per-commit-8-gpu": [ + TestFile("test_fa3.py", 30), + ], "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), ], diff --git a/test/srt/test_fa3.py b/test/srt/test_fa3.py index 886e19db1d2..d02c799facd 100644 --- a/test/srt/test_fa3.py +++ b/test/srt/test_fa3.py @@ -1,3 +1,4 @@ +import os import unittest from types import SimpleNamespace @@ -8,47 +9,83 @@ from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, + DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, + DEFAULT_MODEL_NAME_FOR_TEST_MLA, + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) +GSM_DATASET_PATH = None + +# In case of some machine lack internet connection, we can set OFFLINE_MODE to True. +OFFLINE_MODE = False + +# Change the path below when OFFLINE_MODE is True. +OFFLINE_PATH_DICT = { + DEFAULT_MODEL_NAME_FOR_TEST: "/shared/public/elr-models/meta-llama/Meta-Llama-3.1-8B-Instruct", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3: "/shared/public/elr-models/jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + DEFAULT_MODEL_NAME_FOR_TEST_MLA: "/shared/public/sharing/deepseek/dsv3-test/snapshots/", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN: "/shared/public/sharing/deepseek/dsv3-test-NextN/snapshots/", + GSM_DATASET_PATH: "/shared/public/data/gsm8k/test.jsonl", +} + + +if OFFLINE_MODE: + DEFAULT_MODEL_NAME_FOR_TEST = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST] + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = OFFLINE_PATH_DICT[ + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 + ] + DEFAULT_MODEL_NAME_FOR_TEST_MLA = OFFLINE_PATH_DICT[DEFAULT_MODEL_NAME_FOR_TEST_MLA] + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN = OFFLINE_PATH_DICT[ + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN + ] + GSM_DATASET_PATH = OFFLINE_PATH_DICT[GSM_DATASET_PATH] + + +# Default server arguments shared across all tests +DEFAULT_SERVER_ARGS = [ + "--trust-remote-code", + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + "--attention-backend", + "fa3", +] + """ Integration test for python/sglang/srt/layers/attention/flashattention_backend.py """ -# Change to your own model if testing model is not public. -MODEL_USED_FOR_TEST = DEFAULT_MODEL_NAME_FOR_TEST -MODEL_USED_FOR_TEST_MLA = "lmsys/sglang-ci-dsv3-test" -# Setting data path to None uses default data path in few_shot_gsm8k eval test. -DATA_PATH = None @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") class BaseFlashAttentionTest(unittest.TestCase): - """Base class for FlashAttention tests to reduce code duplication.""" + """Base class for testing FlashAttention3.""" - model = MODEL_USED_FOR_TEST + model = DEFAULT_MODEL_NAME_FOR_TEST base_url = DEFAULT_URL_FOR_TEST - accuracy_threshold = 0.62 + accuracy_threshold = 0.65 # derived tests need to override this + speculative_decode = False + spec_decode_threshold = 1.0 # derived spec decoding tests need to override this @classmethod def get_server_args(cls): """Return the arguments for the server launch. Override in subclasses.""" - args = [ - "--trust-remote-code", - "--enable-torch-compile", - "--attention-backend", - "fa3", - ] - return args + return DEFAULT_SERVER_ARGS @classmethod def setUpClass(cls): + # disable deep gemm precompile to make launch server faster + # please don't do this if you want to make your inference workload faster + os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "False" cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=cls.get_server_args(), + env=os.environ, ) @classmethod @@ -57,13 +94,13 @@ def tearDownClass(cls): def test_gsm8k(self): args = SimpleNamespace( - num_shots=5, - num_questions=200, + num_shots=4, + num_questions=100, max_new_tokens=512, parallel=128, host="http://127.0.0.1", port=int(self.base_url.split(":")[-1]), - data_path=DATA_PATH, + data_path=GSM_DATASET_PATH, ) metrics = run_eval_few_shot_gsm8k(args) print(metrics) @@ -72,61 +109,85 @@ def test_gsm8k(self): metric_key = "accuracy" self.assertGreater(metrics[metric_key], self.accuracy_threshold) + if self.speculative_decode: + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold) + + +class TestFlashAttention3MLA(BaseFlashAttentionTest): + """Test FlashAttention3 with MLA, e.g. deepseek v3 test model""" -class TestFlashAttention3(BaseFlashAttentionTest): - """Test FlashAttention3 with MLA model and CUDA graph enabled.""" + accuracy_threshold = 0.60 + model = DEFAULT_MODEL_NAME_FOR_TEST_MLA @classmethod def get_server_args(cls): - args = super().get_server_args() - args.extend( - [ - "--cuda-graph-max-bs", - "2", - ] - ) - return args + return DEFAULT_SERVER_ARGS + +class TestFlashAttention3LocalAttn(BaseFlashAttentionTest): + """Test FlashAttention3 with Model with local attention, e.g. Llama 4.""" -class TestFlashAttention3DisableCudaGraph(BaseFlashAttentionTest): - """Test FlashAttention3 with CUDA graph disabled.""" + accuracy_threshold = 0.70 + model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION @classmethod def get_server_args(cls): - args = super().get_server_args() - args.extend( - [ - "--disable-cuda-graph", - ] - ) - return args + cloned_args = DEFAULT_SERVER_ARGS.copy() + # remove --enable-torch-compile from cloned_args since llama4 does not support it for now + cloned_args.remove("--enable-torch-compile") + # we cannot use scout's 10m context due to this bug: https://github.com/sgl-project/sglang/issues/5755 + cloned_args.extend(["--tp", "4", "--context-length", "1000000"]) + return cloned_args -class TestFlashAttention3MLA(BaseFlashAttentionTest): - """Test FlashAttention3 with MLA.""" +class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): + """Test FlashAttention3 with speculative decode enabled with Llama 3.1 8B and its eagle3 model""" - model = MODEL_USED_FOR_TEST_MLA + model = DEFAULT_MODEL_NAME_FOR_TEST + accuracy_threshold = 0.65 + speculative_decode = True + spec_decode_threshold = 1.5 @classmethod def get_server_args(cls): - args = super().get_server_args() + args = DEFAULT_SERVER_ARGS args.extend( [ "--cuda-graph-max-bs", "2", + "--speculative-algorithm", + "EAGLE3", + "--speculative-draft", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--dtype", + "float16", ] ) return args -class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest): - """Test FlashAttention3 with speculative decode enabled.""" +class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): + """Tests FlashAttention3 with enhanced speculative decoding using Llama 3.1 8B and EAGLE3. + This test will be using top-k value > 1 which would verify the other branches of the FA3 code + """ - model = "meta-llama/Llama-3.1-8B-Instruct" + model = DEFAULT_MODEL_NAME_FOR_TEST + accuracy_threshold = 0.65 + speculative_decode = True + spec_decode_threshold = 1.5 @classmethod def get_server_args(cls): - args = super().get_server_args() + args = DEFAULT_SERVER_ARGS args.extend( [ "--cuda-graph-max-bs", @@ -134,49 +195,24 @@ def get_server_args(cls): "--speculative-algorithm", "EAGLE3", "--speculative-draft", - "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, "--speculative-num-steps", - "3", + "5", "--speculative-eagle-topk", - "1", + "4", "--speculative-num-draft-tokens", - "3", + "8", "--dtype", "float16", ] ) return args - def test_gsm8k(self): - """ - Override the test_gsm8k to further test for average speculative accept length. - """ - requests.get(self.base_url + "/flush_cache") - - args = SimpleNamespace( - num_shots=5, - data_path=DATA_PATH, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) - - self.assertGreater(metrics["accuracy"], 0.60) - - server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 1.5) - class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest): """Test FlashAttention3 with speculative decode enabled, topk > 1""" - model = "meta-llama/Llama-3.1-8B-Instruct" + model = DEFAULT_MODEL_NAME_FOR_TEST @classmethod def get_server_args(cls): @@ -188,7 +224,7 @@ def get_server_args(cls): "--speculative-algorithm", "EAGLE3", "--speculative-draft", - "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, "--speculative-num-steps", "5", "--speculative-eagle-topk", @@ -209,7 +245,7 @@ def test_gsm8k(self): args = SimpleNamespace( num_shots=5, - data_path=DATA_PATH, + data_path=GSM_DATASET_PATH, num_questions=200, max_new_tokens=512, parallel=128, @@ -228,13 +264,16 @@ def test_gsm8k(self): class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): - """Test FlashAttention3 with speculative decode enabled.""" + """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model""" - model = MODEL_USED_FOR_TEST_MLA + model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + accuracy_threshold = 0.60 + speculative_decode = True + spec_decode_threshold = 1.5 @classmethod def get_server_args(cls): - args = super().get_server_args() + args = DEFAULT_SERVER_ARGS args.extend( [ "--cuda-graph-max-bs", @@ -242,41 +281,48 @@ def get_server_args(cls): "--speculative-algorithm", "EAGLE", "--speculative-draft", - "lmsys/sglang-ci-dsv3-test-NextN", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, "--speculative-num-steps", "3", "--speculative-eagle-topk", "1", "--speculative-num-draft-tokens", - "3", + "4", ] ) return args - def test_gsm8k(self): - """ - Override the test_gsm8k to further test for average speculative accept length. - """ - requests.get(self.base_url + "/flush_cache") - args = SimpleNamespace( - num_shots=5, - data_path=DATA_PATH, - num_questions=200, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval_few_shot_gsm8k(args) - print(metrics) +class TestFlashAttention3MLASpeculativeDecodeTopk(BaseFlashAttentionTest): + """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model + This test will be using top-k value > 1 which would verify the other branches of the FA3 code + """ - self.assertGreater(metrics["accuracy"], 0.60) + model = DEFAULT_MODEL_NAME_FOR_TEST_MLA + accuracy_threshold = 0.60 + speculative_decode = True + spec_decode_threshold = 1.5 - server_info = requests.get(self.base_url + "/get_server_info") - avg_spec_accept_length = server_info.json()["avg_spec_accept_length"] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 1.5) + @classmethod + def get_server_args(cls): + args = DEFAULT_SERVER_ARGS + args.extend( + [ + "--cuda-graph-max-bs", + "2", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft", + DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "4", + "--speculative-num-draft-tokens", + "8", + ] + ) + return args if __name__ == "__main__":