Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ class Envs:
SGLANG_ROPE_CACHE_ALIGN = EnvInt(128)

# Overlap Spec V2
SGLANG_ENABLE_SPEC_V2 = EnvBool(False)
SGLANG_ENABLE_SPEC_V2 = EnvBool(True)
SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False)

# Spec Config
Expand Down
41 changes: 17 additions & 24 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,11 +1962,6 @@ def _handle_model_specific_adjustments(self):
logger.info(
"Enable multi-layer EAGLE speculative decoding for MiMoV2 model."
)
if not envs.SGLANG_ENABLE_SPEC_V2.get():
envs.SGLANG_ENABLE_SPEC_V2.set(True)
logger.warning(
"Spec v2 is enabled for multi-layer EAGLE speculative decoding."
)

if self.enable_hierarchical_cache:
self.swa_full_tokens_ratio = 1.0
Expand All @@ -1983,11 +1978,6 @@ def _handle_model_specific_adjustments(self):
logger.info(
"Enable multi-layer EAGLE speculative decoding for Step3p5ForCausalLM model."
)
if not envs.SGLANG_ENABLE_SPEC_V2.get():
envs.SGLANG_ENABLE_SPEC_V2.set(True)
logger.warning(
"Spec v2 is enabled for multi-layer EAGLE speculative decoding."
)
if self.enable_hierarchical_cache:
self.swa_full_tokens_ratio = 1.0
logger.warning(
Expand Down Expand Up @@ -3386,26 +3376,29 @@ def _handle_speculative_decoding(self):
"Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests."
)

spec_v1_reason = None
if (
self.speculative_algorithm in ["EAGLE", "EAGLE3", "STANDALONE"]
and envs.SGLANG_ENABLE_SPEC_V2.get()
self.speculative_eagle_topk is not None
and self.speculative_eagle_topk > 1
and not self.disable_overlap_schedule
):
self.disable_overlap_schedule = True
spec_v1_reason = "spec v2 currently only supports topk = 1"
elif (
not envs.SGLANG_ENABLE_SPEC_V2.get()
and not self.disable_overlap_schedule
):
self.disable_overlap_schedule = False
self.disable_overlap_schedule = True
spec_v1_reason = "SGLANG_ENABLE_SPEC_V2=False"

if self.disable_overlap_schedule:
logger.warning(
"Spec v2 is enabled for eagle/eagle3 speculative decoding and overlap schedule is turned on."
"Spec v1 is used for eagle/eagle3/standalone speculative decoding because %s.",
spec_v1_reason or "overlap schedule is disabled",
)
if (
self.speculative_eagle_topk is not None
and self.speculative_eagle_topk > 1
):
raise ValueError(
"Spec v2 currently only supports topk = 1 for speculative decoding."
)
else:
self.disable_overlap_schedule = True
logger.warning(
"Overlap scheduler is disabled when spec v2 is off or using unsupported speculative algorithm. "
"You can set env SGLANG_ENABLE_SPEC_V2=True to enable the experimental overlap scheduler. "
"Spec v2 is enabled by default for eagle/eagle3/standalone speculative decoding."
)

if self.enable_mixed_chunk:
Expand Down
1 change: 0 additions & 1 deletion test/manual/ascend/test_ascend_deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def setUpClass(cls):
]

envs.SGLANG_NPU_USE_MLAPO.set(True)
envs.SGLANG_ENABLE_SPEC_V2.set(True)
envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.set(True)

