diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 70cfeb49aedc..f0257764383f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -68,6 +68,7 @@ def __init__( prefix: str = "", activation: str = "silu", routed_scaling_factor: Optional[float] = None, + **kwargs, ): super().__init__( num_experts=num_experts, @@ -81,6 +82,7 @@ def __init__( prefix=prefix, activation=activation, routed_scaling_factor=routed_scaling_factor, + **kwargs, ) if _use_aiter or _is_npu: diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 543bd877bd5d..1dae8be11066 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -36,6 +36,7 @@ StandardDispatchOutput, ) from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker +from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, @@ -56,7 +57,7 @@ ) if is_flashinfer_available(): - from flashinfer import RoutingMethodType, fp4_quantize + from flashinfer import fp4_quantize # Try to import FP4 TRTLLM function if flashinfer is available trtllm_fp4_block_scale_moe = None @@ -144,6 +145,7 @@ def __init__( gemm1_clamp_limit: Optional[float] = None, use_weight_loader_fused: bool = False, with_bias=False, + routing_method_type: Optional[RoutingMethodType] = None, ): super().__init__() if params_dtype is None: @@ -244,6 +246,8 @@ def __init__( and get_moe_runner_backend().is_cutlass() ) + self.routing_method_type = routing_method_type + def _load_per_tensor_weight_scale( self, shard_id: str, diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 3e902847d0cc..a7e9ce26fd47 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -2,7 +2,7 @@ import logging from contextlib import contextmanager -from enum import Enum +from enum import Enum, IntEnum from functools import lru_cache from typing import TYPE_CHECKING, Optional @@ -248,3 +248,22 @@ def speculative_moe_backend_context(): yield finally: MOE_RUNNER_BACKEND = original_backend + + +# The type of method in top-K routing, for use in torch custom op +# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = (0,) + # Renormalize: TopK -> Softmax + Renormalize = (1,) + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups + DeepSeekV3 = (2,) + # Llama4: Top1 -> Sigmoid + Llama4 = (3,) + # Qwen3: Softmax -> TopK -> Renormalize + RenormalizeNaive = (4,) + # TopK only (no softmax) + TopK = (5,) + # Unspecified + Unspecified = 6 diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 37f15ae2c4c3..e29706b84ca7 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1193,6 +1193,7 @@ def apply_with_router_logits( from flashinfer.fused_moe import trtllm_fp8_block_scale_moe from sglang.srt.layers.moe.topk import TopKOutputChecker + from sglang.srt.layers.moe.utils import RoutingMethodType assert TopKOutputChecker.format_is_bypassed(topk_output) router_logits = topk_output.router_logits @@ -1204,26 +1205,30 @@ def apply_with_router_logits( # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() - assert ( - topk_config.num_expert_group is not None - and topk_config.topk_group is not None - ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None" - correction_bias = ( None if topk_config.correction_bias is None else topk_config.correction_bias.to(x.dtype) ) + routing_method_type = getattr( + layer, "routing_method_type", RoutingMethodType.DeepSeekV3 + ) + with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): + # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. # It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 # so we put the whole function under the ``use_symmetric_memory`` context manager. # If the bug is fixed, we can only put the output tensor allocation under the context manager. return trtllm_fp8_block_scale_moe( - routing_logits=router_logits.to(torch.float32), + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), routing_bias=correction_bias, hidden_states=a_q, hidden_states_scale=a_sf_t, @@ -1244,7 +1249,7 @@ def apply_with_router_logits( tile_tokens_dim=get_tile_tokens_dim( x.shape[0], topk_config.top_k, layer.num_experts ), - routing_method_type=2, # DeepSeek-styled routing method + routing_method_type=routing_method_type, use_shuffled_weight=False, ) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 051095e61433..32a002682f84 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -57,6 +57,7 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -162,6 +163,7 @@ def __init__( intermediate_size=config.moe_intermediate_size, quant_config=quant_config, prefix=add_prefix("experts", prefix), + routing_method_type=RoutingMethodType.RenormalizeNaive, ) self.gate = ReplicatedLinear( diff --git a/test/srt/nightly/test_flashinfer_trtllm_gen_moe_backend.py b/test/srt/nightly/test_flashinfer_trtllm_gen_moe_backend.py new file mode 100644 index 000000000000..890f8fe9727c --- /dev/null +++ b/test/srt/nightly/test_flashinfer_trtllm_gen_moe_backend.py @@ -0,0 +1,65 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestFlashinferTrtllmGenMoeBackend(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"}, + other_args=[ + "--attention-backend", + "triton", + "--moe-runner-backend", + "flashinfer_trtllm", + "--cuda-graph-max-bs", + "512", + "--tp-size", + "4", + "--ep-size", + "4", + "--mem-fraction-static", + "0.7", + "--mamba-ssm-dtype", + "bfloat16", + "--quantization", + "fp8", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.93) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e64760f58c5d..71aef66c4622 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -219,6 +219,7 @@ class TestFile: TestFile("test_deepseek_v3_deterministic.py", 240), ], "nightly-4-gpu-b200": [ + TestFile("nightly/test_flashinfer_trtllm_gen_moe_backend.py", 300), TestFile("test_fp4_moe.py", 300), TestFile("nightly/test_gpt_oss_4gpu_perf.py", 600), TestFile("nightly/test_flashinfer_trtllm_gen_attn_backend.py", 300),