Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/nightly-test-nvidia.yml
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,7 @@ jobs:
SGLANG_ENABLE_JIT_DEEPGEMM: "0"
run: |
rm -rf test/performance_profiles_mistral_large3/
rm -rf test/performance_profiles_mistral_large3_eagle/
cd test
IS_BLACKWELL=1 python3 nightly/test_mistral_large3_perf.py

Expand All @@ -442,6 +443,7 @@ jobs:
GITHUB_RUN_NUMBER: ${{ github.run_number }}
run: |
python3 scripts/ci/publish_traces.py --traces-dir test/performance_profiles_mistral_large3
python3 scripts/ci/publish_traces.py --traces-dir test/performance_profiles_mistral_large3_eagle

- name: Run DeepSeek v3.1 nightly performance test
if: always()
Expand Down
99 changes: 99 additions & 0 deletions test/nightly/test_mistral_large3_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
register_cuda_ci(est_time=600, suite="nightly-8-gpu-b200", nightly=True)

MISTRAL_LARGE3_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512"
MISTRAL_LARGE3_EAGLE_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle"
PROFILE_DIR = "performance_profiles_mistral_large3"


Expand Down Expand Up @@ -101,5 +102,103 @@ def test_accuracy_mgsm(self):
kill_process_tree(process.pid)


class TestNightlyMistralLarge3EaglePerformance(unittest.TestCase):
"""Test Mistral Large 3 with Eagle speculative decoding."""

@classmethod
def setUpClass(cls):
# Set environment variable to disable JIT DeepGemm
os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"] = "0"

cls.model = MISTRAL_LARGE3_MODEL_PATH
cls.base_url = DEFAULT_URL_FOR_TEST
cls.batch_sizes = [1, 1, 8, 16, 64]
cls.input_lens = tuple(_parse_int_list_env("NIGHTLY_INPUT_LENS", "4096"))
cls.output_lens = tuple(_parse_int_list_env("NIGHTLY_OUTPUT_LENS", "512"))

# Mistral-Large-3 with Eagle speculative decoding
# Eagle model is used as draft model for speculative decoding
cls.other_args = [
"--tp",
"8",
"--attention-backend",
"trtllm_mla",
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
MISTRAL_LARGE3_EAGLE_MODEL_PATH,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
"--kv-cache-dtype",
"auto",
"--model-loader-extra-config",
'{"enable_multithread_load": true}',
"--chat-template",
"mistral",
]

cls.runner = NightlyBenchmarkRunner(
"performance_profiles_mistral_large3_eagle", cls.__name__, cls.base_url
)
cls.runner.setup_profile_directory()

@classmethod
def tearDownClass(cls):
# Clean up environment variable
if "SGLANG_ENABLE_JIT_DEEPGEMM" in os.environ:
del os.environ["SGLANG_ENABLE_JIT_DEEPGEMM"]

def test_eagle_bench_one_batch(self):
results, success = self.runner.run_benchmark_for_model(
model_path=self.model,
batch_sizes=self.batch_sizes,
input_lens=self.input_lens,
output_lens=self.output_lens,
other_args=self.other_args,
)

self.runner.add_report(results)
self.runner.write_final_report()

if not success:
raise AssertionError(
f"Benchmark failed for {self.model} with Eagle. Check the logs for details."
)

def test_eagle_accuracy_mgsm(self):
"""Run MGSM accuracy evaluation for Mistral Large 3 with Eagle."""
process = popen_launch_server(
model=self.model,
base_url=self.base_url,
other_args=self.other_args,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)

try:
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
print(f"MGSM accuracy for {self.model} with Eagle: {metrics['score']}")

# Placeholder threshold - adjust after first successful run
expected_threshold = 0.90
self.assertGreaterEqual(
metrics["score"],
expected_threshold,
f"MGSM accuracy {metrics['score']} below threshold {expected_threshold}",
)
finally:
kill_process_tree(process.pid)


if __name__ == "__main__":
unittest.main()
Loading