def test_a_gsm8k(self):
Expand Down
1 change: 0 additions & 1 deletion test/manual/test_deepseek_v31.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_deepseek_v31_all_variants(self):
DEEPSEEK_V31_MODEL_PATH,
tp_size=8,
extra_args=base_args + mtp_args,
env={"SGLANG_ENABLE_SPEC_V2": "1"},
variant="TP8+MTP",
),
]
Expand Down
1 change: 0 additions & 1 deletion test/manual/test_glm_46_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def test_glm_46_fp8_all_variants(self):
GLM_4_6_FP8_MODEL_PATH,
tp_size=8,
extra_args=base_args + mtp_args,
env={"SGLANG_ENABLE_SPEC_V2": "1"},
variant="TP8+MTP",
),
]
Expand Down
1 change: 0 additions & 1 deletion test/manual/test_qwen3_235b.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def test_qwen3_235b_fp8_all_variants(self):
QWEN3_235B_FP8_MODEL_PATH,
tp_size=8,
extra_args=base_args + eagle3_args,
env={"SGLANG_ENABLE_SPEC_V2": "1"},
variant="TP8+EP2+EAGLE3",
),
]
Expand Down
242 changes: 242 additions & 0 deletions test/registered/4-gpu-models/test_qwen35_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import unittest
from types import SimpleNamespace

import requests

from sglang.srt.utils import kill_process_tree
from sglang.test.accuracy_test_runner import AccuracyTestParams
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.kits.reasoning_kit import ReasoningTokenUsageMixin

# This eval harness applies the chat_template, which is critical for qwen3.5
# to get good accuracy on gsm8k
from sglang.test.run_combined_tests import run_combined_tests
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
ModelLaunchSettings,
popen_launch_server,
)

register_cuda_ci(est_time=768, suite="stage-c-test-4-gpu-b200")

QWEN35_FP4_MODEL = "nvidia/Qwen3.5-397B-A17B-NVFP4"
ACC_THRESHOLDS = {QWEN35_FP4_MODEL: {"gsm8k": 0.95}}


class TestQwen35FP4(CustomTestCase):
def test_gsm8k(self):
base_args = [
"--tp-size",
"4",
"--chunked-prefill-size",
"2048",
"--mamba-scheduler-strategy",
"extra_buffer",
"--mamba-track-interval",
"128",
"--mamba-ssm-dtype",
"bfloat16",
"--max-running-requests",
"128",
"--reasoning-parser",
"qwen3",
"--attention-backend",
"trtllm_mha",
"--quantization",
"modelopt_fp4",
"--model-loader-extra-config",
'{"enable_multithread_load": true,"num_threads": 64}',
]

variants = [
ModelLaunchSettings(
QWEN35_FP4_MODEL,
extra_args=base_args,
variant="Triton",
),
# TODO: Fix this and re-enable it
# ModelLaunchSettings(
# QWEN35_FP4_MODEL,
# extra_args=base_args + ["--linear-attn-decode-backend", "flashinfer"],
# variant="FlashInfer",
# ),
]

run_combined_tests(
models=variants,
test_name="Qwen3.5-397B-A17B-NVFP4",
accuracy_params=AccuracyTestParams(
dataset="gsm8k",
baseline_accuracy=ACC_THRESHOLDS[QWEN35_FP4_MODEL]["gsm8k"],
num_examples=200,
num_threads=128,
max_tokens=16000,
thinking_mode="qwen3",
temperature=0.6,
top_p=0.95,
top_k=20,
),
)


class TestQwen35FP4MTP(ReasoningTokenUsageMixin, CustomTestCase):
reasoning_parser_name = "qwen3"

@classmethod
def setUpClass(cls):
cls.model = QWEN35_FP4_MODEL
cls.base_url = DEFAULT_URL_FOR_TEST
cls.init_reasoning_token_verifier()
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tp-size",
"4",
"--chunked-prefill-size",
"2048",
"--mamba-scheduler-strategy",
"extra_buffer",
"--mamba-track-interval",
"128",
"--mamba-ssm-dtype",
"bfloat16",
"--max-running-requests",
"128",
"--reasoning-parser",
"qwen3",
"--attention-backend",
"trtllm_mha",
"--quantization",
"modelopt_fp4",
"--speculative-algorithm",
"NEXTN",
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
"--mem-fraction-static",
"0.8",
"--model-loader-extra-config",
'{"enable_multithread_load": true,"num_threads": 64}',
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
model=self.model,
eval_name="gsm8k",
num_shots=5,
num_examples=200,
max_tokens=16000,
num_threads=128,
repeat=1,
temperature=0.6,
top_p=0.95,
top_k=20,
base_url=self.base_url,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreaterEqual(metrics["score"], ACC_THRESHOLDS[self.model]["gsm8k"])

