diff --git a/docs/advanced_features/expert_parallelism.md b/docs/advanced_features/expert_parallelism.md
index 9abe69b5ac04..73bab333c1c2 100644
--- a/docs/advanced_features/expert_parallelism.md
+++ b/docs/advanced_features/expert_parallelism.md
@@ -32,6 +32,7 @@ Currently, DeepEP, Mooncake, `ascend_fuseep` and MORI only support cases where `
| `deep_gemm` | DeepGEMM backend optimized for MoE matrix multiplications, supporting contiguous layouts for prefill and masked layouts for decode; often JIT-compiled for performance. | Large-scale EP deployments with FP8 block-wise quantization. |
| `cutlass` | CUTLASS-based backend for efficient GEMMs. | NVIDIA architectures with CUTLASS support. |
| `flashinfer_trtllm` | FlashInfer integrated with TensorRT-LLM for accelerated MoE computations, supporting FP4 communication operators and high-performance GEMMs. | Blackwell with TRT-LLM. |
+| `flashinfer_trtllm_routed` | FlashInfer integrated with TensorRT-LLM for accelerated routed MoE computations, consuming SGLang-computed top-k expert assignments and weights. | Blackwell with TRT-LLM. |
| `flashinfer_cutlass` | FlashInfer combined with CUTLASS for high-performance grouped GEMMs in MoE layers, handling FP4/FP8 quantization efficiently. | Blackwell with FP4/FP8 models. |
| `flashinfer_mxfp4` | FlashInfer variant optimized for MXFP4 (mixed FP4) quantization in MoE runners, focusing on memory-efficient low-precision inference. | Low-precision models with MXFP4. |
| `flashinfer_cutedsl` | FlashInfer with a custom DSL for flexible and efficient MoE kernel generation, integrated with ModelOpt FP4 quantization. | Low-precision models with NVFP4. |
diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md
index b8d89c208e07..61fc9059b595 100644
--- a/docs/advanced_features/server_arguments.md
+++ b/docs/advanced_features/server_arguments.md
@@ -312,7 +312,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| --- | --- | --- | --- |
| `--expert-parallel-size`
`--ep-size`
`--ep` | The expert parallelism size. | `1` | Type: int |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `ascend_fuseep`|
-| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |
+| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_trtllm_routed`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` |
| `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` |
| `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
| `--enable-aiter-allreduce-fusion` | Enable aiter allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) |
diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py
index 5fb874efc56d..43e8bcce08a1 100644
--- a/python/sglang/srt/layers/moe/ep_moe/layer.py
+++ b/python/sglang/srt/layers/moe/ep_moe/layer.py
@@ -763,10 +763,11 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
elif (
quant_config is None
or quant_config.get_name() == "fp8"
+ or quant_config.get_name() == "mxfp8"
or quant_config.get_name() == "modelopt_fp8"
or quant_config.get_name() == "compressed_tensors"
):
- # FlashInferFusedMoE support bf16, fp8 and compressed_tensors
+ # FlashInferFusedMoE supports bf16, fp8, mxfp8 and compressed_tensors
return FusedMoE
if get_moe_runner_backend().is_flashinfer_cutlass():
diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
index 8da7d8eef330..144b050922f5 100644
--- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py
@@ -220,7 +220,10 @@ def __init__(
self.use_presharded_weights = use_presharded_weights
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels()
- self.use_flashinfer_trtllm_moe = get_moe_runner_backend().is_flashinfer_trtllm()
+ self.use_flashinfer_trtllm_moe = (
+ get_moe_runner_backend().is_flashinfer_trtllm()
+ or get_moe_runner_backend().is_flashinfer_trtllm_routed()
+ )
# flashinfer_trtllm kernel requires intermediate_size to be a multiple of 128
# Pad the intermediate_size_per_partition if necessary
@@ -302,7 +305,10 @@ def __init__(
self.quant_method, ModelOptNvFp4FusedMoEMethod
) or (
isinstance(self.quant_method, Fp8MoEMethod)
- and get_moe_runner_backend().is_cutlass()
+ and (
+ get_moe_runner_backend().is_cutlass()
+ or get_moe_runner_backend().is_flashinfer_trtllm_routed()
+ )
)
self.routing_method_type = routing_method_type
diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
index 5d6fb78c4fe8..9c3eac87c664 100644
--- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
+++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py
@@ -119,6 +119,77 @@ def align_fp8_moe_weights_for_flashinfer_trtllm(
layer.output2_scales_scalar = Parameter(output2_scales_scalar, requires_grad=False)
+def align_mxfp8_moe_weights_for_flashinfer_trtllm(layer: Module) -> None:
+ """Prepare MXFP8 MoE weights/scales for FlashInfer TRT-LLM kernels."""
+ from flashinfer import (
+ reorder_rows_for_gated_act_gemm,
+ shuffle_matrix_a,
+ shuffle_matrix_sf_a,
+ )
+
+ w13_weight = cast(torch.Tensor, layer.w13_weight).contiguous()
+ w2_weight = cast(torch.Tensor, layer.w2_weight).contiguous()
+ w13_scale = cast(torch.Tensor, layer.w13_weight_scale_inv).contiguous()
+ w2_scale = cast(torch.Tensor, layer.w2_weight_scale_inv).contiguous()
+
+ assert w13_scale.dtype == torch.uint8
+ assert w2_scale.dtype == torch.uint8
+
+ num_experts, two_n, _ = w13_weight.shape
+ _, hidden_size, _ = w2_weight.shape
+ epilogue_tile_m = 128
+
+ w13_interleaved = [
+ reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts)
+ ]
+ w13_scale_interleaved = [
+ reorder_rows_for_gated_act_gemm(w13_scale[i]) for i in range(num_experts)
+ ]
+
+ w13_shuffled = [
+ shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m)
+ for i in range(num_experts)
+ ]
+ w2_shuffled = [
+ shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m)
+ for i in range(num_experts)
+ ]
+ w13_scale_shuffled = [
+ shuffle_matrix_sf_a(
+ w13_scale_interleaved[i].view(torch.uint8).reshape(two_n, -1),
+ epilogue_tile_m,
+ )
+ for i in range(num_experts)
+ ]
+ w2_scale_shuffled = [
+ shuffle_matrix_sf_a(
+ w2_scale[i].view(torch.uint8).reshape(hidden_size, -1),
+ epilogue_tile_m,
+ )
+ for i in range(num_experts)
+ ]
+
+ # Keep parameter identities stable for CUDA graph capture reuse.
+ copy_or_rebind_param(
+ layer, "w13_weight", torch.stack(w13_shuffled).view(torch.float8_e4m3fn)
+ )
+ copy_or_rebind_param(
+ layer, "w2_weight", torch.stack(w2_shuffled).view(torch.float8_e4m3fn)
+ )
+ copy_or_rebind_param(
+ layer,
+ "w13_weight_scale_inv",
+ torch.stack(w13_scale_shuffled).reshape_as(w13_scale).contiguous(),
+ )
+ copy_or_rebind_param(
+ layer,
+ "w2_weight_scale_inv",
+ torch.stack(w2_scale_shuffled).reshape_as(w2_scale).contiguous(),
+ )
+ layer.w13_weight_scale_inv.format_ue8m0 = True
+ layer.w2_weight_scale_inv.format_ue8m0 = True
+
+
def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None:
"""Prepare FP4 MoE weights/scales for FlashInfer TRT-LLM kernels.
@@ -197,6 +268,7 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo):
# Block-quant path
block_quant: bool
+ use_mxfp8: bool = False
weight_block_k: int | None = None
w13_weight_scale_inv: torch.Tensor | None = None
w2_weight_scale_inv: torch.Tensor | None = None
@@ -209,13 +281,27 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo):
use_routing_scales_on_input: bool = False
+def _pack_topk_for_flashinfer_routed(
+ topk_ids: torch.Tensor, topk_weights: torch.Tensor
+) -> torch.Tensor:
+ """Pack routed top-k tensors into FlashInfer's int32 format."""
+ packed_ids = topk_ids.to(torch.int32)
+ packed_weights = topk_weights.to(torch.bfloat16)
+ packed = (packed_ids << 16) | packed_weights.view(torch.int16).to(torch.int32)
+ # SGLang can mark padded tokens with -1 expert ids.
+ return packed.masked_fill_(packed_ids < 0, 0)
+
+
def fused_experts_none_to_flashinfer_trtllm_fp8(
dispatch_output: StandardDispatchOutput,
quant_info: FlashInferTrtllmFp8MoeQuantInfo,
runner_config: MoeRunnerConfig,
+ use_routed_topk: bool = False,
) -> StandardCombineInput:
from flashinfer.fused_moe import (
+ Fp8QuantizationType,
trtllm_fp8_block_scale_moe,
+ trtllm_fp8_block_scale_routed_moe,
trtllm_fp8_per_tensor_scale_moe,
)
@@ -228,64 +314,132 @@ def fused_experts_none_to_flashinfer_trtllm_fp8(
hidden_states = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
- assert TopKOutputChecker.format_is_bypassed(topk_output)
-
- router_logits = topk_output.router_logits
- topk_config = topk_output.topk_config
- correction_bias = (
- None
- if topk_config.correction_bias is None
- else topk_config.correction_bias.to(hidden_states.dtype)
- )
+ if TopKOutputChecker.format_is_bypassed(topk_output):
+ router_logits = topk_output.router_logits
+ topk_config = topk_output.topk_config
+ correction_bias = (
+ None
+ if topk_config.correction_bias is None
+ else topk_config.correction_bias.to(hidden_states.dtype)
+ )
+ else:
+ router_logits = None
+ topk_config = None
+ correction_bias = None
routing_method_type = quant_info.routing_method_type
+ fp8_quantization_type = (
+ Fp8QuantizationType.MxFp8
+ if quant_info.use_mxfp8
+ else Fp8QuantizationType.DeepSeekFp8
+ )
+ use_shuffled_weight = quant_info.use_mxfp8
if quant_info.block_quant:
assert quant_info.weight_block_k is not None
assert quant_info.w13_weight_scale_inv is not None
assert quant_info.w2_weight_scale_inv is not None
- a_q, a_sf = per_token_group_quant_fp8(hidden_states, quant_info.weight_block_k)
- a_sf_t = a_sf.t().contiguous()
+ if quant_info.use_mxfp8:
+ assert quant_info.weight_block_k == 32
+ from flashinfer import mxfp8_quantize
+
+ a_q, a_sf = mxfp8_quantize(hidden_states, False)
+ # FlashInfer TRT-LLM MxFP8 expects token-major activation scales:
+ # [num_tokens, hidden_size // 32] (no transpose).
+ a_sf_t = a_sf.view(torch.uint8).reshape(hidden_states.shape[0], -1)
+ else:
+ a_q, a_sf = per_token_group_quant_fp8(
+ hidden_states, quant_info.weight_block_k
+ )
+ a_sf_t = a_sf.t().contiguous()
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
- # FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
- # It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
- # so we put the whole function under the ``use_symmetric_memory`` context manager.
- # If the bug is fixed, we can only put the output tensor allocation under the context manager.
- output = trtllm_fp8_block_scale_moe(
- routing_logits=(
- router_logits.to(torch.float32)
- if routing_method_type == RoutingMethodType.DeepSeekV3
- else router_logits
- ),
- routing_bias=correction_bias,
- hidden_states=a_q,
- hidden_states_scale=a_sf_t,
- gemm1_weights=quant_info.w13_weight,
- gemm1_weights_scale=quant_info.w13_weight_scale_inv,
- gemm2_weights=quant_info.w2_weight,
- gemm2_weights_scale=quant_info.w2_weight_scale_inv,
- num_experts=quant_info.global_num_experts,
- top_k=topk_config.top_k,
- n_group=(
- topk_config.num_expert_group if topk_config.num_expert_group else 0
- ),
- topk_group=topk_config.topk_group if topk_config.topk_group else 0,
- intermediate_size=quant_info.intermediate_size,
- local_expert_offset=quant_info.local_expert_offset,
- local_num_experts=quant_info.local_num_experts,
- routed_scaling_factor=(
- runner_config.routed_scaling_factor
- if runner_config.routed_scaling_factor is not None
- else 1.0
- ),
- routing_method_type=routing_method_type,
- use_shuffled_weight=False,
- tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
- )
+ if use_routed_topk:
+ assert (
+ runner_config.top_k is not None
+ ), "runner_config.top_k is required for flashinfer_trtllm_routed."
+ assert TopKOutputChecker.format_is_standard(topk_output)
+ packed_topk_ids = _pack_topk_for_flashinfer_routed(
+ topk_ids=topk_output.topk_ids,
+ topk_weights=topk_output.topk_weights,
+ )
+
+ output = trtllm_fp8_block_scale_routed_moe(
+ topk_ids=packed_topk_ids,
+ routing_bias=None,
+ hidden_states=a_q,
+ hidden_states_scale=a_sf_t,
+ gemm1_weights=quant_info.w13_weight,
+ gemm1_weights_scale=quant_info.w13_weight_scale_inv,
+ gemm2_weights=quant_info.w2_weight,
+ gemm2_weights_scale=quant_info.w2_weight_scale_inv,
+ num_experts=quant_info.global_num_experts,
+ top_k=runner_config.top_k,
+ n_group=None,
+ topk_group=None,
+ intermediate_size=quant_info.intermediate_size,
+ local_expert_offset=quant_info.local_expert_offset,
+ local_num_experts=quant_info.local_num_experts,
+ routed_scaling_factor=(
+ runner_config.routed_scaling_factor
+ if runner_config.routed_scaling_factor is not None
+ else 1.0
+ ),
+ routing_method_type=(
+ RoutingMethodType.TopK
+ if routing_method_type == RoutingMethodType.DeepSeekV3
+ else routing_method_type
+ ),
+ use_shuffled_weight=use_shuffled_weight,
+ weight_layout=0,
+ tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
+ fp8_quantization_type=fp8_quantization_type,
+ )
+ else:
+ assert TopKOutputChecker.format_is_bypassed(topk_output)
+
+ # FIXME: there is a bug in the trtllm_fp8_block_scale_moe.
+ # It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325
+ # so we put the whole function under the ``use_symmetric_memory`` context manager.
+ # If the bug is fixed, we can only put the output tensor allocation under the context manager.
+ output = trtllm_fp8_block_scale_moe(
+ routing_logits=(
+ router_logits.to(torch.float32)
+ if routing_method_type == RoutingMethodType.DeepSeekV3
+ else router_logits
+ ),
+ routing_bias=correction_bias,
+ hidden_states=a_q,
+ hidden_states_scale=a_sf_t,
+ gemm1_weights=quant_info.w13_weight,
+ gemm1_weights_scale=quant_info.w13_weight_scale_inv,
+ gemm2_weights=quant_info.w2_weight,
+ gemm2_weights_scale=quant_info.w2_weight_scale_inv,
+ num_experts=quant_info.global_num_experts,
+ top_k=topk_config.top_k,
+ n_group=(
+ topk_config.num_expert_group
+ if topk_config.num_expert_group
+ else 0
+ ),
+ topk_group=topk_config.topk_group if topk_config.topk_group else 0,
+ intermediate_size=quant_info.intermediate_size,
+ local_expert_offset=quant_info.local_expert_offset,
+ local_num_experts=quant_info.local_num_experts,
+ routed_scaling_factor=(
+ runner_config.routed_scaling_factor
+ if runner_config.routed_scaling_factor is not None
+ else 1.0
+ ),
+ routing_method_type=routing_method_type,
+ use_shuffled_weight=use_shuffled_weight,
+ weight_layout=0,
+ tune_max_num_tokens=next_power_of_2(a_q.shape[0]),
+ fp8_quantization_type=fp8_quantization_type,
+ )
else:
assert quant_info.w13_input_scale is not None
assert quant_info.output1_scales_scalar is not None
@@ -577,3 +731,21 @@ def fused_experts_none_to_flashinfer_trtllm(
raise TypeError(
f"Unexpected quant_info type for flashinfer_trtllm: {type(quant_info)}"
)
+
+
+@register_fused_func("none", "flashinfer_trtllm_routed")
+def fused_experts_none_to_flashinfer_trtllm_routed(
+ dispatch_output: StandardDispatchOutput,
+ quant_info: MoeQuantInfo,
+ runner_config: MoeRunnerConfig,
+) -> StandardCombineInput:
+ if isinstance(quant_info, FlashInferTrtllmFp8MoeQuantInfo):
+ return fused_experts_none_to_flashinfer_trtllm_fp8(
+ dispatch_output,
+ quant_info,
+ runner_config,
+ use_routed_topk=True,
+ )
+ raise TypeError(
+ f"Unexpected quant_info type for flashinfer_trtllm_routed: {type(quant_info)}"
+ )
diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py
index 8b58cd3115bd..ee580e580586 100644
--- a/python/sglang/srt/layers/moe/moe_runner/runner.py
+++ b/python/sglang/srt/layers/moe/moe_runner/runner.py
@@ -39,7 +39,10 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
self.runner_core = DeepGemmRunnerCore(config)
elif runner_backend.is_marlin():
self.runner_core = None # Marlin only supports fused path
- elif runner_backend.is_flashinfer_trtllm():
+ elif (
+ runner_backend.is_flashinfer_trtllm()
+ or runner_backend.is_flashinfer_trtllm_routed()
+ ):
self.runner_core = None # FlashInfer TRT-LLM only supports fused path
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")
diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py
index 8ec839991a90..b77c19f83a69 100644
--- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py
+++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py
@@ -86,6 +86,9 @@ def __init__(self, moe_runner_config: MoeRunnerConfig):
self.enable_flashinfer_cutlass_moe = (
get_moe_runner_backend().is_flashinfer_cutlass()
)
+ self.enable_flashinfer_trtllm_routed_moe = (
+ get_moe_runner_backend().is_flashinfer_trtllm_routed()
+ )
self.num_experts = moe_runner_config.num_experts
self.num_local_shared_experts = moe_runner_config.num_fused_shared_experts
self.num_local_routed_experts = (
@@ -142,6 +145,7 @@ def dispatch(
if (
self.moe_ep_size > 1
and not self.enable_flashinfer_cutlass_moe
+ and not self.enable_flashinfer_trtllm_routed_moe
and TopKOutputChecker.format_is_standard(topk_output)
):
if self.local_expert_mapping is None:
diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py
index 2523ef851711..068bc67cdaae 100644
--- a/python/sglang/srt/layers/moe/topk.py
+++ b/python/sglang/srt/layers/moe/topk.py
@@ -29,6 +29,7 @@
)
import torch
+import torch.nn.functional as F
try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
@@ -443,6 +444,25 @@ def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor:
return topk_weights, topk_ids
+def fused_topk_softmax_torch_raw_logits(
+ hidden_states: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ renormalize: bool,
+):
+ assert (
+ hidden_states.shape[0] == gating_output.shape[0]
+ ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
+
+ _, topk_ids = torch.topk(gating_output, k=topk, dim=-1, sorted=False)
+ logits = gating_output.float()
+ topk_weights = logits.gather(1, topk_ids)
+ if renormalize:
+ topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32)
+
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
+
+
def fused_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -1030,15 +1050,28 @@ def select_experts(
)
elif custom_routing_function is None:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
- # Qwen3MOE uses fused_topk
- topk_weights, topk_ids = fused_topk(
- hidden_states=hidden_states,
- gating_output=router_logits,
- topk=num_routed_topk if _use_aiter else top_k,
- renormalize=renormalize,
- correction_bias=correction_bias,
- scoring_func=scoring_func,
- )
+ if (
+ get_moe_runner_backend().is_flashinfer_trtllm_routed()
+ and scoring_func == "softmax"
+ and correction_bias is None
+ ):
+ # flashinfer_trtllm_routed uses raw-logits topk
+ topk_weights, topk_ids = fused_topk_softmax_torch_raw_logits(
+ hidden_states=hidden_states,
+ gating_output=router_logits,
+ topk=num_routed_topk if _use_aiter else top_k,
+ renormalize=renormalize,
+ )
+ else:
+ # Qwen3MOE uses fused_topk
+ topk_weights, topk_ids = fused_topk(
+ hidden_states=hidden_states,
+ gating_output=router_logits,
+ topk=num_routed_topk if _use_aiter else top_k,
+ renormalize=renormalize,
+ correction_bias=correction_bias,
+ scoring_func=scoring_func,
+ )
else:
assert (
num_token_non_padded is None
diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py
index 70ad66e9c514..3a2c7f3e537b 100644
--- a/python/sglang/srt/layers/moe/utils.py
+++ b/python/sglang/srt/layers/moe/utils.py
@@ -61,6 +61,7 @@ class MoeRunnerBackend(Enum):
TRITON = "triton"
TRITON_KERNELS = "triton_kernel"
FLASHINFER_TRTLLM = "flashinfer_trtllm"
+ FLASHINFER_TRTLLM_ROUTED = "flashinfer_trtllm_routed"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
@@ -82,6 +83,9 @@ def is_triton_kernels(self):
def is_flashinfer_trtllm(self):
return self == MoeRunnerBackend.FLASHINFER_TRTLLM
+ def is_flashinfer_trtllm_routed(self):
+ return self == MoeRunnerBackend.FLASHINFER_TRTLLM_ROUTED
+
def is_flashinfer_cutlass(self):
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py
index 24757958430c..1735881f53f1 100644
--- a/python/sglang/srt/layers/quantization/fp8.py
+++ b/python/sglang/srt/layers/quantization/fp8.py
@@ -49,11 +49,12 @@
can_auto_enable_marlin_fp8,
cutlass_fp8_supported,
dispatch_w8a8_block_fp8_linear,
+ dispatch_w8a8_mxfp8_linear,
+ get_fp8_gemm_runner_backend,
input_to_float8,
mxfp8_group_quantize,
normalize_e4m3fn_to_e4m3fnuz,
requant_weight_ue8m0_inplace,
- triton_mxfp8_blockscaled_linear,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.marlin_utils_fp8 import (
@@ -71,6 +72,7 @@
per_tensor_dequantize,
requantize_with_max_scale,
)
+from sglang.srt.layers.utils import copy_or_rebind_param
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
@@ -268,7 +270,12 @@ def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]):
self.block_quant = (
self.use_mxfp8 or self.quant_config.weight_block_size is not None
)
- self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
+ self.w8a8_block_fp8_linear = None
+ self.w8a8_mxfp8_linear = None
+ if self.use_mxfp8:
+ self.w8a8_mxfp8_linear = dispatch_w8a8_mxfp8_linear()
+ else:
+ self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
self.is_checkpoint_fp8_serialized = (
self.quant_config.is_checkpoint_fp8_serialized
)
@@ -441,6 +448,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None:
# Keep parameter object to preserve weight_loader attrs for hot reload.
layer.weight_scale_inv.requires_grad_(False)
layer.weight_scale_inv.format_ue8m0 = True
+ self._process_mxfp8_linear_weight_scale(layer)
return
else:
# For fp8 linear weights run with deepgemm, the weights and scales need be requantized to ue8m0
@@ -474,6 +482,25 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None:
layer.weight.data = weight.data
layer.weight_scale_inv.data = weight_scale.data
+ def _process_mxfp8_linear_weight_scale(self, layer: Module) -> None:
+ if not self.use_mxfp8:
+ return
+
+ if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
+ from flashinfer import block_scale_interleave
+
+ scale_u8 = layer.weight_scale_inv.data
+ new_swizzled = block_scale_interleave(scale_u8.contiguous()).contiguous()
+ else:
+ # Triton path consumes canonical 2D UE8M0 scales directly.
+ return
+
+ copy_or_rebind_param(layer, "weight_scale_inv_swizzled", new_swizzled)
+ layer._weight_scale_inv_swizzled_src_version = layer.weight_scale_inv._version
+ layer._weight_scale_inv_swizzled_src_data_ptr = (
+ layer.weight_scale_inv.data_ptr()
+ )
+
def _quantize_mxfp8_weights(self, layer: Module) -> None:
weight = layer.weight.data
qweight, weight_scale = mxfp8_group_quantize(weight)
@@ -489,6 +516,7 @@ def _quantize_mxfp8_weights(self, layer: Module) -> None:
"weight_scale_inv", Parameter(weight_scale, requires_grad=False)
)
layer.weight_scale_inv.format_ue8m0 = True
+ self._process_mxfp8_linear_weight_scale(layer)
layer.input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
@@ -621,18 +649,22 @@ def apply(
)
if self.use_mxfp8:
+ if get_fp8_gemm_runner_backend().is_flashinfer_trtllm():
+ weight_scale = layer.weight_scale_inv_swizzled
+ else:
+ weight_scale = layer.weight_scale_inv
if isinstance(x, tuple):
- return triton_mxfp8_blockscaled_linear(
+ return self.w8a8_mxfp8_linear(
input=x[0],
weight=layer.weight,
- weight_scale=layer.weight_scale_inv,
+ weight_scale=weight_scale,
input_scale=x[1],
bias=bias,
)
- return triton_mxfp8_blockscaled_linear(
+ return self.w8a8_mxfp8_linear(
input=x,
weight=layer.weight,
- weight_scale=layer.weight_scale_inv,
+ weight_scale=weight_scale,
input_scale=None,
bias=bias,
)
@@ -1105,6 +1137,19 @@ def _quantize_and_swizzle_with_triton_kernel(weight: torch.Tensor):
scale = _swizzle_with_triton_kernel(weight.shape, scale)
return qweight, scale
+ def _quantize_with_flashinfer_trtllm(weight: torch.Tensor):
+ weight = weight.contiguous()
+ num_experts, m, k = weight.shape
+ assert k % 32 == 0, f"{k=} must be divisible by 32 for MXFP8"
+ from flashinfer import mxfp8_quantize
+
+ weight_flat = weight.view(-1, k).contiguous()
+ qweight, scale = mxfp8_quantize(weight_flat, False)
+ scale_u8 = (
+ scale.view(torch.uint8).contiguous().view(num_experts, m, k // 32)
+ )
+ return qweight.view_as(weight), scale_u8
+
if quantize:
if get_moe_runner_backend().is_cutlass():
w13_q, w13_s = _quantize_and_swizzle_with_cutlass_es_kernel(
@@ -1113,6 +1158,15 @@ def _quantize_and_swizzle_with_triton_kernel(weight: torch.Tensor):
w2_q, w2_s = _quantize_and_swizzle_with_cutlass_es_kernel(
layer.w2_weight.data
)
+ elif (
+ get_moe_runner_backend().is_flashinfer_trtllm()
+ or get_moe_runner_backend().is_flashinfer_trtllm_routed()
+ ):
+ # Match FlashInfer TRT-LLM MoE test contracts:
+ # 1) quantize in canonical (non-swizzled) scale layout, and
+ # 2) do row/layout shuffling in align_mxfp8_moe_weights_for_flashinfer_trtllm.
+ w13_q, w13_s = _quantize_with_flashinfer_trtllm(layer.w13_weight.data)
+ w2_q, w2_s = _quantize_with_flashinfer_trtllm(layer.w2_weight.data)
else:
w13_q, w13_s = _quantize_and_swizzle_with_triton_kernel(
layer.w13_weight.data
@@ -1121,14 +1175,23 @@ def _quantize_and_swizzle_with_triton_kernel(weight: torch.Tensor):
layer.w2_weight.data
)
else:
- w13_q = layer.w13_weight.data
- w2_q = layer.w2_weight.data
- w13_s = _swizzle_with_triton_kernel(
- layer.w13_weight.data.shape, layer.w13_weight_scale_inv.data
- )
- w2_s = _swizzle_with_triton_kernel(
- layer.w2_weight.data.shape, layer.w2_weight_scale_inv.data
- )
+ if (
+ get_moe_runner_backend().is_flashinfer_trtllm()
+ or get_moe_runner_backend().is_flashinfer_trtllm_routed()
+ ):
+ w13_q = layer.w13_weight.data
+ w2_q = layer.w2_weight.data
+ w13_s = layer.w13_weight_scale_inv.data
+ w2_s = layer.w2_weight_scale_inv.data
+ else:
+ w13_q = layer.w13_weight.data
+ w2_q = layer.w2_weight.data
+ w13_s = _swizzle_with_triton_kernel(
+ layer.w13_weight.data.shape, layer.w13_weight_scale_inv.data
+ )
+ w2_s = _swizzle_with_triton_kernel(
+ layer.w2_weight.data.shape, layer.w2_weight_scale_inv.data
+ )
# Keep parameter objects to preserve weight_loader attrs for hot reload.
# Prefer in-place copy; rebind only when shape/dtype changes (online quantize).
@@ -1154,6 +1217,16 @@ def _copy_or_rebind(param: Parameter, new_value: torch.Tensor) -> None:
layer.w13_input_scale = None
layer.w2_input_scale = None
+ if (
+ get_moe_runner_backend().is_flashinfer_trtllm()
+ or get_moe_runner_backend().is_flashinfer_trtllm_routed()
+ ):
+ from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import (
+ align_mxfp8_moe_weights_for_flashinfer_trtllm,
+ )
+
+ align_mxfp8_moe_weights_for_flashinfer_trtllm(layer)
+
def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and _use_hip_int4:
self.process_weights_hip_int4(layer)
@@ -1376,6 +1449,7 @@ def create_moe_runner(
moe_runner_backend.is_deep_gemm()
or moe_runner_backend.is_triton()
or moe_runner_backend.is_flashinfer_trtllm()
+ or moe_runner_backend.is_flashinfer_trtllm_routed()
):
self.runner = MoeRunner(moe_runner_backend, moe_runner_config)
else:
@@ -1504,9 +1578,14 @@ def apply(
w2_scale=w2_scale,
block_shape=block_shape,
)
- elif self.runner.runner_backend.is_flashinfer_trtllm():
+ elif (
+ self.runner.runner_backend.is_flashinfer_trtllm()
+ or self.runner.runner_backend.is_flashinfer_trtllm_routed()
+ ):
# FlashInfer TRT-LLM backend only supports fused execution and consumes
# router logits directly (no separate apply_with_router_logits needed).
+ # FlashInfer TRT-LLM routed backend consumes SGLang-computed
+ # top-k ids/weights (packed into int32) instead of router logits.
global_num_experts = int(getattr(layer, "num_experts"))
num_local_experts = int(getattr(layer, "num_local_experts"))
moe_ep_rank = int(getattr(layer, "moe_ep_rank"))
@@ -1522,6 +1601,7 @@ def apply(
getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3)
),
block_quant=self.block_quant,
+ use_mxfp8=getattr(self.quant_config, "use_mxfp8", False),
weight_block_k=(
None
if self.quant_config.weight_block_size is None
diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py
index ce65b5a01e6a..f072264f6eef 100644
--- a/python/sglang/srt/layers/quantization/fp8_utils.py
+++ b/python/sglang/srt/layers/quantization/fp8_utils.py
@@ -185,6 +185,8 @@ def _check_cutlass_block_fp8_hardware_support() -> bool:
if is_blackwell_supported() and is_flashinfer_available():
+ from flashinfer import mm_mxfp8 as _raw_flashinfer_mm_mxfp8
+ from flashinfer import mxfp8_quantize as _raw_flashinfer_mxfp8_quantize
from flashinfer.gemm import gemm_fp8_nt_groupwise as _raw_gemm_fp8_nt_groupwise
from sglang.srt.utils.custom_op import register_custom_op
@@ -242,6 +244,62 @@ def gemm_fp8_nt_groupwise(
backend=backend,
)
+ # Wrap MXFP8 ops as custom ops so torch.compile does not trace into
+ # flashinfer's JIT compilation path (filesystem checks/cubin loader).
+ def _fake_flashinfer_mxfp8_quantize(
+ input: torch.Tensor,
+ _is_sf_swizzled_layout: bool = True,
+ alignment: int = 32,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Fake mode only needs dtypes and output rank to propagate compile graph.
+ # The scale tensor shape is not consumed before the following fake mm op.
+ k_aligned = ((input.shape[1] + alignment - 1) // alignment) * alignment
+ q_input = input.new_empty(
+ (input.shape[0], k_aligned), dtype=torch.float8_e4m3fn
+ )
+ scale = input.new_empty((1,), dtype=torch.uint8)
+ return q_input, scale
+
+ @register_custom_op(
+ op_name="flashinfer_mxfp8_quantize",
+ mutates_args=[],
+ fake_impl=_fake_flashinfer_mxfp8_quantize,
+ )
+ def flashinfer_mxfp8_quantize(
+ input: torch.Tensor,
+ is_sf_swizzled_layout: bool = True,
+ alignment: int = 32,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ return _raw_flashinfer_mxfp8_quantize(
+ input,
+ is_sf_swizzled_layout=is_sf_swizzled_layout,
+ alignment=alignment,
+ )
+
+ @register_custom_op(
+ op_name="flashinfer_mm_mxfp8",
+ mutates_args=[],
+ fake_impl=lambda q_input, weight_t, x_scale_u8, weight_scale_t, out_dtype, backend="auto": (
+ q_input.new_empty((q_input.shape[0], weight_t.shape[1]), dtype=out_dtype)
+ ),
+ )
+ def flashinfer_mm_mxfp8(
+ q_input: torch.Tensor,
+ weight_t: torch.Tensor,
+ x_scale_u8: torch.Tensor,
+ weight_scale_t: torch.Tensor,
+ out_dtype: torch.dtype,
+ backend: str = "auto",
+ ) -> torch.Tensor:
+ return _raw_flashinfer_mm_mxfp8(
+ q_input,
+ weight_t,
+ x_scale_u8,
+ weight_scale_t,
+ out_dtype=out_dtype,
+ backend=backend,
+ )
+
if is_sm90_supported() and is_flashinfer_available():
# FlashInfer SM90 DeepGEMM with automatic swapAB optimization for small M
@@ -266,6 +324,18 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
return _dispatch_auto_backend()
+def dispatch_w8a8_mxfp8_linear() -> Callable:
+ """Dispatch MXFP8 linear kernel by --fp8-gemm-backend.
+
+ For MXFP8, Triton remains the default path. We only route to FlashInfer
+ when backend is explicitly set to flashinfer_trtllm.
+ """
+ backend = get_fp8_gemm_runner_backend()
+ if backend.is_flashinfer_trtllm():
+ return flashinfer_mxfp8_blockscaled_linear
+ return triton_mxfp8_blockscaled_linear
+
+
def _dispatch_explicit_backend(backend: Fp8GemmRunnerBackend) -> Callable:
"""Dispatch based on explicitly selected backend."""
if backend.is_flashinfer_trtllm():
@@ -843,6 +913,61 @@ def triton_mxfp8_blockscaled_linear(
return output.to(dtype=output_dtype).view(*output_shape)
+def flashinfer_mxfp8_blockscaled_linear(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ weight_scale: torch.Tensor,
+ input_scale: Optional[torch.Tensor] = None,
+ bias: Optional[torch.Tensor] = None,
+ output_dtype: Optional[torch.dtype] = None,
+) -> torch.Tensor:
+ """MXFP8 dense linear via FlashInfer mm_mxfp8."""
+ input_2d = input.view(-1, input.shape[-1]).contiguous()
+ output_shape = [*input.shape[:-1], weight.shape[0]]
+
+ m, k = input_2d.shape
+ n, k_w = weight.shape
+ if k != k_w:
+ raise ValueError(f"Input K={k} does not match weight K={k_w}.")
+ if k % 32 != 0:
+ raise ValueError(f"K={k} must be divisible by 32 for MXFP8.")
+ if weight.dtype != torch.float8_e4m3fn:
+ raise TypeError("MXFP8 weight must be FP8 E4M3.")
+
+ if input_scale is None:
+ q_input, x_scale_u8 = flashinfer_mxfp8_quantize(
+ input_2d, is_sf_swizzled_layout=True, alignment=32
+ )
+ else:
+ q_input = input_2d
+
+ if output_dtype is None:
+ if input_2d.dtype in (torch.float16, torch.bfloat16, torch.float32):
+ output_dtype = input_2d.dtype
+ else:
+ output_dtype = torch.bfloat16
+
+ # Ensure transposed tensors are contiguous for FlashInfer's internal runner.
+ weight_t = weight.contiguous().t()
+ weight_scale_t = (
+ weight_scale.contiguous().t()
+ if weight_scale.ndim == 2
+ else weight_scale.contiguous()
+ )
+ output = flashinfer_mm_mxfp8(
+ q_input,
+ weight_t,
+ x_scale_u8,
+ weight_scale_t,
+ out_dtype=output_dtype,
+ backend="auto",
+ )
+
+ if bias is not None:
+ output += bias
+ return output.to(dtype=output_dtype).view(*output_shape)
+
+
def dequant_mxfp4(
w_block: torch.Tensor,
w_scale: torch.Tensor,
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 6bfa8ecb30ea..657c4c1b9870 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -179,6 +179,7 @@
"triton",
"triton_kernel",
"flashinfer_trtllm",
+ "flashinfer_trtllm_routed",
"flashinfer_cutlass",
"flashinfer_mxfp4",
"flashinfer_cutedsl",
@@ -2454,12 +2455,19 @@ def _handle_data_parallelism(self):
def _handle_moe_kernel_config(self):
if self.quantization == "mxfp8":
- if self.moe_runner_backend not in ["auto", "cutlass"]:
+ if self.moe_runner_backend == "auto":
+ self.moe_runner_backend = "flashinfer_trtllm"
+ elif self.moe_runner_backend not in [
+ "cutlass",
+ "flashinfer_trtllm",
+ "flashinfer_trtllm_routed",
+ ]:
logger.warning(
- "mxfp8 quantization forces --moe-runner-backend=cutlass. "
+ "mxfp8 quantization supports only cutlass, flashinfer_trtllm, "
+ "or flashinfer_trtllm_routed backends. "
f"Overriding {self.moe_runner_backend!r}."
)
- self.moe_runner_backend = "cutlass"
+ self.moe_runner_backend = "flashinfer_trtllm"
if self.moe_runner_backend == "flashinfer_cutlass":
assert self.quantization in [
@@ -2476,6 +2484,7 @@ def _handle_moe_kernel_config(self):
assert self.quantization in [
"modelopt_fp4",
"fp8",
+ "mxfp8",
"modelopt_fp8",
"compressed-tensors",
None,
@@ -2485,6 +2494,16 @@ def _handle_moe_kernel_config(self):
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
)
+ if self.moe_runner_backend == "flashinfer_trtllm_routed":
+ assert self.quantization in [
+ "fp8",
+ "mxfp8",
+ ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8' or 'mxfp8'."
+ self.disable_shared_experts_fusion = True
+ logger.warning(
+ "FlashInfer TRTLLM routed MoE is enabled. --disable-shared-experts-fusion is automatically set."
+ )
+
if get_bool_env_var("SGLANG_CUTLASS_MOE"):
logger.warning(
"SGLANG_CUTLASS_MOE is deprecated, use --moe-runner-backend=cutlass and/or --speculative-moe-runner-backend=cutlass instead"
@@ -2691,7 +2710,8 @@ def _handle_speculative_decoding(self):
if self.speculative_moe_runner_backend is None:
self.speculative_moe_runner_backend = (
"auto"
- if self.moe_runner_backend == "flashinfer_trtllm"
+ if self.moe_runner_backend
+ in ["flashinfer_trtllm", "flashinfer_trtllm_routed"]
else self.moe_runner_backend
)
else:
diff --git a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py
index 63db2b2ad1cc..df0e3af457cb 100644
--- a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py
+++ b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py
@@ -12,10 +12,12 @@
popen_launch_server,
)
-register_cuda_ci(est_time=300, suite="nightly-4-gpu-b200", nightly=True)
+register_cuda_ci(est_time=500, suite="nightly-4-gpu-b200", nightly=True)
-class TestFlashinferTrtllmGenMoeBackendFP8(CustomTestCase):
+class FlashinferTrtllmGenMoeBackendFP8Base:
+ backend = None
+
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8"
@@ -29,7 +31,7 @@ def setUpClass(cls):
"--attention-backend",
"triton",
"--moe-runner-backend",
- "flashinfer_trtllm",
+ cls.backend,
"--tp-size",
"4",
"--ep-size",
@@ -60,7 +62,9 @@ def test_gsm8k(self):
self.assertGreater(metrics["accuracy"], 0.93)
-class TestFlashinferTrtllmGenMoeBackendBF16(CustomTestCase):
+class FlashinferTrtllmGenMoeBackendBF16Base:
+ backend = None
+
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
@@ -73,7 +77,7 @@ def setUpClass(cls):
"--attention-backend",
"triton",
"--moe-runner-backend",
- "flashinfer_trtllm",
+ cls.backend,
"--cuda-graph-max-bs",
"512",
"--tp-size",
@@ -106,5 +110,82 @@ def test_gsm8k(self):
self.assertGreater(metrics["accuracy"], 0.93)
+class FlashinferTrtllmGenMoeBackendMXFP8Base:
+ backend = None
+
+ @classmethod
+ def setUpClass(cls):
+ cls.model = "Qwen/Qwen3-30B-A3B-Instruct-2507"
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"},
+ other_args=[
+ "--quantization",
+ "mxfp8",
+ "--fp8-gemm-backend",
+ "flashinfer_trtllm",
+ "--moe-runner-backend",
+ cls.backend,
+ "--tp-size",
+ "4",
+ "--ep-size",
+ "4",
+ "--mem-fraction-static",
+ "0.7",
+ ],
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_process_tree(cls.process.pid)
+
+ def test_gsm8k(self):
+ args = SimpleNamespace(
+ num_shots=5,
+ data_path=None,
+ num_questions=200,
+ max_new_tokens=512,
+ parallel=128,
+ host="http://127.0.0.1",
+ port=int(self.base_url.split(":")[-1]),
+ )
+ metrics = run_eval(args)
+ print(f"{metrics=}")
+ self.assertGreater(metrics["accuracy"], 0.93)
+
+
+class TestFlashinferTrtllmGenMoeBackendFP8(
+ FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase
+):
+ backend = "flashinfer_trtllm"
+
+
+class TestFlashinferTrtllmGenMoeBackendMXFP8(
+ FlashinferTrtllmGenMoeBackendMXFP8Base, CustomTestCase
+):
+ backend = "flashinfer_trtllm"
+
+
+class TestFlashinferTrtllmGenMoeBackendBF16(
+ FlashinferTrtllmGenMoeBackendBF16Base, CustomTestCase
+):
+ backend = "flashinfer_trtllm"
+
+
+class TestFlashinferTrtllmGenMoeBackendFP8Routed(
+ FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase
+):
+ backend = "flashinfer_trtllm_routed"
+
+
+class TestFlashinferTrtllmGenMoeBackendMXFP8Routed(
+ FlashinferTrtllmGenMoeBackendMXFP8Base, CustomTestCase
+):
+ backend = "flashinfer_trtllm_routed"
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/test/registered/quant/test_fp8_blockwise_gemm.py b/test/registered/quant/test_fp8_blockwise_gemm.py
index b832e4740668..af7600cb5380 100644
--- a/test/registered/quant/test_fp8_blockwise_gemm.py
+++ b/test/registered/quant/test_fp8_blockwise_gemm.py
@@ -12,9 +12,10 @@
try_cached_model,
)
-register_cuda_ci(est_time=280, suite="stage-c-test-4-gpu-b200")
+register_cuda_ci(est_time=420, suite="stage-c-test-4-gpu-b200")
MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507-FP8"
+BF16_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507"
class FP8BlockwiseGemmBase:
@@ -56,7 +57,51 @@ def test_gsm8k(self):
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
- self.assertGreaterEqual(metrics["accuracy"], 0.41)
+ self.assertGreaterEqual(metrics["accuracy"], 0.8)
+
+
+class MXFP8GemmBase:
+ backend = None
+
+ @classmethod
+ def setUpClass(cls):
+ if cls.backend is None:
+ raise NotImplementedError("Subclass must set 'backend' attribute")
+ cls.model = try_cached_model(BF16_MODEL_PATH)
+ cls.base_url = DEFAULT_URL_FOR_TEST
+ other_args = [
+ "--trust-remote-code",
+ "--quantization",
+ "mxfp8",
+ "--fp8-gemm-backend",
+ cls.backend,
+ ]
+ cls.process = popen_launch_server(
+ cls.model,
+ cls.base_url,
+ timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
+ other_args=other_args,
+ )
+
+ @classmethod
+ def tearDownClass(cls):
+ kill_process_tree(cls.process.pid)
+
+ def test_gsm8k(self):
+ parsed_url = urlparse(self.base_url)
+ args = SimpleNamespace(
+ num_shots=8,
+ data_path=None,
+ num_questions=1319,
+ max_new_tokens=512,
+ parallel=200,
+ host=f"{parsed_url.scheme}://{parsed_url.hostname}",
+ port=parsed_url.port,
+ )
+ metrics = run_eval_few_shot_gsm8k(args)
+ print(metrics)
+
+ self.assertGreaterEqual(metrics["accuracy"], 0.8)
class TestFP8BlockwiseGemmTriton(FP8BlockwiseGemmBase, unittest.TestCase):
@@ -77,5 +122,15 @@ class TestFP8BlockwiseGemmFlashinferDeepGemm(FP8BlockwiseGemmBase, unittest.Test
backend = "flashinfer_deepgemm"
+@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
+class TestMXFP8GemmTriton(MXFP8GemmBase, unittest.TestCase):
+ backend = "triton"
+
+
+@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher")
+class TestMXFP8GemmFlashinferTrtllm(MXFP8GemmBase, unittest.TestCase):
+ backend = "flashinfer_trtllm"
+
+
if __name__ == "__main__":
unittest.main()