From 4cf6b6b2f2a6a8d58b7188419db876fdd25b6312 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 6 Apr 2026 12:46:22 -0700 Subject: [PATCH 1/6] Initial implementation --- .../moe/moe_runner/flashinfer_trtllm.py | 51 ++--- .../compressed_tensors_w4a4_nvfp4_moe.py | 35 +-- .../srt/layers/quantization/modelopt_quant.py | 13 +- ...nvfp4_online_input_scale_update_weights.py | 211 ++++++++++++++++++ .../rl/test_update_weights_from_disk_mxfp8.py | 107 ++++++--- 5 files changed, 315 insertions(+), 102 deletions(-) create mode 100644 test/registered/backends/test_qwen3_nvfp4_online_input_scale_update_weights.py diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index f4add35b391d..6a637805f4ad 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -275,15 +275,15 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: w13_weight.size(0), # num_experts ) - # Set flashinfer parameters + # Set flashinfer parameters in-place + copy_or_rebind_param(layer, "w13_weight", gemm1_weights_fp4_shuffled.contiguous()) + copy_or_rebind_param(layer, "w2_weight", gemm2_weights_fp4_shuffled.contiguous()) copy_or_rebind_param( - layer, "gemm1_weights_fp4_shuffled", gemm1_weights_fp4_shuffled + layer, "w13_weight_scale", gemm1_scales_fp4_shuffled.contiguous() ) copy_or_rebind_param( - layer, "gemm2_weights_fp4_shuffled", gemm2_weights_fp4_shuffled + layer, "w2_weight_scale", gemm2_scales_fp4_shuffled.contiguous() ) - copy_or_rebind_param(layer, "gemm1_scales_fp4_shuffled", gemm1_scales_fp4_shuffled) - copy_or_rebind_param(layer, "gemm2_scales_fp4_shuffled", gemm2_scales_fp4_shuffled) # Compute additional scaling factor needed for TRT-LLM w2_input_scale_quant = cast(torch.Tensor, layer.w2_input_scale_quant) @@ -294,14 +294,6 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: (w2_input_scale_quant * g1_alphas).to(torch.float32), ) - # Clean up weights that won't be used by TRT-LLM - del ( - layer.w2_weight, - layer.w2_weight_scale, - layer.w13_weight, - layer.w13_weight_scale, - ) - @dataclass class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): @@ -560,11 +552,10 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( class FlashInferTrtllmFp4MoeQuantInfo(MoeQuantInfo): """Quantization payload consumed by FlashInfer TRT-LLM FP4 MoE kernels.""" - # Shuffled FP4 weights (processed by align_fp4_moe_weights_for_flashinfer_trtllm) - gemm1_weights_fp4_shuffled: torch.Tensor - gemm2_weights_fp4_shuffled: torch.Tensor - gemm1_scales_fp4_shuffled: torch.Tensor - gemm2_scales_fp4_shuffled: torch.Tensor + w13_weight: torch.Tensor + w2_weight: torch.Tensor + w13_weight_scale: torch.Tensor + w2_weight_scale: torch.Tensor # Scaling factors g1_scale_c: torch.Tensor @@ -666,18 +657,14 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( routing_bias=None, hidden_states=hs_fp4, hidden_states_scale=hs_scale, - gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm1_weights=quant_info.w13_weight, + gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm2_weights=quant_info.w2_weight, + gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=quant_info.g1_scale_c, output1_scale_gate_scalar=quant_info.g1_alphas, @@ -716,18 +703,14 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( routing_bias=correction_bias, hidden_states=hs_fp4, hidden_states_scale=hs_scale, - gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm1_weights=quant_info.w13_weight, + gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm2_weights=quant_info.w2_weight, + gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=quant_info.g1_scale_c, output1_scale_gate_scalar=quant_info.g1_alphas, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py index 5898a078dbba..6b285809ba16 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -20,6 +20,7 @@ from sglang.srt.layers.quantization.utils import ( prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, + replace_parameter, swizzle_blockscale, ) from sglang.srt.utils import next_power_of_2, set_weight_attrs @@ -257,30 +258,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) logger.debug("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = torch.nn.Parameter( - gemm1_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm2_weights_fp4_shuffled = torch.nn.Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = torch.nn.Parameter( - gemm1_scales_fp4_shuffled, requires_grad=False - ) - layer.gemm2_scales_fp4_shuffled = torch.nn.Parameter( - gemm2_scales_fp4_shuffled, requires_grad=False - ) + replace_parameter(layer, "w13_weight", gemm1_weights_fp4_shuffled) + replace_parameter(layer, "w2_weight", gemm2_weights_fp4_shuffled) + replace_parameter(layer, "w13_weight_scale", gemm1_scales_fp4_shuffled) + replace_parameter(layer, "w2_weight_scale", gemm2_scales_fp4_shuffled) # Additional parameter needed for TRT-LLM layer.g1_scale_c = torch.nn.Parameter( (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale else: # swizzle weight scales layer.w13_weight_scale = torch.nn.Parameter( @@ -370,18 +357,14 @@ def apply_weights( routing_bias=correction_bias, hidden_states=hs_fp4, hidden_states_scale=hs_scale, - gemm1_weights=layer.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c, output1_scale_gate_scalar=layer.g1_alphas, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 3115bf0a5b6b..814f9e602bb1 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1980,9 +1980,8 @@ def apply( ), f"{activation=} missing from {ACT_STR_TO_TYPE_MAP.keys()=}" moe_runner_config = self.moe_runner_config - # FlashInfer TRTLLM FP4 path - layer has shuffled weights only when - # backend is flashinfer_trtllm - if hasattr(layer, "gemm1_weights_fp4_shuffled"): + # FlashInfer TRTLLM FP4 path + if self.enable_flashinfer_trtllm_moe: from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( FlashInferTrtllmFp4MoeQuantInfo, ) @@ -1994,10 +1993,10 @@ def apply( ) quant_info = FlashInferTrtllmFp4MoeQuantInfo( - gemm1_weights_fp4_shuffled=layer.gemm1_weights_fp4_shuffled.data, - gemm2_weights_fp4_shuffled=layer.gemm2_weights_fp4_shuffled.data, - gemm1_scales_fp4_shuffled=layer.gemm1_scales_fp4_shuffled.data, - gemm2_scales_fp4_shuffled=layer.gemm2_scales_fp4_shuffled.data, + w13_weight=layer.w13_weight.data, + w2_weight=layer.w2_weight.data, + w13_weight_scale=layer.w13_weight_scale.data, + w2_weight_scale=layer.w2_weight_scale.data, g1_scale_c=layer.g1_scale_c.data, g1_alphas=layer.g1_alphas.data, g2_alphas=layer.g2_alphas.data, diff --git a/test/registered/backends/test_qwen3_nvfp4_online_input_scale_update_weights.py b/test/registered/backends/test_qwen3_nvfp4_online_input_scale_update_weights.py new file mode 100644 index 000000000000..8ea1f95b0212 --- /dev/null +++ b/test/registered/backends/test_qwen3_nvfp4_online_input_scale_update_weights.py @@ -0,0 +1,211 @@ +import os +import unittest +from types import SimpleNamespace + +import requests +import torch + +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=1800, suite="nightly-4-gpu-b200", nightly=True) + +QWEN3_NVFP4_MODEL = "nvidia/Qwen3-30B-A3B-NVFP4" +GSM8K_QUESTION_COUNT = 200 +GSM8K_NUM_SHOTS = 8 +GSM8K_MIN_ACCURACY = 0.88 +MAX_ALLOWED_ACCURACY_DROP = 0.05 + +CASE_MATRIX = [ + { + "name": "case1_trtllm_moe_trtllm_gemm_tp1", + "other_args": [ + "--quantization", + "modelopt_fp4", + "--moe-runner-backend", + "flashinfer_trtllm", + "--fp4-gemm-backend", + "flashinfer_trtllm", + "--tp", + "1", + "--moe-a2a-backend", + "none", + "--trust-remote-code", + ], + }, + { + "name": "case2_flashinfer_cutlass_moe_cutlass_gemm_tp1", + "other_args": [ + "--quantization", + "modelopt_fp4", + "--moe-runner-backend", + "flashinfer_cutlass", + "--fp4-gemm-backend", + "flashinfer_cutlass", + "--tp", + "1", + "--moe-a2a-backend", + "none", + "--trust-remote-code", + ], + }, + { + "name": "case3_cutlass_moe_cutlass_gemm_tp1", + "other_args": [ + "--quantization", + "modelopt_fp4", + "--moe-runner-backend", + "cutlass", + "--fp4-gemm-backend", + "flashinfer_cutlass", + "--tp", + "1", + "--moe-a2a-backend", + "none", + "--trust-remote-code", + ], + }, + { + "name": "case4_flashinfer_a2a_cutlass_moe_tp2_ep2", + "other_args": [ + "--quantization", + "modelopt_fp4", + "--moe-runner-backend", + "flashinfer_cutlass", + "--fp4-gemm-backend", + "flashinfer_cutlass", + "--moe-a2a-backend", + "flashinfer", + "--tp", + "2", + "--ep", + "2", + "--trust-remote-code", + ], + }, + { + "name": "case5_dp_allgather_flashinfer_cutlass_tp4_dp4_ep4", + "other_args": [ + "--quantization", + "modelopt_fp4", + "--moe-runner-backend", + "flashinfer_cutlass", + "--fp4-gemm-backend", + "flashinfer_cutlass", + "--moe-a2a-backend", + "none", + "--tp", + "4", + "--dp", + "4", + "--ep", + "4", + "--enable-dp-attention", + "--trust-remote-code", + ], + }, +] + + +@unittest.skipIf( + get_device_sm() < 100, "Test requires CUDA SM 100 or higher (Blackwell)" +) +@unittest.skipIf(torch.cuda.device_count() < 4, "Test requires at least 4 CUDA GPUs") +class TestQwen3Nvfp4OnlineInputScaleUpdateWeights(CustomTestCase): + model = QWEN3_NVFP4_MODEL + base_url = DEFAULT_URL_FOR_TEST + port = int(base_url.split(":")[-1]) + + def _run_gsm8k_eval(self, case_name): + args = SimpleNamespace( + num_shots=GSM8K_NUM_SHOTS, + data_path=None, + num_questions=GSM8K_QUESTION_COUNT, + max_new_tokens=512, + parallel=GSM8K_QUESTION_COUNT, + host="http://127.0.0.1", + port=self.port, + ) + metrics = run_eval(args) + print(f"{case_name=}, {metrics=}") + return metrics + + def _update_weights_from_disk(self, case_name): + response = requests.post( + self.base_url + "/update_weights_from_disk", + json={ + "model_path": self.model, + "flush_cache": True, + "abort_all_requests": False, + }, + timeout=300, + ) + self.assertEqual(response.status_code, 200, msg=f"{case_name}: {response.text}") + + result = response.json() + self.assertTrue(result.get("success"), msg=f"{case_name}: {result}") + + def _assert_model_path_is_expected(self, case_name): + response = requests.get(self.base_url + "/get_model_info", timeout=30) + self.assertEqual(response.status_code, 200, msg=f"{case_name}: {response.text}") + self.assertEqual( + response.json()["model_path"], + self.model, + msg=f"{case_name}: model mismatch", + ) + + def _run_single_case(self, case): + process = None + case_name = case["name"] + try: + process = popen_launch_server( + self.model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=case["other_args"], + env={ + **os.environ, + "SGLANG_NVFP4_ONLINE_INPUT_SCALE": "1", + }, + ) + + metrics_before = self._run_gsm8k_eval(case_name) + self.assertGreaterEqual( + metrics_before["accuracy"], + GSM8K_MIN_ACCURACY, + msg=f"{case_name}: before accuracy too low", + ) + + self._update_weights_from_disk(case_name) + self._assert_model_path_is_expected(case_name) + + metrics_after = self._run_gsm8k_eval(case_name) + self.assertGreaterEqual( + metrics_after["accuracy"], + GSM8K_MIN_ACCURACY, + msg=f"{case_name}: after accuracy too low", + ) + self.assertGreaterEqual( + metrics_after["accuracy"], + metrics_before["accuracy"] - MAX_ALLOWED_ACCURACY_DROP, + msg=f"{case_name}: post-update accuracy regressed too much", + ) + finally: + if process is not None: + kill_process_tree(process.pid) + + def test_update_weights_with_online_input_scale_matrix(self): + for case in CASE_MATRIX: + with self.subTest(case=case["name"]): + self._run_single_case(case) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/rl/test_update_weights_from_disk_mxfp8.py b/test/registered/rl/test_update_weights_from_disk_mxfp8.py index 12ac0758cdb6..fa00e43fbc0c 100644 --- a/test/registered/rl/test_update_weights_from_disk_mxfp8.py +++ b/test/registered/rl/test_update_weights_from_disk_mxfp8.py @@ -15,37 +15,43 @@ ) -class TestServerUpdateWeightsFromDiskMXFP8(CustomTestCase): - model = "zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16" +class UpdateWeightsFromDiskModelBase: + model = None base_url = DEFAULT_URL_FOR_TEST request_timeout = 120 update_timeout = 240 + launch_env = None decode_payload = { "text": "The capital of France is", "sampling_params": {"temperature": 0, "max_new_tokens": 16}, } - backend_test_suites = ( - { - "fp8_gemm_backend": "flashinfer_trtllm", - "moe_runner_backend": "flashinfer_trtllm_routed", - }, + backend_test_suites = () + update_test_suites = ( + {"flush_cache": True, "abort_all_requests": False}, + {"flush_cache": False, "abort_all_requests": False}, ) - def _launch_server(self, fp8_gemm_backend, moe_runner_backend): + @classmethod + def setUpClass(cls): + super().setUpClass() + if cls.model is None: + raise NotImplementedError("Subclass must set 'model' attribute") + if not cls.backend_test_suites: + raise NotImplementedError( + "Subclass must set non-empty 'backend_test_suites'" + ) + + def _launch_server(self, backend_test_suite): + launch_kwargs = {} + if self.launch_env is not None: + launch_kwargs["env"] = self.launch_env + other_args = backend_test_suite.get("other_args") return popen_launch_server( self.model, self.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--base-gpu-id", - "0", - "--tp-size", - "4", - "--fp8-gemm-backend", - fp8_gemm_backend, - "--moe-runner-backend", - moe_runner_backend, - ], + other_args=other_args, + **launch_kwargs, ) def _get_json(self, endpoint, timeout=None): @@ -119,30 +125,19 @@ def _run_update_weights( timeout=self.update_timeout, ) - def test_parameterized_update_weights_mxfp8(self): - update_test_suites = ( - {"flush_cache": True, "abort_all_requests": False}, - {"flush_cache": False, "abort_all_requests": False}, - ) + def test_parameterized_update_weights_from_disk(self): for backend_test_suite in self.backend_test_suites: - with self.subTest(**backend_test_suite): - process = self._launch_server( - backend_test_suite["fp8_gemm_backend"], - backend_test_suite["moe_runner_backend"], - ) + case_name = backend_test_suite.get("name", "default") + with self.subTest(model=self.model, case_name=case_name): + process = self._launch_server(backend_test_suite) try: origin_model_path = self._get_model_info() self.assertEqual(origin_model_path, self.model) self._assert_non_empty_decode() baseline_sig = self._get_decode_logprob_signature() - for update_test_suite in update_test_suites: - with self.subTest( - fp8_gemm_backend=backend_test_suite["fp8_gemm_backend"], - moe_runner_backend=backend_test_suite["moe_runner_backend"], - flush_cache=update_test_suite["flush_cache"], - abort_all_requests=update_test_suite["abort_all_requests"], - ): + for update_test_suite in self.update_test_suites: + with self.subTest(case_name=case_name, **update_test_suite): ret = self._run_update_weights( self.model, flush_cache=update_test_suite["flush_cache"], @@ -161,5 +156,47 @@ def test_parameterized_update_weights_mxfp8(self): kill_process_tree(process.pid) +class TestServerUpdateWeightsFromDiskMXFP8( + UpdateWeightsFromDiskModelBase, CustomTestCase +): + model = "zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16" + backend_test_suites = ( + { + "name": "flashinfer_trtllm_routed_mxfp8", + "other_args": ( + "--base-gpu-id", + "0", + "--tp-size", + "4", + "--fp8-gemm-backend", + "flashinfer_trtllm", + "--moe-runner-backend", + "flashinfer_trtllm_routed", + ), + }, + ) + + +class TestServerUpdateWeightsFromDiskNVFP4( + UpdateWeightsFromDiskModelBase, CustomTestCase +): + model = "nvidia/Qwen3-30B-A3B-NVFP4" + backend_test_suites = ( + { + "name": "flashinfer_trtllm_nvfp4", + "other_args": ( + "--base-gpu-id", + "0", + "--tp-size", + "4", + "--fp4-gemm-backend", + "flashinfer_trtllm", + "--moe-runner-backend", + "flashinfer_trtllm", + ), + }, + ) + + if __name__ == "__main__": unittest.main() From 0f642d72c5d1ddc4a20cd4b6351c361448feca80 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 6 Apr 2026 12:49:42 -0700 Subject: [PATCH 2/6] Rename test --- ...m_disk_mxfp8.py => test_update_weights_from_disk_blackwell.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/registered/rl/{test_update_weights_from_disk_mxfp8.py => test_update_weights_from_disk_blackwell.py} (100%) diff --git a/test/registered/rl/test_update_weights_from_disk_mxfp8.py b/test/registered/rl/test_update_weights_from_disk_blackwell.py similarity index 100% rename from test/registered/rl/test_update_weights_from_disk_mxfp8.py rename to test/registered/rl/test_update_weights_from_disk_blackwell.py From 5e54e8bd29e834bbe16dd8b94982a407c4487f67 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 6 Apr 2026 12:59:46 -0700 Subject: [PATCH 3/6] Minor fix --- python/sglang/srt/layers/quantization/modelopt_quant.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 814f9e602bb1..c24185bdf633 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1981,7 +1981,7 @@ def apply( moe_runner_config = self.moe_runner_config # FlashInfer TRTLLM FP4 path - if self.enable_flashinfer_trtllm_moe: + if self.enable_flashinfer_trtllm_moe and hasattr(layer, "g1_scale_c"): from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( FlashInferTrtllmFp4MoeQuantInfo, ) From 3a769122b3e6ec5e4bde0de2827e368c7101ab92 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 7 Apr 2026 02:11:11 -0700 Subject: [PATCH 4/6] Expand tests --- python/sglang/srt/server_args.py | 3 +- .../test_flashinfer_trtllm_gen_moe_backend.py | 55 +++++++++++++++++++ ...test_update_weights_from_disk_blackwell.py | 2 +- 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 93eba2b1e74f..d309f7819763 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2776,8 +2776,9 @@ def _handle_moe_kernel_config(self): assert self.quantization in [ "fp8", "mxfp8", + "modelopt_fp4", None, - ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', or bfloat16 (None)." + ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', 'modelopt_fp4', or bfloat16 (None)." self.disable_shared_experts_fusion = True logger.warning( "FlashInfer TRTLLM routed MoE is enabled. --disable-shared-experts-fusion is automatically set." diff --git a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py index b63447a60cd5..9c7260b19966 100644 --- a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py +++ b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py @@ -157,6 +157,49 @@ def test_gsm8k(self): self.assertGreater(metrics["score"], 0.93) +class FlashinferTrtllmGenMoeBackendNVFP4Base: + backend = None + + @classmethod + def setUpClass(cls): + cls.model = "nvidia/Qwen3-30B-A3B-NVFP4" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"}, + other_args=[ + "--moe-runner-backend", + cls.backend, + "--tp-size", + "4", + "--ep-size", + "4", + "--mem-fraction-static", + "0.7", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.89) + + class TestFlashinferTrtllmGenMoeBackendFP8( FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase ): @@ -175,6 +218,12 @@ class TestFlashinferTrtllmGenMoeBackendBF16( backend = "flashinfer_trtllm" +class TestFlashinferTrtllmGenMoeBackendNVFP4( + FlashinferTrtllmGenMoeBackendNVFP4Base, CustomTestCase +): + backend = "flashinfer_trtllm" + + class TestFlashinferTrtllmGenMoeBackendFP8Routed( FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase ): @@ -193,5 +242,11 @@ class TestFlashinferTrtllmGenMoeBackendBF16Routed( backend = "flashinfer_trtllm_routed" +class TestFlashinferTrtllmGenMoeBackendNVFP4Routed( + FlashinferTrtllmGenMoeBackendNVFP4Base, CustomTestCase +): + backend = "flashinfer_trtllm_routed" + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/rl/test_update_weights_from_disk_blackwell.py b/test/registered/rl/test_update_weights_from_disk_blackwell.py index fa00e43fbc0c..d4a2b71dc1e0 100644 --- a/test/registered/rl/test_update_weights_from_disk_blackwell.py +++ b/test/registered/rl/test_update_weights_from_disk_blackwell.py @@ -192,7 +192,7 @@ class TestServerUpdateWeightsFromDiskNVFP4( "--fp4-gemm-backend", "flashinfer_trtllm", "--moe-runner-backend", - "flashinfer_trtllm", + "flashinfer_trtllm_routed", ), }, ) From d5d128a897c4e46a727a2db41e768c50e0f098af Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 7 Apr 2026 17:10:05 -0700 Subject: [PATCH 5/6] Drop file --- ...nvfp4_online_input_scale_update_weights.py | 211 ------------------ 1 file changed, 211 deletions(-) delete mode 100644 test/registered/backends/test_qwen3_nvfp4_online_input_scale_update_weights.py diff --git a/test/registered/backends/test_qwen3_nvfp4_online_input_scale_update_weights.py b/test/registered/backends/test_qwen3_nvfp4_online_input_scale_update_weights.py deleted file mode 100644 index 8ea1f95b0212..000000000000 --- a/test/registered/backends/test_qwen3_nvfp4_online_input_scale_update_weights.py +++ /dev/null @@ -1,211 +0,0 @@ -import os -import unittest -from types import SimpleNamespace - -import requests -import torch - -from sglang.srt.utils import get_device_sm, kill_process_tree -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - -register_cuda_ci(est_time=1800, suite="nightly-4-gpu-b200", nightly=True) - -QWEN3_NVFP4_MODEL = "nvidia/Qwen3-30B-A3B-NVFP4" -GSM8K_QUESTION_COUNT = 200 -GSM8K_NUM_SHOTS = 8 -GSM8K_MIN_ACCURACY = 0.88 -MAX_ALLOWED_ACCURACY_DROP = 0.05 - -CASE_MATRIX = [ - { - "name": "case1_trtllm_moe_trtllm_gemm_tp1", - "other_args": [ - "--quantization", - "modelopt_fp4", - "--moe-runner-backend", - "flashinfer_trtllm", - "--fp4-gemm-backend", - "flashinfer_trtllm", - "--tp", - "1", - "--moe-a2a-backend", - "none", - "--trust-remote-code", - ], - }, - { - "name": "case2_flashinfer_cutlass_moe_cutlass_gemm_tp1", - "other_args": [ - "--quantization", - "modelopt_fp4", - "--moe-runner-backend", - "flashinfer_cutlass", - "--fp4-gemm-backend", - "flashinfer_cutlass", - "--tp", - "1", - "--moe-a2a-backend", - "none", - "--trust-remote-code", - ], - }, - { - "name": "case3_cutlass_moe_cutlass_gemm_tp1", - "other_args": [ - "--quantization", - "modelopt_fp4", - "--moe-runner-backend", - "cutlass", - "--fp4-gemm-backend", - "flashinfer_cutlass", - "--tp", - "1", - "--moe-a2a-backend", - "none", - "--trust-remote-code", - ], - }, - { - "name": "case4_flashinfer_a2a_cutlass_moe_tp2_ep2", - "other_args": [ - "--quantization", - "modelopt_fp4", - "--moe-runner-backend", - "flashinfer_cutlass", - "--fp4-gemm-backend", - "flashinfer_cutlass", - "--moe-a2a-backend", - "flashinfer", - "--tp", - "2", - "--ep", - "2", - "--trust-remote-code", - ], - }, - { - "name": "case5_dp_allgather_flashinfer_cutlass_tp4_dp4_ep4", - "other_args": [ - "--quantization", - "modelopt_fp4", - "--moe-runner-backend", - "flashinfer_cutlass", - "--fp4-gemm-backend", - "flashinfer_cutlass", - "--moe-a2a-backend", - "none", - "--tp", - "4", - "--dp", - "4", - "--ep", - "4", - "--enable-dp-attention", - "--trust-remote-code", - ], - }, -] - - -@unittest.skipIf( - get_device_sm() < 100, "Test requires CUDA SM 100 or higher (Blackwell)" -) -@unittest.skipIf(torch.cuda.device_count() < 4, "Test requires at least 4 CUDA GPUs") -class TestQwen3Nvfp4OnlineInputScaleUpdateWeights(CustomTestCase): - model = QWEN3_NVFP4_MODEL - base_url = DEFAULT_URL_FOR_TEST - port = int(base_url.split(":")[-1]) - - def _run_gsm8k_eval(self, case_name): - args = SimpleNamespace( - num_shots=GSM8K_NUM_SHOTS, - data_path=None, - num_questions=GSM8K_QUESTION_COUNT, - max_new_tokens=512, - parallel=GSM8K_QUESTION_COUNT, - host="http://127.0.0.1", - port=self.port, - ) - metrics = run_eval(args) - print(f"{case_name=}, {metrics=}") - return metrics - - def _update_weights_from_disk(self, case_name): - response = requests.post( - self.base_url + "/update_weights_from_disk", - json={ - "model_path": self.model, - "flush_cache": True, - "abort_all_requests": False, - }, - timeout=300, - ) - self.assertEqual(response.status_code, 200, msg=f"{case_name}: {response.text}") - - result = response.json() - self.assertTrue(result.get("success"), msg=f"{case_name}: {result}") - - def _assert_model_path_is_expected(self, case_name): - response = requests.get(self.base_url + "/get_model_info", timeout=30) - self.assertEqual(response.status_code, 200, msg=f"{case_name}: {response.text}") - self.assertEqual( - response.json()["model_path"], - self.model, - msg=f"{case_name}: model mismatch", - ) - - def _run_single_case(self, case): - process = None - case_name = case["name"] - try: - process = popen_launch_server( - self.model, - self.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=case["other_args"], - env={ - **os.environ, - "SGLANG_NVFP4_ONLINE_INPUT_SCALE": "1", - }, - ) - - metrics_before = self._run_gsm8k_eval(case_name) - self.assertGreaterEqual( - metrics_before["accuracy"], - GSM8K_MIN_ACCURACY, - msg=f"{case_name}: before accuracy too low", - ) - - self._update_weights_from_disk(case_name) - self._assert_model_path_is_expected(case_name) - - metrics_after = self._run_gsm8k_eval(case_name) - self.assertGreaterEqual( - metrics_after["accuracy"], - GSM8K_MIN_ACCURACY, - msg=f"{case_name}: after accuracy too low", - ) - self.assertGreaterEqual( - metrics_after["accuracy"], - metrics_before["accuracy"] - MAX_ALLOWED_ACCURACY_DROP, - msg=f"{case_name}: post-update accuracy regressed too much", - ) - finally: - if process is not None: - kill_process_tree(process.pid) - - def test_update_weights_with_online_input_scale_matrix(self): - for case in CASE_MATRIX: - with self.subTest(case=case["name"]): - self._run_single_case(case) - - -if __name__ == "__main__": - unittest.main() From cc6e152b4b484dc91d861b2c11edd5eaf527a809 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 7 Apr 2026 17:13:17 -0700 Subject: [PATCH 6/6] Minor rename --- .../rl/test_update_weights_from_disk_blackwell.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/test/registered/rl/test_update_weights_from_disk_blackwell.py b/test/registered/rl/test_update_weights_from_disk_blackwell.py index d4a2b71dc1e0..58d79c375699 100644 --- a/test/registered/rl/test_update_weights_from_disk_blackwell.py +++ b/test/registered/rl/test_update_weights_from_disk_blackwell.py @@ -15,7 +15,7 @@ ) -class UpdateWeightsFromDiskModelBase: +class UpdateWeightsFromDiskBase: model = None base_url = DEFAULT_URL_FOR_TEST request_timeout = 120 @@ -156,9 +156,7 @@ def test_parameterized_update_weights_from_disk(self): kill_process_tree(process.pid) -class TestServerUpdateWeightsFromDiskMXFP8( - UpdateWeightsFromDiskModelBase, CustomTestCase -): +class TestServerUpdateWeightsFromDiskMXFP8(UpdateWeightsFromDiskBase, CustomTestCase): model = "zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16" backend_test_suites = ( { @@ -177,9 +175,7 @@ class TestServerUpdateWeightsFromDiskMXFP8( ) -class TestServerUpdateWeightsFromDiskNVFP4( - UpdateWeightsFromDiskModelBase, CustomTestCase -): +class TestServerUpdateWeightsFromDiskNVFP4(UpdateWeightsFromDiskBase, CustomTestCase): model = "nvidia/Qwen3-30B-A3B-NVFP4" backend_test_suites = ( {