server_info = requests.get(self.base_url + "/server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 3.3)


class TestQwen35FP4MTPV2(ReasoningTokenUsageMixin, CustomTestCase):
reasoning_parser_name = "qwen3"

@classmethod
def setUpClass(cls):
cls.model = QWEN35_FP4_MODEL
cls.base_url = DEFAULT_URL_FOR_TEST
cls.init_reasoning_token_verifier()
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--tp-size",
"4",
"--chunked-prefill-size",
"2048",
"--mamba-scheduler-strategy",
"extra_buffer",
"--mamba-track-interval",
"128",
"--mamba-ssm-dtype",
"bfloat16",
"--max-running-requests",
"128",
"--reasoning-parser",
"qwen3",
"--attention-backend",
"trtllm_mha",
"--quantization",
"modelopt_fp4",
"--speculative-algorithm",
"NEXTN",
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"1",
"--speculative-num-draft-tokens",
"4",
"--mem-fraction-static",
"0.8",
"--model-loader-extra-config",
'{"enable_multithread_load": true,"num_threads": 64}',
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
model=self.model,
eval_name="gsm8k",
num_shots=5,
num_examples=200,
max_tokens=16000,
num_threads=128,
repeat=1,
temperature=0.6,
top_p=0.95,
top_k=20,
base_url=self.base_url,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreaterEqual(metrics["score"], ACC_THRESHOLDS[self.model]["gsm8k"])

server_info = requests.get(self.base_url + "/server_info")
avg_spec_accept_length = server_info.json()["internal_states"][0][
"avg_spec_accept_length"
]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 3.3)


if __name__ == "__main__":
unittest.main()
11 changes: 0 additions & 11 deletions test/registered/4-gpu-models/test_qwen3_next_models_mtp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import unittest

from sglang.srt.environ import envs
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.kits.eval_accuracy_kit import GSM8KMixin
from sglang.test.kits.kl_divergence_kit import KLDivergenceMixin
Expand Down Expand Up @@ -94,16 +93,6 @@ class TestQwen3NextMTPV2(GSM8KMixin, KLDivergenceMixin, DefaultServerBase):
"128",
]

@classmethod
def setUpClass(cls):
envs.SGLANG_ENABLE_SPEC_V2.set(True)
super().setUpClass()

@classmethod
def tearDownClass(cls):
envs.SGLANG_ENABLE_SPEC_V2.set(False)
super().tearDownClass()


if __name__ == "__main__":
unittest.main()
2 changes: 0 additions & 2 deletions test/registered/8-gpu-models/test_deepseek_v32.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_deepseek_v32_all_variants(self):
DEEPSEEK_V32_MODEL_PATH,
tp_size=8,
extra_args=BASE_ARGS + DP_ARGS + TOOL_CALL_ARGS + MTP_ARGS,
env={"SGLANG_ENABLE_SPEC_V2": "1"},
variant="DP8+MTP",
),
# Variant: "tp" - Pure TP=8 only
Expand All @@ -83,7 +82,6 @@ def test_deepseek_v32_all_variants(self):
DEEPSEEK_V32_MODEL_PATH,
tp_size=8,
extra_args=BASE_ARGS + TP_ARGS + TOOL_CALL_ARGS + MTP_ARGS,
env={"SGLANG_ENABLE_SPEC_V2": "1"},
variant="TP8+MTP",
),
]
Expand Down
Loading
Loading