From 1dd327e28b84efbc1d39ca7066f584d964faf9bc Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 3 Nov 2025 08:05:44 +0000 Subject: [PATCH 1/6] Enable flashinfer-trtllm-gen-moe fp8 blockwise backend for models like Qwen3-Next --- python/sglang/srt/layers/moe/ep_moe/layer.py | 2 ++ .../srt/layers/moe/fused_moe_triton/layer.py | 3 +++ python/sglang/srt/layers/quantization/fp8.py | 14 +++++++------- python/sglang/srt/models/qwen2_moe.py | 6 +++++- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 81b91e50d136..63278eb6b161 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -69,6 +69,7 @@ def __init__( prefix: str = "", activation: str = "silu", routed_scaling_factor: Optional[float] = None, + **kwargs, ): super().__init__( num_experts=num_experts, @@ -82,6 +83,7 @@ def __init__( prefix=prefix, activation=activation, routed_scaling_factor=routed_scaling_factor, + **kwargs, ) if _use_aiter or _is_npu: diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 5b527a386e07..78771de6fd86 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -142,6 +142,7 @@ def __init__( gemm1_clamp_limit: Optional[float] = None, use_weight_loader_fused: bool = False, with_bias=False, + routing_method_type: Optional[RoutingMethodType] = None, ): super().__init__() if params_dtype is None: @@ -235,6 +236,8 @@ def __init__( and self.quant_method._should_use_cutlass_fused_experts() ) + self.routing_method_type = routing_method_type + def _load_per_tensor_weight_scale( self, shard_id: str, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 91b54e1257ea..9487c27a3125 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1187,6 +1187,8 @@ def apply_with_router_logits( activation = self.moe_runner_config.activation routed_scaling_factor = self.moe_runner_config.routed_scaling_factor + + from flashinfer import RoutingMethodType from flashinfer.fused_moe import trtllm_fp8_block_scale_moe from sglang.srt.layers.moe.topk import TopKOutputChecker @@ -1201,19 +1203,17 @@ def apply_with_router_logits( # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() - assert ( - topk_config.num_expert_group is not None - and topk_config.topk_group is not None - ), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None" - correction_bias = ( None if topk_config.correction_bias is None else topk_config.correction_bias.to(x.dtype) ) + routing_method_type = getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3) + + return trtllm_fp8_block_scale_moe( - routing_logits=router_logits.to(torch.float32), + routing_logits=router_logits.to(torch.float32) if routing_method_type == RoutingMethodType.DeepSeekV3 else router_logits, routing_bias=correction_bias, hidden_states=a_q, hidden_states_scale=a_sf_t, @@ -1234,7 +1234,7 @@ def apply_with_router_logits( tile_tokens_dim=get_tile_tokens_dim( x.shape[0], topk_config.top_k, layer.num_experts ), - routing_method_type=2, # DeepSeek-styled routing method + routing_method_type=routing_method_type, use_shuffled_weight=False, ) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 051095e61433..5ae3812ad128 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -70,7 +70,10 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.server_args import get_global_server_args from sglang.srt.two_batch_overlap import model_forward_maybe_tbo -from sglang.srt.utils import add_prefix, is_cuda, make_layers +from sglang.srt.utils import add_prefix, is_cuda, make_layers, is_flashinfer_available + +if is_flashinfer_available(): + from flashinfer import RoutingMethodType logger = logging.getLogger(__name__) @@ -162,6 +165,7 @@ def __init__( intermediate_size=config.moe_intermediate_size, quant_config=quant_config, prefix=add_prefix("experts", prefix), + routing_method_type=RoutingMethodType.RenormalizeNaive if is_flashinfer_available() else None, ) self.gate = ReplicatedLinear( From b5a0207402be929ccc08fa9a5e319ad3d8c8c255 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 3 Nov 2025 00:12:10 -0800 Subject: [PATCH 2/6] reformat --- python/sglang/srt/layers/quantization/fp8.py | 16 ++++++++++------ python/sglang/srt/models/qwen2_moe.py | 8 ++++++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 1f638ed9f8d6..8d0bb863fc79 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1186,7 +1186,6 @@ def apply_with_router_logits( activation = self.moe_runner_config.activation routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - from flashinfer import RoutingMethodType from flashinfer.fused_moe import trtllm_fp8_block_scale_moe @@ -1208,11 +1207,16 @@ def apply_with_router_logits( else topk_config.correction_bias.to(x.dtype) ) - routing_method_type = getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3) - - + routing_method_type = getattr( + layer, "routing_method_type", RoutingMethodType.DeepSeekV3 + ) + return trtllm_fp8_block_scale_moe( - routing_logits=router_logits.to(torch.float32) if routing_method_type == RoutingMethodType.DeepSeekV3 else router_logits, + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), routing_bias=correction_bias, hidden_states=a_q, hidden_states_scale=a_sf_t, @@ -1233,7 +1237,7 @@ def apply_with_router_logits( tile_tokens_dim=get_tile_tokens_dim( x.shape[0], topk_config.top_k, layer.num_experts ), - routing_method_type=routing_method_type, + routing_method_type=routing_method_type, use_shuffled_weight=False, ) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 5ae3812ad128..83bd7e76b5f6 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -70,7 +70,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.server_args import get_global_server_args from sglang.srt.two_batch_overlap import model_forward_maybe_tbo -from sglang.srt.utils import add_prefix, is_cuda, make_layers, is_flashinfer_available +from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers if is_flashinfer_available(): from flashinfer import RoutingMethodType @@ -165,7 +165,11 @@ def __init__( intermediate_size=config.moe_intermediate_size, quant_config=quant_config, prefix=add_prefix("experts", prefix), - routing_method_type=RoutingMethodType.RenormalizeNaive if is_flashinfer_available() else None, + routing_method_type=( + RoutingMethodType.RenormalizeNaive + if is_flashinfer_available() + else None + ), ) self.gate = ReplicatedLinear( From 836dc27a61ac48eeced8fca00bd65f41a7839fb9 Mon Sep 17 00:00:00 2001 From: Sam Date: Sat, 8 Nov 2025 13:15:38 +0000 Subject: [PATCH 3/6] Add unit test for trtllm-gen-moe --- test/srt/run_suite.py | 1 + .../test_flashinfer_trtllm_gen_moe_backend.py | 65 +++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 test/srt/test_flashinfer_trtllm_gen_moe_backend.py diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1e9df7340997..562c94283492 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -178,6 +178,7 @@ class TestFile: TestFile("test_flash_attention_4.py", 300), TestFile("test_gpt_oss_4gpu.py", 600), TestFile("test_llama31_fp4.py", 300), + TestFile("test_flashinfer_trtllm_gen_moe_backend.py", 300), ], "per-commit-4-gpu-gb200": [ TestFile("test_cutedsl_moe.py", 300), diff --git a/test/srt/test_flashinfer_trtllm_gen_moe_backend.py b/test/srt/test_flashinfer_trtllm_gen_moe_backend.py new file mode 100644 index 000000000000..890f8fe9727c --- /dev/null +++ b/test/srt/test_flashinfer_trtllm_gen_moe_backend.py @@ -0,0 +1,65 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestFlashinferTrtllmGenMoeBackend(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"}, + other_args=[ + "--attention-backend", + "triton", + "--moe-runner-backend", + "flashinfer_trtllm", + "--cuda-graph-max-bs", + "512", + "--tp-size", + "4", + "--ep-size", + "4", + "--mem-fraction-static", + "0.7", + "--mamba-ssm-dtype", + "bfloat16", + "--quantization", + "fp8", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.93) + + +if __name__ == "__main__": + unittest.main() From 4e84d10beccdc9a8846735ef395f7fd107c2c994 Mon Sep 17 00:00:00 2001 From: Sam Date: Sun, 9 Nov 2025 02:03:20 -0800 Subject: [PATCH 4/6] enable test case --- test/srt/run_suite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 1f51d0d4c8a3..848da0ad7356 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -180,7 +180,6 @@ class TestFile: TestFile("test_flash_attention_4.py", 300), TestFile("test_gpt_oss_4gpu.py", 600), TestFile("test_llama31_fp4.py", 300), - TestFile("test_flashinfer_trtllm_gen_moe_backend.py", 300), ], "per-commit-4-gpu-gb200": [ TestFile("test_cutedsl_moe.py", 300), @@ -220,6 +219,7 @@ class TestFile: TestFile("test_deepseek_v3_deterministic.py", 240), ], "nightly-4-gpu-b200": [ + TestFile("test_flashinfer_trtllm_gen_moe_backend.py", 300), TestFile("test_fp4_moe.py", 300), TestFile("nightly/test_gpt_oss_4gpu_perf.py", 600), ], From 7d875fc1cf002f3b5142edc598013a18296b956a Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 10 Nov 2025 09:01:07 +0000 Subject: [PATCH 5/6] Refactor RoutingMethodType to avoid import error --- .../srt/layers/moe/fused_moe_triton/layer.py | 3 ++- python/sglang/srt/layers/moe/utils.py | 21 ++++++++++++++++++- python/sglang/srt/layers/quantization/fp8.py | 2 +- python/sglang/srt/models/qwen2_moe.py | 12 +++-------- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index fcd6631ed0a6..b38820449248 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -32,6 +32,7 @@ StandardDispatchOutput, ) from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker +from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, @@ -56,7 +57,7 @@ ) if is_flashinfer_available(): - from flashinfer import RoutingMethodType, fp4_quantize + from flashinfer import fp4_quantize # Try to import FP4 TRTLLM function if flashinfer is available trtllm_fp4_block_scale_moe = None diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 3e902847d0cc..a7e9ce26fd47 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -2,7 +2,7 @@ import logging from contextlib import contextmanager -from enum import Enum +from enum import Enum, IntEnum from functools import lru_cache from typing import TYPE_CHECKING, Optional @@ -248,3 +248,22 @@ def speculative_moe_backend_context(): yield finally: MOE_RUNNER_BACKEND = original_backend + + +# The type of method in top-K routing, for use in torch custom op +# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h +class RoutingMethodType(IntEnum): + # Default: Softmax -> TopK + Default = (0,) + # Renormalize: TopK -> Softmax + Renormalize = (1,) + # DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups + DeepSeekV3 = (2,) + # Llama4: Top1 -> Sigmoid + Llama4 = (3,) + # Qwen3: Softmax -> TopK -> Renormalize + RenormalizeNaive = (4,) + # TopK only (no softmax) + TopK = (5,) + # Unspecified + Unspecified = 6 diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 426d41d54c89..e29706b84ca7 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1190,10 +1190,10 @@ def apply_with_router_logits( activation = self.moe_runner_config.activation routed_scaling_factor = self.moe_runner_config.routed_scaling_factor - from flashinfer import RoutingMethodType from flashinfer.fused_moe import trtllm_fp8_block_scale_moe from sglang.srt.layers.moe.topk import TopKOutputChecker + from sglang.srt.layers.moe.utils import RoutingMethodType assert TopKOutputChecker.format_is_bypassed(topk_output) router_logits = topk_output.router_logits diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 83bd7e76b5f6..32a002682f84 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -57,6 +57,7 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -70,10 +71,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.server_args import get_global_server_args from sglang.srt.two_batch_overlap import model_forward_maybe_tbo -from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers - -if is_flashinfer_available(): - from flashinfer import RoutingMethodType +from sglang.srt.utils import add_prefix, is_cuda, make_layers logger = logging.getLogger(__name__) @@ -165,11 +163,7 @@ def __init__( intermediate_size=config.moe_intermediate_size, quant_config=quant_config, prefix=add_prefix("experts", prefix), - routing_method_type=( - RoutingMethodType.RenormalizeNaive - if is_flashinfer_available() - else None - ), + routing_method_type=RoutingMethodType.RenormalizeNaive, ) self.gate = ReplicatedLinear( From 82f96f1313e5c34b9cb5f812b5b14a90c75d2f9b Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 11 Nov 2025 04:15:24 +0000 Subject: [PATCH 6/6] move unit test into nightly dir --- .../srt/{ => nightly}/test_flashinfer_trtllm_gen_moe_backend.py | 0 test/srt/run_suite.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename test/srt/{ => nightly}/test_flashinfer_trtllm_gen_moe_backend.py (100%) diff --git a/test/srt/test_flashinfer_trtllm_gen_moe_backend.py b/test/srt/nightly/test_flashinfer_trtllm_gen_moe_backend.py similarity index 100% rename from test/srt/test_flashinfer_trtllm_gen_moe_backend.py rename to test/srt/nightly/test_flashinfer_trtllm_gen_moe_backend.py diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index ad8d35afb33c..7e3f474e0d1f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -219,7 +219,7 @@ class TestFile: TestFile("test_deepseek_v3_deterministic.py", 240), ], "nightly-4-gpu-b200": [ - TestFile("test_flashinfer_trtllm_gen_moe_backend.py", 300), + TestFile("nightly/test_flashinfer_trtllm_gen_moe_backend.py", 300), TestFile("test_fp4_moe.py", 300), TestFile("nightly/test_gpt_oss_4gpu_perf.py", 600), ],