Skip to content
Closed
2 changes: 1 addition & 1 deletion python/sglang/srt/models/mistral_large_3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch import nn
from transformers import PretrainedConfig

from python.sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.attention.nsa.utils import is_nsa_enable_prefill_cp
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import RowParallelLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Expand Down
1 change: 1 addition & 0 deletions scripts/ci/slash_command_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def handle_rerun_stage(
"unit-test-deepep-8-gpu",
"unit-test-backend-4-gpu-b200",
"unit-test-backend-4-gpu-gb200",
"unit-test-backend-8-gpu-b200",
]

# Valid AMD stage names that support target_stage
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
],
"per-commit-8-gpu-b200": [
TestFile("test_mistral_large3_basic.py", 275),
TestFile("test_mistral_large3_eagle_basic.py", 275),
],
"per-commit-4-gpu-gb200": [
TestFile("test_cutedsl_moe.py", 300),
Expand Down
105 changes: 105 additions & 0 deletions test/srt/test_mistral_large3_eagle_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.send_one import BenchArgs, send_one_prompt
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)

# Base model and Eagle draft model
MISTRAL_LARGE3_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512"
MISTRAL_LARGE3_EAGLE_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle"


class TestMistralLarge3EagleBasic(CustomTestCase):
@classmethod
def setUpClass(cls):
# Set environment variable to disable JIT DeepGemm
os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"] = "0"

cls.model = MISTRAL_LARGE3_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
# Mistral-Large-3 with Eagle speculative decoding
# Eagle model is used as draft model for speculative decoding
other_args = [
"--tp",
"8",
"--attention-backend",
"trtllm_mla",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
MISTRAL_LARGE3_EAGLE_MODEL_PATH,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
"--kv-cache-dtype",
"auto",
"--model-loader-extra-config",
'{"enable_multithread_load": true}',
"--chat-template",
"mistral",
]
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH * 5,
other_args=other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
# Clean up environment variable
if "SGLANG_ENABLE_JIT_DEEPGEMM" in os.environ:
del os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"]
Comment on lines +63 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a try...finally block ensures that the environment variable is cleaned up even if kill_process_tree raises an exception. This makes the test cleanup more robust.

        try:
            kill_process_tree(cls.process.pid)
        finally:
            # Clean up environment variable
            if "SGLANG_ENABLE_JIT_DEEPGEMM" in os.environ:
                del os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"]


def test_a_gsm8k(
self,
): # Append an "a" to make this test run first (alphabetically) to warm up the server
args = SimpleNamespace(
num_shots=8,
data_path=None,
num_questions=1400,
parallel=1400,
max_new_tokens=512,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using urlparse is more robust for extracting the port from the base URL compared to string splitting. This avoids potential issues if the URL format changes (e.g., using IPv6 addresses).

Please also add from urllib.parse import urlparse at the top of the file.

            port=urlparse(self.base_url).port,

)
metrics = run_eval_few_shot_gsm8k(args)
print(f"{metrics=}")

if is_in_ci():
write_github_step_summary(
f"### test_gsm8k (mistral-large-3-eagle)\n"
f'{metrics["accuracy"]=:.3f}\n'
)
self.assertGreater(metrics["accuracy"], 0.90)

def test_bs_1_speed(self):
args = BenchArgs(port=int(self.base_url.split(":")[-1]), max_new_tokens=2048)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similarly, using urlparse here for both host and port improves robustness and consistency. This avoids relying on the default host value in BenchArgs.

        parsed_url = urlparse(self.base_url)
        args = BenchArgs(host=parsed_url.hostname, port=parsed_url.port, max_new_tokens=2048)

acc_length, speed = send_one_prompt(args)

print(f"{speed=:.2f}")

if is_in_ci():
write_github_step_summary(
f"### test_bs_1_speed (mistral-large-3-eagle)\n"
f"{speed=:.2f} token/s\n"
)
self.assertGreater(speed, 50)


if __name__ == "__main__":
unittest.main()
Loading