Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions python/sglang/test/gsm8k_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from abc import ABC
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)


class GSM8KMixin(ABC):
accuracy: float
model: str
other_args: list[str] = []

@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.other_args,
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreaterEqual(metrics["accuracy"], self.accuracy)
233 changes: 233 additions & 0 deletions python/sglang/test/mmmu_vlm_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import glob
import json
import os
import subprocess
from abc import ABC
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)

# Set default mem_fraction_static to 0.8
DEFAULT_MEM_FRACTION_STATIC = 0.8


class MMMUVLMMixin(ABC):
parsed_args = None # Class variable to store args
other_args = []
mmmu_args = []

@classmethod
def setUpClass(cls):
# Removed argument parsing from here
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.time_out = DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH

if cls.parsed_args is None:
cls.parsed_args = SimpleNamespace(
mem_fraction_static=DEFAULT_MEM_FRACTION_STATIC
)

# Set OpenAI API key and base URL environment variables. Needed for lmm-evals to work.
os.environ["OPENAI_API_KEY"] = cls.api_key
os.environ["OPENAI_API_BASE"] = f"{cls.base_url}/v1"

def run_mmmu_eval(
self,
model_version: str,
output_path: str,
*,
env: dict | None = None,
):
"""
Evaluate a VLM on the MMMU validation set with lmms‑eval.
Only `model_version` (checkpoint) and `chat_template` vary;
We are focusing only on the validation set due to resource constraints.
"""
# -------- fixed settings --------
model = "openai_compatible"
tp = 1
tasks = "mmmu_val"
batch_size = 32
log_suffix = "openai_compatible"
os.makedirs(output_path, exist_ok=True)

# -------- compose --model_args --------
model_args = f'model_version="{model_version}",' f"tp={tp}"

# -------- build command list --------
cmd = [
"python3",
"-m",
"lmms_eval",
"--model",
model,
"--model_args",
model_args,
"--tasks",
tasks,
"--batch_size",
str(batch_size),
"--log_samples",
"--log_samples_suffix",
log_suffix,
"--output_path",
str(output_path),
*self.mmmu_args,
]

subprocess.run(
cmd,
check=True,
timeout=3600,
)

def _run_vlm_mmmu_test(
self,
model,
output_path,
test_name="",
custom_env=None,
log_level="info",
capture_output=False,
):
"""
Common method to run VLM MMMU benchmark test.

Args:
model: Model to test
output_path: Path for output logs
test_name: Optional test name for logging
custom_env: Optional custom environment variables
log_level: Log level for server (default: "info")
capture_output: Whether to capture server stdout/stderr
"""
print(f"\nTesting model: {model.model}{test_name}")

process = None
mmmu_accuracy = 0 # Initialize to handle potential exceptions
server_output = ""

try:
# Prepare environment variables
process_env = os.environ.copy()
if custom_env:
process_env.update(custom_env)
# if test vlm with cuda_ipc feature, open this env_var
process_env["SGLANG_USE_CUDA_IPC_TRANSPORT"] = "1"

# Prepare stdout/stderr redirection if needed
stdout_file = None
stderr_file = None
if capture_output:
stdout_file = open("/tmp/server_stdout.log", "w")
stderr_file = open("/tmp/server_stderr.log", "w")

# Launch server for testing
process = popen_launch_server(
model.model,
base_url=self.base_url,
timeout=self.time_out,
api_key=self.api_key,
other_args=[
"--trust-remote-code",
"--cuda-graph-max-bs",
"32",
"--enable-multimodal",
"--mem-fraction-static",
str(self.parsed_args.mem_fraction_static), # Use class variable
"--log-level",
log_level,
*self.other_args,
],
env=process_env,
return_stdout_stderr=(
(stdout_file, stderr_file) if capture_output else None
),
)

# Run evaluation
self.run_mmmu_eval(model.model, output_path)

# Get the result file
# Search recursively for JSON result files (lmms-eval v0.4.1+ creates subdirectories)
result_files = glob.glob(f"{output_path}/**/*.json", recursive=True)
if not result_files:
result_files = glob.glob(f"{output_path}/*.json")

if not result_files:
raise FileNotFoundError(f"No JSON result files found in {output_path}")

result_file_path = result_files[0]

with open(result_file_path, "r") as f:
result = json.load(f)
print(f"Result{test_name}\n: {result}")

# Process the result
mmmu_accuracy = result["results"]["mmmu_val"]["mmmu_acc,none"]
print(
f"Model {model.model} achieved accuracy{test_name}: {mmmu_accuracy:.4f}"
)

# Capture server output if requested
if capture_output and process:
server_output = self._read_output_from_files()

# Assert performance meets expected threshold
self.assertGreaterEqual(
mmmu_accuracy,
model.mmmu_accuracy,
f"Model {model.model} accuracy ({mmmu_accuracy:.4f}) below expected threshold ({model.mmmu_accuracy:.4f}){test_name}",
)

return server_output

except Exception as e:
print(f"Error testing {model.model}{test_name}: {e}")
self.fail(f"Test failed for {model.model}{test_name}: {e}")

finally:
# Ensure process cleanup happens regardless of success/failure
if process is not None and process.poll() is None:
print(f"Cleaning up process {process.pid}")
try:
kill_process_tree(process.pid)
except Exception as e:
print(f"Error killing process: {e}")

# clean up temporary files
if capture_output:
if stdout_file:
stdout_file.close()
if stderr_file:
stderr_file.close()
for filename in ["/tmp/server_stdout.log", "/tmp/server_stderr.log"]:
try:
if os.path.exists(filename):
os.remove(filename)
except Exception as e:
print(f"Error removing {filename}: {e}")

def _read_output_from_files(self):
output_lines = []

log_files = [
("/tmp/server_stdout.log", "[STDOUT]"),
("/tmp/server_stderr.log", "[STDERR]"),
]
for filename, tag in log_files:
try:
if os.path.exists(filename):
with open(filename, "r") as f:
for line in f:
output_lines.append(f"{tag} {line.rstrip()}")
except Exception as e:
print(f"Error reading {tag.lower()} file: {e}")

return "\n".join(output_lines)
54 changes: 10 additions & 44 deletions test/srt/models/test_nvidia_nemotron_nano_v2.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,27 @@
import unittest
from types import SimpleNamespace

from sglang.srt.utils import is_blackwell, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
from sglang.srt.utils import is_blackwell
from sglang.test.gsm8k_mixin import GSM8KMixin
from sglang.test.test_utils import CustomTestCase


class TestNvidiaNemotronNanoV2(CustomTestCase):
class TestNvidiaNemotronNanoV2BF16(GSM8KMixin, CustomTestCase):
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
accuracy = 0.87
other_args = ["--max-mamba-cache-size", "256"]

@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--max-mamba-cache-size",
"256",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreaterEqual(metrics["accuracy"], self.accuracy)


class TestNvidiaNemotronNanoV2FP8(TestNvidiaNemotronNanoV2):

class TestNvidiaNemotronNanoV2FP8(GSM8KMixin, CustomTestCase):
accuracy = 0.87
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8"
other_args = ["--max-mamba-cache-size", "256"]


@unittest.skipIf(not is_blackwell(), "NVFP4 only supported on blackwell")
class TestNvidiaNemotronNanoV2NVFP4(TestNvidiaNemotronNanoV2):
class TestNvidiaNemotronNanoV2NVFP4(GSM8KMixin, CustomTestCase):
accuracy = 0.855
model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4"
other_args = ["--max-mamba-cache-size", "256"]


if __name__ == "__main__":
Expand Down
Loading
Loading