diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 176311eaf284..1b876b3a9c26 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -461,123 +461,114 @@ def apply( dispatch_output: StandardDispatchOutput, ) -> CombineInput: - from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output - topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids - - output = cutlass_moe_fp4( - a=x, - a1_gscale=layer.w13_input_scale_quant, - w1_fp4=layer.w13_weight, - w1_blockscale=layer.w13_weight_scale, - w1_alphas=layer.g1_alphas, - a2_gscale=layer.w2_input_scale_quant, - w2_fp4=layer.w2_weight, - w2_blockscale=layer.w2_weight_scale, - w2_alphas=layer.g2_alphas, - topk_weights=topk_weights, - topk_ids=topk_ids, - params=layer.cutlass_moe_params, - apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, - ).to(x.dtype) - return StandardCombineInput(hidden_states=output) - - def apply_with_router_logits( - self, - layer: torch.nn.Module, - dispatch_output: StandardDispatchOutput, - ) -> torch.Tensor: - assert self.use_flashinfer_trtllm - - x = dispatch_output.hidden_states - topk_output = dispatch_output.topk_output - - from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe + if self.use_flashinfer_trtllm: + from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe - from sglang.srt.layers.moe.utils import RoutingMethodType + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config + # Quantize input hidden states using fp4_quantize + hs_fp4_bytes, hs_sf_bytes = fp4_quantize( + x, + layer.w13_input_scale_quant, + self.group_size, # sf_vec_size + False, # use_ue8m0 + False, # is_sf_swizzled_layout + ) + hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2) + hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1) - # Quantize input hidden states using fp4_quantize - hs_fp4_bytes, hs_sf_bytes = fp4_quantize( - x, - layer.w13_input_scale_quant, - self.group_size, # sf_vec_size - False, # use_ue8m0 - False, # is_sf_swizzled_layout - ) - hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2) - hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1) + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(x.dtype) + ) - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(x.dtype) - ) + assert layer.routing_method_type is not None - assert layer.routing_method_type is not None + # DeepSeekV3 style routing requires float32 router logits + if layer.routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) - # DeepSeekV3 style routing requires float32 router logits - if layer.routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) + routed_scaling_factor = self.moe_runner_config.routed_scaling_factor + routed_scaling_factor = ( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ) - routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - routed_scaling_factor = ( - routed_scaling_factor if routed_scaling_factor is not None else 1.0 - ) + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + num_tokens = hs_fp4.shape[0] + hidden_size = ( + hs_fp4.shape[-1] * 2 + if hs_fp4.dtype == torch.uint8 + else hs_fp4.shape[-1] + ) + symm_output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device + ) - with use_symmetric_memory( - get_tp_group(), disabled=not is_allocation_symmetric() - ): - num_tokens = hs_fp4.shape[0] - hidden_size = ( - hs_fp4.shape[-1] * 2 - if hs_fp4.dtype == torch.uint8 - else hs_fp4.shape[-1] - ) - symm_output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device - ) + output = trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=correction_bias, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale, + gemm1_weights=layer.gemm1_weights_fp4_shuffled, + gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=layer.gemm2_weights_fp4_shuffled, + gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=layer.g1_scale_c, + output1_scale_gate_scalar=layer.g1_alphas, + output2_scale_scalar=layer.g2_alphas, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=layer.routing_method_type, + do_finalize=True, + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + output=symm_output, + )[0] + else: + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 + + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + output = cutlass_moe_fp4( + a=x, + a1_gscale=layer.w13_input_scale_quant, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_weight_scale, + w1_alphas=layer.g1_alphas, + a2_gscale=layer.w2_input_scale_quant, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_weight_scale, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + params=layer.cutlass_moe_params, + apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, + ).to(x.dtype) - return trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hs_fp4, - hidden_states_scale=hs_scale, - gemm1_weights=layer.gemm1_weights_fp4_shuffled, - gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=layer.gemm2_weights_fp4_shuffled, - gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view( - torch.float8_e4m3fn - ), - gemm2_bias=None, - output1_scale_scalar=layer.g1_scale_c, - output1_scale_gate_scalar=layer.g1_alphas, - output2_scale_scalar=layer.g2_alphas, - num_experts=layer.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, - local_num_experts=layer.num_local_experts, - routed_scaling_factor=routed_scaling_factor, - routing_method_type=layer.routing_method_type, - do_finalize=True, - tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), - output=symm_output, - )[0] + return StandardCombineInput(hidden_states=output) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): diff --git a/test/registered/8-gpu-models/test_mistral_large3.py b/test/registered/8-gpu-models/test_mistral_large3.py index 3892399823a4..c010d1f629d6 100644 --- a/test/registered/8-gpu-models/test_mistral_large3.py +++ b/test/registered/8-gpu-models/test_mistral_large3.py @@ -9,9 +9,10 @@ # Runs on both H200 and B200 via nightly-8-gpu-common suite # Note: trtllm_mla backend may have hardware-specific behavior -register_cuda_ci(est_time=1800, suite="nightly-8-gpu-common", nightly=True) +register_cuda_ci(est_time=3000, suite="nightly-8-gpu-common", nightly=True) -MISTRAL_LARGE3_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512" +MISTRAL_LARGE3_FP8_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512" +MISTRAL_LARGE3_NVFP4_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4" MISTRAL_LARGE3_EAGLE_MODEL_PATH = "mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle" @@ -19,9 +20,10 @@ class TestMistralLarge3(unittest.TestCase): """Unified test class for Mistral-Large-3 performance and accuracy. - Two variants: - - basic: TP=8 + trtllm_mla backend + Three variants: + - basic: FP8 model + TP=8 + trtllm_mla backend - eagle: basic + EAGLE speculative decoding with draft model + - nvfp4: NVFP4 model + TP=8 + trtllm_mla backend Each variant runs BOTH: - Performance test (using NightlyBenchmarkRunner) @@ -56,22 +58,33 @@ def test_mistral_large3_all_variants(self): "--speculative-num-draft-tokens=4", "--kv-cache-dtype=auto", ] + # TODO: add this to base args when FP8 TRTLLM moe is supported + nvfp4_args = [ + "--moe-runner-backend=flashinfer_trtllm", + ] variants = [ - # Variant: "basic" - TP=8 + trtllm_mla backend + # Variant: "basic" - FP8 model + TP=8 + trtllm_mla backend ModelLaunchSettings( - MISTRAL_LARGE3_MODEL_PATH, + MISTRAL_LARGE3_FP8_MODEL_PATH, tp_size=8, extra_args=base_args, variant="TP8", ), - # Variant: "eagle" - TP=8 + trtllm_mla + EAGLE with draft model + # Variant: "eagle" - FP8 model + TP=8 + trtllm_mla + EAGLE with draft model ModelLaunchSettings( - MISTRAL_LARGE3_MODEL_PATH, + MISTRAL_LARGE3_FP8_MODEL_PATH, tp_size=8, extra_args=base_args + eagle_args, variant="TP8+MTP", ), + # Variant: "nvfp4" - NVFP4 model + TP=8 + trtllm_mla backend + ModelLaunchSettings( + MISTRAL_LARGE3_NVFP4_MODEL_PATH, + tp_size=8, + extra_args=base_args + nvfp4_args, + variant="NVFP4", + ), ] run_combined_tests(