diff --git a/python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py b/python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py index c38126679638..04d1b112db4a 100644 --- a/python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py +++ b/python/sglang/srt/layers/moe/flashinfer_trtllm_moe.py @@ -28,6 +28,7 @@ def _fake_fp8_block_scale_moe( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Optional[int] = None, + num_fused_shared_experts: Optional[int] = None, ) -> torch.Tensor: return torch.empty( hidden_states.shape, dtype=torch.bfloat16, device=hidden_states.device @@ -58,6 +59,7 @@ def trtllm_fp8_block_scale_moe_wrapper( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Optional[int] = None, + num_fused_shared_experts: Optional[int] = None, ) -> torch.Tensor: try: from flashinfer.fused_moe import trtllm_fp8_block_scale_moe @@ -93,7 +95,15 @@ def trtllm_fp8_block_scale_moe_wrapper( from flashinfer.fused_moe import Fp8QuantizationType kwargs["fp8_quantization_type"] = Fp8QuantizationType(fp8_quantization_type) + if num_fused_shared_experts is not None and num_fused_shared_experts > 0: + kwargs["num_fused_shared_experts"] = num_fused_shared_experts + from sglang.srt.utils import print_warning_once + + print_warning_once( + f"[trtllm_fp8_block_scale_moe] num_fused_shared_experts={kwargs.get('num_fused_shared_experts', 'NOT SET')}, " + f"top_k={kwargs['top_k']}, num_experts={kwargs['num_experts']}" + ) return trtllm_fp8_block_scale_moe(**kwargs) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 80735f3da2c7..0b343fdadc13 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1187,9 +1187,6 @@ def forward_impl(self, hidden_states: torch.Tensor, topk_output: TopKOutput): assert ( topk_output.topk_config.renormalize ), "Renormalize is required for flashinfer trtllm moe" - assert ( - self.num_fused_shared_experts == 0 - ), "Fused shared experts are not supported for flashinfer trtllm moe" assert ( self.moe_runner_config.is_gated ), "Only gated MoEs are supported for flashinfer trtllm moe" diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 7914265dd3b1..a6427faf18ba 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -419,6 +419,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( else: assert TopKOutputChecker.format_is_bypassed(topk_output) + _nfse = runner_config.num_fused_shared_experts or 0 output = _trtllm_fp8_block_scale_moe_wrapper( routing_logits=( router_logits.to(torch.float32) @@ -433,7 +434,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( gemm2_weights=quant_info.w2_weight, gemm2_weights_scale=quant_info.w2_weight_scale_inv, num_experts=quant_info.global_num_experts, - top_k=topk_config.top_k, + top_k=topk_config.top_k - _nfse, n_group=topk_config.num_expert_group, topk_group=topk_config.topk_group, intermediate_size=quant_info.intermediate_size, @@ -448,6 +449,7 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( use_shuffled_weight=use_shuffled_weight, tune_max_num_tokens=next_power_of_2(a_q.shape[0]), fp8_quantization_type=int(fp8_quantization_type), + num_fused_shared_experts=_nfse if _nfse > 0 else None, ) symm_output.copy_(output) output = symm_output diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 98716292297e..ad726f831c5c 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1588,8 +1588,9 @@ def apply( # router logits directly (no separate apply_with_router_logits needed). # FlashInfer TRT-LLM routed backend consumes SGLang-computed # top-k ids/weights (packed into int32) instead of router logits. - global_num_experts = int(getattr(layer, "num_experts")) - num_local_experts = int(getattr(layer, "num_local_experts")) + _nfse = int(getattr(layer, "num_fused_shared_experts", 0)) + global_num_experts = int(getattr(layer, "num_experts")) - _nfse + num_local_experts = int(getattr(layer, "num_local_experts")) - _nfse moe_ep_rank = int(getattr(layer, "moe_ep_rank")) quant_info = FlashInferTrtllmFp8MoeQuantInfo( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0d88f541a1dd..33af8b2ca741 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2152,6 +2152,8 @@ def determine_num_fused_shared_experts( ) elif get_moe_expert_parallel_world_size() > 1 and ( not _is_hip or torch.cuda.get_device_capability("cuda") < (9, 4) + ) and not ( + _is_cuda and get_moe_runner_backend().is_flashinfer_trtllm() ): disable_reason = "Only Deepseek V3/R1 on AMD-platform with capability >= gfx942(MI30x) can use shared experts fusion optimization under expert parallelism." elif disable_reason is None and ( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a0f0704f3469..532430321041 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2630,10 +2630,14 @@ def _handle_moe_kernel_config(self): "compressed-tensors", None, ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', 'modelopt_mixed', 'compressed-tensors', or bfloat16 (None)." - self.disable_shared_experts_fusion = True - logger.warning( - "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." - ) + # FP8 block-scale (fp8, mxfp8) supports fused shared experts via + # num_fused_shared_experts in trtllm_fp8_block_scale_moe; other + # quant types (BF16, FP4, per-tensor FP8) do not. + if self.quantization not in ["fp8", "mxfp8"]: + self.disable_shared_experts_fusion = True + logger.warning( + "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." + ) if self.moe_runner_backend == "flashinfer_trtllm_routed": assert self.quantization in [