diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 53c832e1d1bd..5f77356da749 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -536,27 +536,22 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): if get_moe_a2a_backend().is_ascend_fuseep(): return NpuFuseEPMoE - # NEW: Direct FP4 detection (bypasses EP requirements) - # Check for FP4 quantization with TRTLLM flag, regardless of EP if get_moe_runner_backend().is_flashinfer_trtllm(): + # NEW: Direct FP4 detection (bypasses EP requirements) + # Check for FP4 quantization with TRTLLM flag, regardless of EP # FlashInferFP4MoE must be paired with ModelOptNvFp4FusedMoEMethod. - # If UnquantizedFusedMoEMethod is detected, fall back to FusedMoE instead. - if quant_config is None: - return FusedMoE - try: - # Check the quantization argument directly - if quant_config is not None and quant_config.get_name() == "modelopt_fp4": - from sglang.srt.layers.moe.fused_moe_triton.layer import ( - FlashInferFP4MoE, - ) - - return FlashInferFP4MoE - except: - pass + if quant_config is not None and quant_config.get_name() == "modelopt_fp4": + from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFP4MoE + + return FlashInferFP4MoE + elif ( + quant_config is None + or quant_config.get_name() == "fp8" + or quant_config.get_name() == "modelopt_fp8" + ): + # FlashInferFusedMoE support bf16 and fp8 + return FlashInferFusedMoE - if get_moe_runner_backend().is_flashinfer_trtllm() and quant_config is not None: - # FIXME: FlashInferFusedMoE only supports fp8 quant now - return FlashInferFusedMoE if get_moe_runner_backend().is_flashinfer_cutlass(): return FusedMoE return FusedMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 255e619c1fd6..3563d3e3d412 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -192,6 +192,7 @@ def __init__( self.use_presharded_weights = use_presharded_weights self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels() + self.use_flashinfer_trtllm_moe = get_moe_runner_backend().is_flashinfer_trtllm() self.quant_config = quant_config self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4() @@ -236,7 +237,9 @@ def __init__( if quant_config is not None: self.quant_method = quant_config.get_quant_method(self, prefix) if self.quant_method is None: - self.quant_method = UnquantizedFusedMoEMethod(self.use_triton_kernels) + self.quant_method = UnquantizedFusedMoEMethod( + self.use_triton_kernels, self.use_flashinfer_trtllm_moe + ) self.quant_method.create_weights( layer=self, @@ -640,9 +643,10 @@ def _weight_loader_impl( raise ValueError(f"shard_id must be ['w1','w2','w3'] but got {shard_id}.") # Flashinfer assumes w31 format for w13_weight. Same for the scales. - if get_moe_runner_backend().is_flashinfer_trtllm() and ( + if self.use_flashinfer_trtllm_moe and ( isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod) or isinstance(self.quant_method, Fp8MoEMethod) + or isinstance(self.quant_method, UnquantizedFusedMoEMethod) ): shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id] @@ -1036,29 +1040,66 @@ def __init__(self, *args, **kwargs): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): assert ( self.moe_runner_config.activation == "silu" - ), "Only silu is supported for flashinfer blockscale fp8 moe" + ), "Only silu is supported for flashinfer trtllm moe" assert self.quant_method is not None assert ( topk_output.topk_config.renormalize - ), "Renormalize is required for flashinfer blockscale fp8 moe" + ), "Renormalize is required for flashinfer trtllm moe" assert ( self.num_fused_shared_experts == 0 - ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" + ), "Fused shared experts are not supported for flashinfer trtllm moe" assert ( self.moe_runner_config.is_gated - ), "Only gated MoEs are supported for flashinfer blockscale fp8 moe" + ), "Only gated MoEs are supported for flashinfer trtllm moe" assert TopKOutputChecker.format_is_bypassed(topk_output) - # Matrix multiply. - final_hidden_states = self.quant_method.apply_with_router_logits( - layer=self, - dispatch_output=StandardDispatchOutput( - hidden_states=hidden_states, - hidden_states_scale=None, - topk_output=topk_output, - ), - ) + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + correction_bias = topk_config.correction_bias + + if isinstance(self.quant_method, UnquantizedFusedMoEMethod): + # lazy import + try: + from flashinfer.fused_moe import trtllm_bf16_moe + except ImportError as e: + raise ImportError( + "Can't import trtllm_bf16_moe from flashinfer. " + "Please check flashinfer version to use bf16 with flashinfer_trtllm backend." + ) from e + + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + # TODO: Now trtllm_bf16_moe doesn't support inplace output, + # we can move this out when it support that. + final_hidden_states = trtllm_bf16_moe( + routing_logits=router_logits, + routing_bias=correction_bias, + hidden_states=hidden_states, + gemm1_weights=self.w13_weight, + gemm2_weights=self.w2_weight, + num_experts=self.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.moe_ep_rank * self.num_local_experts, + local_num_experts=self.num_local_experts, + routing_method_type=self.routing_method_type, + ) + + else: + + # FP8 Matrix multiply. + final_hidden_states = self.quant_method.apply_with_router_logits( + layer=self, + dispatch_output=StandardDispatchOutput( + hidden_states=hidden_states, + hidden_states_scale=None, + topk_output=topk_output, + ), + ) # NOTE for symmetric memory tagging: # We do not create the context in this function. diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 67c65d5f3664..630b600687b4 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -146,11 +146,15 @@ def apply( class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" - def __init__(self, use_triton_kernels: bool = False): + def __init__( + self, use_triton_kernels: bool = False, use_flashinfer_trtllm_moe: bool = False + ): super().__init__() self.use_flashinfer_cutlass = get_moe_runner_backend().is_flashinfer_cutlass() self.use_triton_kernels = use_triton_kernels self.with_bias = False + self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe + self._cache_permute_indices = dict({}) def create_weights( self, @@ -227,6 +231,71 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if _is_cpu and _is_cpu_amx_available: _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + # Reorder rows of W1 for fused gated activation + if self.use_flashinfer_trtllm_moe: + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w3_w1_permute_indices, + convert_to_block_layout, + get_w2_permute_indices_with_cache, + ) + + # w1 and w3 have been swapped, so we don't need do that here + epilogue_tile_m = 128 + block_k = 128 + old_shape_w13 = layer.w13_weight.data[0].shape + old_shape_w2 = layer.w2_weight.data[0].shape + new_shape_w13 = None + new_shape_w2 = None + for i in range(layer.num_local_experts): + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + layer.w13_weight.data[i].view(torch.uint8), + epilogue_tile_m, + ) + tmp_weights1 = ( + layer.w13_weight.data[i] + .clone() + .view(torch.uint8)[permute_indices.to(layer.w13_weight.data.device)] + .contiguous() + ) + + permute_indices = get_w2_permute_indices_with_cache( + self._cache_permute_indices, + layer.w2_weight.data[i].view(torch.uint8), + epilogue_tile_m, + ) + tmp_weights2 = ( + layer.w2_weight.data[i] + .clone() + .view(torch.uint8)[permute_indices.to(layer.w2_weight.data.device)] + .contiguous() + ) + + tmp_weights1 = convert_to_block_layout( + tmp_weights1.view(torch.uint8), block_k + ) + tmp_weights2 = convert_to_block_layout( + tmp_weights2.view(torch.uint8), block_k + ) + + new_shape_w13 = tmp_weights1.view(torch.bfloat16).shape + new_shape_w2 = tmp_weights2.view(torch.bfloat16).shape + layer.w13_weight.data[i] = ( + tmp_weights1.view(torch.bfloat16) + .contiguous() + .reshape(old_shape_w13) + ) + layer.w2_weight.data[i] = ( + tmp_weights2.view(torch.bfloat16).contiguous().reshape(old_shape_w2) + ) + + layer.w13_weight.data = layer.w13_weight.data.reshape( + layer.num_local_experts, *new_shape_w13 + ) + layer.w2_weight.data = layer.w2_weight.data.reshape( + layer.num_local_experts, *new_shape_w2 + ) + return def create_moe_runner( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2b1c3c04d6cf..24477e69f026 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -679,8 +679,13 @@ def __init__( apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk, fused_shared_experts_scaling_factor=fused_shared_experts_scaling_factor, # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized - # and requires the output format to be standard. We use quant_config to determine the output format. - output_format=TopKOutputFormat.STANDARD if quant_config is None else None, + # and requires the output format to be standard (except trtllm). We use quant_config to determine the output format. + output_format=( + TopKOutputFormat.STANDARD + if (quant_config is None) + and (not get_moe_runner_backend().is_flashinfer_trtllm()) + else None + ), ) self.shared_experts_is_int8 = False diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a0a943f6dd05..203728bb07cc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1591,11 +1591,12 @@ def _handle_moe_kernel_config(self): ], "The expert parallel size must be 1 or the same as the tensor parallel size" if self.moe_runner_backend == "flashinfer_trtllm": - assert ( - self.quantization == "modelopt_fp4" - or self.quantization == "modelopt_fp8" - or self.quantization == "fp8" - ), "modelopt_fp4, modelopt_fp8 or fp8 quantization is required for Flashinfer TRTLLM MoE" + assert self.quantization in [ + "modelopt_fp4", + "fp8", + "modelopt_fp8", + None, + ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM MOE supports only: 'modelopt_fp4', 'fp8', 'modelopt_fp8', or bfloat16 (None)." self.disable_shared_experts_fusion = True logger.warning( "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." diff --git a/test/nightly/test_flashinfer_trtllm_gen_moe_backend.py b/test/nightly/test_flashinfer_trtllm_gen_moe_backend.py index 6f29dbc83684..63db2b2ad1cc 100644 --- a/test/nightly/test_flashinfer_trtllm_gen_moe_backend.py +++ b/test/nightly/test_flashinfer_trtllm_gen_moe_backend.py @@ -15,7 +15,7 @@ register_cuda_ci(est_time=300, suite="nightly-4-gpu-b200", nightly=True) -class TestFlashinferTrtllmGenMoeBackend(CustomTestCase): +class TestFlashinferTrtllmGenMoeBackendFP8(CustomTestCase): @classmethod def setUpClass(cls): cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8" @@ -60,5 +60,51 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.93) +class TestFlashinferTrtllmGenMoeBackendBF16(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "triton", + "--moe-runner-backend", + "flashinfer_trtllm", + "--cuda-graph-max-bs", + "512", + "--tp-size", + "4", + "--ep-size", + "4", + "--mem-fraction-static", + "0.7", + "--mamba-ssm-dtype", + "bfloat16", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.93) + + if __name__ == "__main__": unittest.main()