diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index ef2225eccd14..1359c8313905 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -795,24 +795,4 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): if get_moe_a2a_backend().is_ascend_fuseep(): return NpuFuseEPMoE - if get_moe_runner_backend().is_flashinfer_trtllm(): - # NEW: Direct FP4 detection (bypasses EP requirements) - # Check for FP4 quantization with TRTLLM flag, regardless of EP - # FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod. - if quant_config is not None and quant_config.get_name() == "modelopt_fp4": - from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE - - return FlashInferFP4MoE - elif ( - quant_config is None - or quant_config.get_name() == "fp8" - or quant_config.get_name() == "mxfp8" - or quant_config.get_name() == "modelopt_fp8" - or quant_config.get_name() == "compressed_tensors" - ): - # FlashInferFusedMoE supports bf16, fp8, mxfp8 and compressed_tensors - return FusedMoE - - if get_moe_runner_backend().is_flashinfer_cutlass(): - return FusedMoE return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 43ad055a2716..f3fc5a544f99 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -40,7 +40,6 @@ from sglang.srt.layers.moe.token_dispatcher.flashinfer import FlashinferDispatcher from sglang.srt.layers.moe.token_dispatcher.standard import ( StandardDispatcher, - StandardDispatchOutput, ) from sglang.srt.layers.moe.topk import ( BypassedTopKOutput, @@ -66,16 +65,11 @@ cpu_has_amx_support, get_bool_env_var, is_cpu, - is_flashinfer_available, is_hip, - next_power_of_2, round_up, ) from sglang.srt.utils.custom_op import register_custom_op -if is_flashinfer_available(): - from flashinfer import fp4_quantize - _is_hip = is_hip() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() @@ -1146,267 +1140,6 @@ def clear_overlap_args(self) -> None: self.meta_overlap_args = None -class FlashInferFusedMoE(FusedMoE): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - assert TopKOutputChecker.format_is_bypassed( - topk_output - ), "Only bypassed topk output is supported for flashinfer trtllm moe" - - if is_in_piecewise_cuda_graph(): - return flashinfer_bf16_moe_forward_piecewise_cuda_graph_impl( - hidden_states, - topk_output.router_logits, - topk_output.topk_config.top_k, - topk_output.topk_config.topk_group, - topk_output.topk_config.num_expert_group, - topk_output.topk_config.correction_bias, - topk_output.topk_config.renormalize, - self.layer_id, - ) - else: - return self.forward_impl(hidden_states, topk_output) - - def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - assert ( - self.moe_runner_config.activation == "silu" - ), "Only silu is supported for flashinfer trtllm moe" - assert self.quant_method is not None - assert ( - topk_output.topk_config.renormalize - ), "Renormalize is required for flashinfer trtllm moe" - assert ( - self.num_fused_shared_experts == 0 - ), "Fused shared experts are not supported for flashinfer trtllm moe" - assert ( - self.moe_runner_config.is_gated - ), "Only gated MoEs are supported for flashinfer trtllm moe" - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - correction_bias = topk_config.correction_bias - routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - - if isinstance(self.quant_method, UnquantizedFusedMoEMethod): - # lazy import - try: - from flashinfer.fused_moe import trtllm_bf16_moe - except ImportError as e: - raise ImportError( - "Can't import trtllm_bf16_moe from flashinfer. " - "Please check flashinfer version to use bf16 with flashinfer_trtllm backend." - ) from e - - # Allocate output inside symmetric memory context - with use_symmetric_memory( - get_tp_group(), disabled=not is_allocation_symmetric() - ): - # TODO: Now trtllm_bf16_moe doesn't support inplace output, - # we can move this out when it support that. - symm_output = torch.empty( - hidden_states.shape[0], - hidden_states.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - - # Move kernel call outside context manager to avoid graph breaks - # during torch.compile for piecewise cuda graph - moe_result = trtllm_bf16_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hidden_states, - gemm1_weights=self.w13_weight, - gemm2_weights=self.w2_weight, - num_experts=self.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=self.intermediate_size_per_partition, - local_expert_offset=self.moe_ep_rank * self.num_local_experts, - local_num_experts=self.num_local_experts, - routing_method_type=self.routing_method_type, - tune_max_num_tokens=next_power_of_2(hidden_states.shape[0]), - ) - # Copy result to symmetric memory output - symm_output.copy_(moe_result) - final_hidden_states = symm_output - - else: - - final_hidden_states = self.quant_method.apply( - layer=self, - dispatch_output=StandardDispatchOutput( - hidden_states=hidden_states, - hidden_states_scale=None, - topk_output=topk_output, - ), - ).hidden_states - - # NOTE for symmetric memory tagging: - # We do not create the context in this function. - # Instead, we create the context and tagging inside each FusedMoEMethodBase - # This can allow fine-grained tagging. - - if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states - - -class FlashInferFP4MoE(FusedMoE): - """FP4 TRTLLM MoE implementation using FlashInfer.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # --------------------------------------------------------------------- - # Helper: quantize hidden states to FP4 each forward pass - # --------------------------------------------------------------------- - def _quantize_hidden_states_fp4(self, hidden_states: torch.Tensor): - """ - Quantize hidden states using global scale factor from quantization method. - - Global scale factor is set by ModelOptNvFp4FusedMoEMethod during weight loading. - Only block scales are computed at runtime for efficiency. - - Returns (packed_fp4_uint8, scale_float8_e4m3fn_runtime, global_scale_float32) - """ - - # flashinfer.fp4_quantize returns (packed_uint8, scale_fp8) - # Only the block scales are computed at runtime - hs_fp4_bytes, hs_sf_bytes = fp4_quantize( - hidden_states, - self.w13_input_scale_quant, - 16, # sf_vec_size - False, # use_ue8m0 - False, # is_sf_swizzled_layout - ) - - seq_len, hidden_size = hidden_states.shape - hs_fp4 = hs_fp4_bytes.reshape(seq_len, hidden_size // 2) - # TRT-LLM expects hidden state scales shaped as [seq_len, hidden_size // 16] - hs_sf = hs_sf_bytes.view(torch.float8_e4m3fn).reshape( - seq_len, hidden_size // 16 - ) - - return hs_fp4, hs_sf - - def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - assert TopKOutputChecker.format_is_bypassed( - topk_output - ), "Only bypassed topk output is supported for flashinfer fp4 moe" - - if is_in_piecewise_cuda_graph(): - return flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl( - hidden_states, - topk_output.router_logits, - topk_output.topk_config.top_k, - topk_output.topk_config.topk_group, - topk_output.topk_config.num_expert_group, - topk_output.topk_config.correction_bias, - self.layer_id, - ) - else: - return self.forward_impl(hidden_states, topk_output) - - def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput): - """Forward pass using FP4 TRTLLM kernel. - - Args: - hidden_states: Input tensor - topk_output: TopKOutput object with Bypassed format - """ - from flashinfer.fused_moe import trtllm_fp4_block_scale_moe - - assert isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) - - assert ( - self.moe_runner_config.is_gated - ), "Only gated MoEs are supported for flashinfer fp4 moe" - - assert TopKOutputChecker.format_is_bypassed(topk_output) - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - - hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states) - routing_method_type = self.routing_method_type - assert ( - routing_method_type is not None - ), "flashinfer trtllm moe nvfp4 backend has not been adapted for the current moe layer, you can set routing_method_type (See definition of RoutingMethodType please) for the moe layer explicitly for a quick adaptation." - - # DeepSeekV3 style routing requires float32 router logits, - # see this PR for details: https://github.com/flashinfer-ai/flashinfer/commit/d84e1d560da0a27961c19ca788d96c19cb9dcfb6 - if routing_method_type == RoutingMethodType.DeepSeekV3: - router_logits = router_logits.to(torch.float32) - - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(hidden_states.dtype) - ) - - with use_symmetric_memory( - get_tp_group(), disabled=not is_allocation_symmetric() - ): - num_tokens = hs_fp4.shape[0] - hidden_size = ( - hs_fp4.shape[-1] * 2 - if hs_fp4.dtype == torch.uint8 - else hs_fp4.shape[-1] - ) - symm_output = torch.empty( - num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device - ) - result = trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=correction_bias, - hidden_states=hs_fp4, - hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).reshape( - *hs_scale_linear.shape[:-1], -1 - ), - gemm1_weights=self.gemm1_weights_fp4_shuffled.data, - gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=self.gemm2_weights_fp4_shuffled.data, - gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view( - torch.float8_e4m3fn - ), - gemm2_bias=None, - output1_scale_scalar=self.g1_scale_c.data, - output1_scale_gate_scalar=self.g1_alphas.data, - output2_scale_scalar=self.g2_alphas.data, - num_experts=self.num_experts, - top_k=topk_config.top_k, - n_group=topk_config.num_expert_group, - topk_group=topk_config.topk_group, - intermediate_size=self.intermediate_size_per_partition, - local_expert_offset=self.moe_ep_rank * self.num_local_experts, - local_num_experts=self.num_local_experts, - routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, - # Respect the routing method configured for this layer (e.g., Renormalize for Qwen3), - # instead of always assuming DeepSeekV3. - routing_method_type=( - self.routing_method_type - if self.routing_method_type is not None - else RoutingMethodType.Default - ), - do_finalize=True, - tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), - output=symm_output, - )[0] - - return result - - @register_custom_op(out_shape="hidden_states") def moe_forward_piecewise_cuda_graph_impl( hidden_states: torch.Tensor, @@ -1449,55 +1182,3 @@ def fused_moe_bypassed_piecewise_cuda_graph_impl( forward_context = get_forward_context() moe_layer = forward_context.moe_layers[layer_id] return moe_layer.forward_impl(hidden_states, topk_output) - - -@register_custom_op(out_shape="hidden_states") -def flashinfer_bf16_moe_forward_piecewise_cuda_graph_impl( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - topk_group: Optional[int], - num_expert_group: Optional[int], - correction_bias: Optional[torch.Tensor], - renormalize: bool, - layer_id: int, -) -> torch.Tensor: - topk_output = BypassedTopKOutput( - hidden_states=hidden_states, - router_logits=router_logits, - topk_config=TopKConfig( - top_k=top_k, - topk_group=topk_group, - num_expert_group=num_expert_group, - correction_bias=correction_bias, - renormalize=renormalize, - ), - ) - forward_context = get_forward_context() - moe_layer = forward_context.moe_layers[layer_id] - return moe_layer.forward_impl(hidden_states, topk_output) - - -@register_custom_op(out_shape="hidden_states") -def flashinfer_fp4_moe_forward_piecewise_cuda_graph_impl( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - topk_group: Optional[int], - num_expert_group: Optional[int], - correction_bias: Optional[torch.Tensor], - layer_id: int, -) -> torch.Tensor: - topk_output = BypassedTopKOutput( - hidden_states=hidden_states, - router_logits=router_logits, - topk_config=TopKConfig( - top_k=top_k, - topk_group=topk_group, - num_expert_group=num_expert_group, - correction_bias=correction_bias, - ), - ) - forward_context = get_forward_context() - moe_layer = forward_context.moe_layers[layer_id] - return moe_layer.forward_impl(hidden_states, topk_output) diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 68decf875ba7..5f95bf20f425 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -31,7 +31,6 @@ from sglang.srt.utils.common import ( is_cuda_alike, is_flashinfer_available, - is_sm120_supported, next_power_of_2, ) @@ -41,7 +40,7 @@ StandardDispatchOutput, ) -if is_flashinfer_available() and is_sm120_supported(): +if is_flashinfer_available(): from flashinfer import fp4_quantize elif is_cuda_alike(): from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as fp4_quantize @@ -564,7 +563,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( """FlashInfer TRTLLM FP4 MoE forward pass. This function handles the FP4 TRTLLM MoE path that was previously in - FlashInferFP4MoE.forward_impl and ModelOptNvFp4FusedMoEMethod.apply. + ModelOptNvFp4FusedMoEMethod.apply. """ from flashinfer.fused_moe import trtllm_fp4_block_scale_moe @@ -638,7 +637,6 @@ def fused_experts_none_to_flashinfer_trtllm_fp4( local_expert_offset=quant_info.local_expert_offset, local_num_experts=quant_info.local_num_experts, routed_scaling_factor=runner_config.routed_scaling_factor, - tile_tokens_dim=None, routing_method_type=( routing_method_type if routing_method_type is not None diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 0bd87934625c..c0d9958e45ee 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1273,7 +1273,7 @@ def get_quant_method(self, layer: torch.nn.Module, prefix: str): layer, prefix, Linear=ModelOptFp4LinearMethod, - Moe=ModelOptNvFp4FusedMoEMethod, # FlashInferFP4MoE needs the same quantization method but with compatible attribute handling + Moe=ModelOptNvFp4FusedMoEMethod, )