Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
**kwargs,
):
super().__init__(
num_experts=num_experts,
Expand All @@ -81,6 +82,7 @@ def __init__(
prefix=prefix,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
**kwargs,
)

if _use_aiter or _is_npu:
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
StandardDispatchOutput,
)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.moe.utils import RoutingMethodType
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
Expand All @@ -56,7 +57,7 @@
)

if is_flashinfer_available():
from flashinfer import RoutingMethodType, fp4_quantize
from flashinfer import fp4_quantize

# Try to import FP4 TRTLLM function if flashinfer is available
trtllm_fp4_block_scale_moe = None
Expand Down Expand Up @@ -144,6 +145,7 @@ def __init__(
gemm1_clamp_limit: Optional[float] = None,
use_weight_loader_fused: bool = False,
with_bias=False,
routing_method_type: Optional[RoutingMethodType] = None,
):
super().__init__()
if params_dtype is None:
Expand Down Expand Up @@ -244,6 +246,8 @@ def __init__(
and get_moe_runner_backend().is_cutlass()
)

self.routing_method_type = routing_method_type

def _load_per_tensor_weight_scale(
self,
shard_id: str,
Expand Down
21 changes: 20 additions & 1 deletion python/sglang/srt/layers/moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from contextlib import contextmanager
from enum import Enum
from enum import Enum, IntEnum
from functools import lru_cache
from typing import TYPE_CHECKING, Optional

Expand Down Expand Up @@ -248,3 +248,22 @@ def speculative_moe_backend_context():
yield
finally:
MOE_RUNNER_BACKEND = original_backend


# The type of method in top-K routing, for use in torch custom op
# Please keep this in sync with the counterpart defined in https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/trtllm/fused_moe/runner.h
class RoutingMethodType(IntEnum):
# Default: Softmax -> TopK
Default = (0,)
# Renormalize: TopK -> Softmax
Renormalize = (1,)
# DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the Top4 groups
DeepSeekV3 = (2,)
# Llama4: Top1 -> Sigmoid
Llama4 = (3,)
# Qwen3: Softmax -> TopK -> Renormalize
RenormalizeNaive = (4,)
# TopK only (no softmax)
TopK = (5,)
# Unspecified
Unspecified = 6
19 changes: 12 additions & 7 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,6 +1193,7 @@ def apply_with_router_logits(
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe

from sglang.srt.layers.moe.topk import TopKOutputChecker
from sglang.srt.layers.moe.utils import RoutingMethodType

assert TopKOutputChecker.format_is_bypassed(topk_output)
router_logits = topk_output.router_logits
Expand All @@ -1204,26 +1205,30 @@ def apply_with_router_logits(
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()

assert (
topk_config.num_expert_group is not None
and topk_config.topk_group is not None
), "Current trtllm_fp8_block_scale_moe kernel does not support these two arguments as None"

correction_bias = (
None
if topk_config.correction_bias is None
else topk_config.correction_bias.to(x.dtype)
)

routing_method_type = getattr(
layer, "routing_method_type", RoutingMethodType.DeepSeekV3
)
Comment on lines +1214 to +1216
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for getting routing_method_type can result in None being passed to the trtllm_fp8_block_scale_moe kernel. If a model using FlashInferFusedMoE does not specify routing_method_type, it defaults to None in FusedMoE.__init__. In this case, getattr(layer, "routing_method_type", ...) will return None.

The kernel previously used a hardcoded value and likely does not handle None, which could lead to a runtime error. To make this more robust and ensure backward compatibility, it's better to explicitly check for None and fall back to RoutingMethodType.DeepSeekV3.

Suggested change
routing_method_type = getattr(
layer, "routing_method_type", RoutingMethodType.DeepSeekV3
)
routing_method_type = getattr(layer, "routing_method_type", None)
if routing_method_type is None:
routing_method_type = RoutingMethodType.DeepSeekV3


with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):

# FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
# It ignored the `output`` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
# so we put the whole function under the ``use_symmetric_memory`` context manager.
# If the bug is fixed, we can only put the output tensor allocation under the context manager.
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
routing_logits=(
router_logits.to(torch.float32)
if routing_method_type == RoutingMethodType.DeepSeekV3
else router_logits
),
routing_bias=correction_bias,
hidden_states=a_q,
hidden_states_scale=a_sf_t,
Expand All @@ -1244,7 +1249,7 @@ def apply_with_router_logits(
tile_tokens_dim=get_tile_tokens_dim(
x.shape[0], topk_config.top_k, layer.num_experts
),
routing_method_type=2, # DeepSeek-styled routing method
routing_method_type=routing_method_type,
use_shuffled_weight=False,
)

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import RoutingMethodType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
Expand Down Expand Up @@ -162,6 +163,7 @@ def __init__(
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
routing_method_type=RoutingMethodType.RenormalizeNaive,
)

self.gate = ReplicatedLinear(
Expand Down
65 changes: 65 additions & 0 deletions test/srt/nightly/test_flashinfer_trtllm_gen_moe_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import unittest
from types import SimpleNamespace

from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)


class TestFlashinferTrtllmGenMoeBackend(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"},
other_args=[
"--attention-backend",
"triton",
"--moe-runner-backend",
"flashinfer_trtllm",
"--cuda-graph-max-bs",
"512",
"--tp-size",
"4",
"--ep-size",
"4",
"--mem-fraction-static",
"0.7",
"--mamba-ssm-dtype",
"bfloat16",
"--quantization",
"fp8",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.93)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class TestFile:
TestFile("test_deepseek_v3_deterministic.py", 240),
],
"nightly-4-gpu-b200": [
TestFile("nightly/test_flashinfer_trtllm_gen_moe_backend.py", 300),
TestFile("test_fp4_moe.py", 300),
TestFile("nightly/test_gpt_oss_4gpu_perf.py", 600),
TestFile("nightly/test_flashinfer_trtllm_gen_attn_backend.py", 300),
Expand Down
Loading