diff --git a/python/sglang/srt/models/mistral_large_3_eagle.py b/python/sglang/srt/models/mistral_large_3_eagle.py index f136640fde78..08f7271fde6c 100644 --- a/python/sglang/srt/models/mistral_large_3_eagle.py +++ b/python/sglang/srt/models/mistral_large_3_eagle.py @@ -4,8 +4,8 @@ from torch import nn from transformers import PretrainedConfig -from python.sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import RowParallelLinear from sglang.srt.layers.quantization.base_config import QuantizationConfig diff --git a/scripts/ci/slash_command_handler.py b/scripts/ci/slash_command_handler.py index dfd1df1b0ffe..5826d96bcc12 100644 --- a/scripts/ci/slash_command_handler.py +++ b/scripts/ci/slash_command_handler.py @@ -161,6 +161,7 @@ def handle_rerun_stage( "unit-test-deepep-8-gpu", "unit-test-backend-4-gpu-b200", "unit-test-backend-4-gpu-gb200", + "unit-test-backend-8-gpu-b200", ] # Valid AMD stage names that support target_stage diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e17c25fcdae6..c5f7fedbb29b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -182,6 +182,7 @@ ], "per-commit-8-gpu-b200": [ TestFile("test_mistral_large3_basic.py", 275), + TestFile("test_mistral_large3_eagle_basic.py", 275), ], "per-commit-4-gpu-gb200": [ TestFile("test_cutedsl_moe.py", 300), diff --git a/test/srt/test_mistral_large3_eagle_basic.py b/test/srt/test_mistral_large3_eagle_basic.py new file mode 100644 index 000000000000..6d923054c4f0 --- /dev/null +++ b/test/srt/test_mistral_large3_eagle_basic.py @@ -0,0 +1,105 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.send_one import BenchArgs, send_one_prompt +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, + write_github_step_summary, +) + +# Base model and Eagle draft model +MISTRAL_LARGE3_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512" +MISTRAL_LARGE3_EAGLE_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle" + + +class TestMistralLarge3EagleBasic(CustomTestCase): + @classmethod + def setUpClass(cls): + # Set environment variable to disable JIT DeepGemm + os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"] = "0" + + cls.model = MISTRAL_LARGE3_MODEL_PATH + cls.base_url = DEFAULT_URL_FOR_TEST + # Mistral-Large-3 with Eagle speculative decoding + # Eagle model is used as draft model for speculative decoding + other_args = [ + "--tp", + "8", + "--attention-backend", + "trtllm_mla", + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + MISTRAL_LARGE3_EAGLE_MODEL_PATH, + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + "--kv-cache-dtype", + "auto", + "--model-loader-extra-config", + '{"enable_multithread_load": true}', + "--chat-template", + "mistral", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 5, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + # Clean up environment variable + if "SGLANG_ENABLE_JIT_DEEPGEMM" in os.environ: + del os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"] + + def test_a_gsm8k( + self, + ): # Append an "a" to make this test run first (alphabetically) to warm up the server + args = SimpleNamespace( + num_shots=8, + data_path=None, + num_questions=1400, + parallel=1400, + max_new_tokens=512, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + if is_in_ci(): + write_github_step_summary( + f"### test_gsm8k (mistral-large-3-eagle)\n" + f'{metrics["accuracy"]=:.3f}\n' + ) + self.assertGreater(metrics["accuracy"], 0.90) + + def test_bs_1_speed(self): + args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048) + acc_length, speed = send_one_prompt(args) + + print(f"{speed=:.2f}") + + if is_in_ci(): + write_github_step_summary( + f"### test_bs_1_speed (mistral-large-3-eagle)\n" + f"{speed=:.2f} token/s\n" + ) + self.assertGreater(speed, 50) + + +if __name__ == "__main__": + unittest.main()