diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index f4add35b391d..6a637805f4ad 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -275,15 +275,15 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: w13_weight.size(0), # num_experts ) - # Set flashinfer parameters + # Set flashinfer parameters in-place + copy_or_rebind_param(layer, "w13_weight", gemm1_weights_fp4_shuffled.contiguous()) + copy_or_rebind_param(layer, "w2_weight", gemm2_weights_fp4_shuffled.contiguous()) copy_or_rebind_param( - layer, "gemm1_weights_fp4_shuffled", gemm1_weights_fp4_shuffled + layer, "w13_weight_scale", gemm1_scales_fp4_shuffled.contiguous() ) copy_or_rebind_param( - layer, "gemm2_weights_fp4_shuffled", gemm2_weights_fp4_shuffled + layer, "w2_weight_scale", gemm2_scales_fp4_shuffled.contiguous() ) - copy_or_rebind_param(layer, "gemm1_scales_fp4_shuffled", gemm1_scales_fp4_shuffled) - copy_or_rebind_param(layer, "gemm2_scales_fp4_shuffled", gemm2_scales_fp4_shuffled) # Compute additional scaling factor needed for TRT-LLM w2_input_scale_quant = cast(torch.Tensor, layer.w2_input_scale_quant) @@ -294,14 +294,6 @@ def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: (w2_input_scale_quant * g1_alphas).to(torch.float32), ) - # Clean up weights that won't be used by TRT-LLM - del ( - layer.w2_weight, - layer.w2_weight_scale, - layer.w13_weight, - layer.w13_weight_scale, - ) - @dataclass class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): @@ -560,11 +552,10 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( class FlashInferTrtllmFp4MoeQuantInfo(MoeQuantInfo): """Quantization payload consumed by FlashInfer TRT-LLM FP4 MoE kernels.""" - # Shuffled FP4 weights (processed by align_fp4_moe_weights_for_flashinfer_trtllm) - gemm1_weights_fp4_shuffled: torch.Tensor - gemm2_weights_fp4_shuffled: torch.Tensor - gemm1_scales_fp4_shuffled: torch.Tensor - gemm2_scales_fp4_shuffled: torch.Tensor + w13_weight: torch.Tensor + w2_weight: torch.Tensor + w13_weight_scale: torch.Tensor + w2_weight_scale: torch.Tensor # Scaling factors g1_scale_c: torch.Tensor @@ -666,18 +657,14 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( routing_bias=None, hidden_states=hs_fp4, hidden_states_scale=hs_scale, - gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm1_weights=quant_info.w13_weight, + gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm2_weights=quant_info.w2_weight, + gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=quant_info.g1_scale_c, output1_scale_gate_scalar=quant_info.g1_alphas, @@ -716,18 +703,14 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( routing_bias=correction_bias, hidden_states=hs_fp4, hidden_states_scale=hs_scale, - gemm1_weights=quant_info.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=quant_info.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm1_weights=quant_info.w13_weight, + gemm1_weights_scale=quant_info.w13_weight_scale.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=quant_info.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=quant_info.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm2_weights=quant_info.w2_weight, + gemm2_weights_scale=quant_info.w2_weight_scale.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=quant_info.g1_scale_c, output1_scale_gate_scalar=quant_info.g1_alphas, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py index 5898a078dbba..6b285809ba16 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -20,6 +20,7 @@ from sglang.srt.layers.quantization.utils import ( prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, + replace_parameter, swizzle_blockscale, ) from sglang.srt.utils import next_power_of_2, set_weight_attrs @@ -257,30 +258,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: ) logger.debug("Finished shuffling weights for TRT-LLM MOE") - layer.gemm1_weights_fp4_shuffled = torch.nn.Parameter( - gemm1_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm2_weights_fp4_shuffled = torch.nn.Parameter( - gemm2_weights_fp4_shuffled, requires_grad=False - ) - layer.gemm1_scales_fp4_shuffled = torch.nn.Parameter( - gemm1_scales_fp4_shuffled, requires_grad=False - ) - layer.gemm2_scales_fp4_shuffled = torch.nn.Parameter( - gemm2_scales_fp4_shuffled, requires_grad=False - ) + replace_parameter(layer, "w13_weight", gemm1_weights_fp4_shuffled) + replace_parameter(layer, "w2_weight", gemm2_weights_fp4_shuffled) + replace_parameter(layer, "w13_weight_scale", gemm1_scales_fp4_shuffled) + replace_parameter(layer, "w2_weight_scale", gemm2_scales_fp4_shuffled) # Additional parameter needed for TRT-LLM layer.g1_scale_c = torch.nn.Parameter( (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) - - # Clean up weights that won't be used by TRT-LLM - del layer.w2_weight - del layer.w2_weight_scale - del layer.w13_weight - del layer.w13_weight_scale else: # swizzle weight scales layer.w13_weight_scale = torch.nn.Parameter( @@ -370,18 +357,14 @@ def apply_weights( routing_bias=correction_bias, hidden_states=hs_fp4, hidden_states_scale=hs_scale, - gemm1_weights=layer.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm1_weights=layer.w13_weight, + gemm1_weights_scale=layer.w13_weight_scale.view(torch.float8_e4m3fn), gemm1_bias=None, gemm1_alpha=None, gemm1_beta=None, gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), + gemm2_weights=layer.w2_weight, + gemm2_weights_scale=layer.w2_weight_scale.view(torch.float8_e4m3fn), gemm2_bias=None, output1_scale_scalar=layer.g1_scale_c, output1_scale_gate_scalar=layer.g1_alphas, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 3115bf0a5b6b..c24185bdf633 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1980,9 +1980,8 @@ def apply( ), f"{activation=} missing from {ACT_STR_TO_TYPE_MAP.keys()=}" moe_runner_config = self.moe_runner_config - # FlashInfer TRTLLM FP4 path - layer has shuffled weights only when - # backend is flashinfer_trtllm - if hasattr(layer, "gemm1_weights_fp4_shuffled"): + # FlashInfer TRTLLM FP4 path + if self.enable_flashinfer_trtllm_moe and hasattr(layer, "g1_scale_c"): from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( FlashInferTrtllmFp4MoeQuantInfo, ) @@ -1994,10 +1993,10 @@ def apply( ) quant_info = FlashInferTrtllmFp4MoeQuantInfo( - gemm1_weights_fp4_shuffled=layer.gemm1_weights_fp4_shuffled.data, - gemm2_weights_fp4_shuffled=layer.gemm2_weights_fp4_shuffled.data, - gemm1_scales_fp4_shuffled=layer.gemm1_scales_fp4_shuffled.data, - gemm2_scales_fp4_shuffled=layer.gemm2_scales_fp4_shuffled.data, + w13_weight=layer.w13_weight.data, + w2_weight=layer.w2_weight.data, + w13_weight_scale=layer.w13_weight_scale.data, + w2_weight_scale=layer.w2_weight_scale.data, g1_scale_c=layer.g1_scale_c.data, g1_alphas=layer.g1_alphas.data, g2_alphas=layer.g2_alphas.data, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 31fb5b5353da..037890443dd2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2780,8 +2780,9 @@ def _handle_moe_kernel_config(self): assert self.quantization in [ "fp8", "mxfp8", + "modelopt_fp4", None, - ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', or bfloat16 (None)." + ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8', 'mxfp8', 'modelopt_fp4', or bfloat16 (None)." self.disable_shared_experts_fusion = True logger.warning( "FlashInfer TRTLLM routed MoE is enabled. --disable-shared-experts-fusion is automatically set." diff --git a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py index b63447a60cd5..9c7260b19966 100644 --- a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py +++ b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py @@ -157,6 +157,49 @@ def test_gsm8k(self): self.assertGreater(metrics["score"], 0.93) +class FlashinferTrtllmGenMoeBackendNVFP4Base: + backend = None + + @classmethod + def setUpClass(cls): + cls.model = "nvidia/Qwen3-30B-A3B-NVFP4" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"}, + other_args=[ + "--moe-runner-backend", + cls.backend, + "--tp-size", + "4", + "--ep-size", + "4", + "--mem-fraction-static", + "0.7", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["score"], 0.89) + + class TestFlashinferTrtllmGenMoeBackendFP8( FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase ): @@ -175,6 +218,12 @@ class TestFlashinferTrtllmGenMoeBackendBF16( backend = "flashinfer_trtllm" +class TestFlashinferTrtllmGenMoeBackendNVFP4( + FlashinferTrtllmGenMoeBackendNVFP4Base, CustomTestCase +): + backend = "flashinfer_trtllm" + + class TestFlashinferTrtllmGenMoeBackendFP8Routed( FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase ): @@ -193,5 +242,11 @@ class TestFlashinferTrtllmGenMoeBackendBF16Routed( backend = "flashinfer_trtllm_routed" +class TestFlashinferTrtllmGenMoeBackendNVFP4Routed( + FlashinferTrtllmGenMoeBackendNVFP4Base, CustomTestCase +): + backend = "flashinfer_trtllm_routed" + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/rl/test_update_weights_from_disk_mxfp8.py b/test/registered/rl/test_update_weights_from_disk_blackwell.py similarity index 69% rename from test/registered/rl/test_update_weights_from_disk_mxfp8.py rename to test/registered/rl/test_update_weights_from_disk_blackwell.py index 12ac0758cdb6..58d79c375699 100644 --- a/test/registered/rl/test_update_weights_from_disk_mxfp8.py +++ b/test/registered/rl/test_update_weights_from_disk_blackwell.py @@ -15,37 +15,43 @@ ) -class TestServerUpdateWeightsFromDiskMXFP8(CustomTestCase): - model = "zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16" +class UpdateWeightsFromDiskBase: + model = None base_url = DEFAULT_URL_FOR_TEST request_timeout = 120 update_timeout = 240 + launch_env = None decode_payload = { "text": "The capital of France is", "sampling_params": {"temperature": 0, "max_new_tokens": 16}, } - backend_test_suites = ( - { - "fp8_gemm_backend": "flashinfer_trtllm", - "moe_runner_backend": "flashinfer_trtllm_routed", - }, + backend_test_suites = () + update_test_suites = ( + {"flush_cache": True, "abort_all_requests": False}, + {"flush_cache": False, "abort_all_requests": False}, ) - def _launch_server(self, fp8_gemm_backend, moe_runner_backend): + @classmethod + def setUpClass(cls): + super().setUpClass() + if cls.model is None: + raise NotImplementedError("Subclass must set 'model' attribute") + if not cls.backend_test_suites: + raise NotImplementedError( + "Subclass must set non-empty 'backend_test_suites'" + ) + + def _launch_server(self, backend_test_suite): + launch_kwargs = {} + if self.launch_env is not None: + launch_kwargs["env"] = self.launch_env + other_args = backend_test_suite.get("other_args") return popen_launch_server( self.model, self.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--base-gpu-id", - "0", - "--tp-size", - "4", - "--fp8-gemm-backend", - fp8_gemm_backend, - "--moe-runner-backend", - moe_runner_backend, - ], + other_args=other_args, + **launch_kwargs, ) def _get_json(self, endpoint, timeout=None): @@ -119,30 +125,19 @@ def _run_update_weights( timeout=self.update_timeout, ) - def test_parameterized_update_weights_mxfp8(self): - update_test_suites = ( - {"flush_cache": True, "abort_all_requests": False}, - {"flush_cache": False, "abort_all_requests": False}, - ) + def test_parameterized_update_weights_from_disk(self): for backend_test_suite in self.backend_test_suites: - with self.subTest(**backend_test_suite): - process = self._launch_server( - backend_test_suite["fp8_gemm_backend"], - backend_test_suite["moe_runner_backend"], - ) + case_name = backend_test_suite.get("name", "default") + with self.subTest(model=self.model, case_name=case_name): + process = self._launch_server(backend_test_suite) try: origin_model_path = self._get_model_info() self.assertEqual(origin_model_path, self.model) self._assert_non_empty_decode() baseline_sig = self._get_decode_logprob_signature() - for update_test_suite in update_test_suites: - with self.subTest( - fp8_gemm_backend=backend_test_suite["fp8_gemm_backend"], - moe_runner_backend=backend_test_suite["moe_runner_backend"], - flush_cache=update_test_suite["flush_cache"], - abort_all_requests=update_test_suite["abort_all_requests"], - ): + for update_test_suite in self.update_test_suites: + with self.subTest(case_name=case_name, **update_test_suite): ret = self._run_update_weights( self.model, flush_cache=update_test_suite["flush_cache"], @@ -161,5 +156,43 @@ def test_parameterized_update_weights_mxfp8(self): kill_process_tree(process.pid) +class TestServerUpdateWeightsFromDiskMXFP8(UpdateWeightsFromDiskBase, CustomTestCase): + model = "zianglih/Qwen3-30B-A3B-Instruct-2507-MXFP8-last-8-BF16" + backend_test_suites = ( + { + "name": "flashinfer_trtllm_routed_mxfp8", + "other_args": ( + "--base-gpu-id", + "0", + "--tp-size", + "4", + "--fp8-gemm-backend", + "flashinfer_trtllm", + "--moe-runner-backend", + "flashinfer_trtllm_routed", + ), + }, + ) + + +class TestServerUpdateWeightsFromDiskNVFP4(UpdateWeightsFromDiskBase, CustomTestCase): + model = "nvidia/Qwen3-30B-A3B-NVFP4" + backend_test_suites = ( + { + "name": "flashinfer_trtllm_nvfp4", + "other_args": ( + "--base-gpu-id", + "0", + "--tp-size", + "4", + "--fp4-gemm-backend", + "flashinfer_trtllm", + "--moe-runner-backend", + "flashinfer_trtllm_routed", + ), + }, + ) + + if __name__ == "__main__": unittest.main()