diff --git a/docs/advanced_features/expert_parallelism.md b/docs/advanced_features/expert_parallelism.md index 9abe69b5ac04..73bab333c1c2 100644 --- a/docs/advanced_features/expert_parallelism.md +++ b/docs/advanced_features/expert_parallelism.md @@ -32,6 +32,7 @@ Currently, DeepEP, Mooncake, `ascend_fuseep` and MORI only support cases where ` | `deep_gemm` | DeepGEMM backend optimized for MoE matrix multiplications, supporting contiguous layouts for prefill and masked layouts for decode; often JIT-compiled for performance. | Large-scale EP deployments with FP8 block-wise quantization. | | `cutlass` | CUTLASS-based backend for efficient GEMMs. | NVIDIA architectures with CUTLASS support. | | `flashinfer_trtllm` | FlashInfer integrated with TensorRT-LLM for accelerated MoE computations, supporting FP4 communication operators and high-performance GEMMs. | Blackwell with TRT-LLM. | +| `flashinfer_trtllm_routed` | FlashInfer integrated with TensorRT-LLM for accelerated routed MoE computations, consuming SGLang-computed top-k expert assignments and weights. | Blackwell with TRT-LLM. | | `flashinfer_cutlass` | FlashInfer combined with CUTLASS for high-performance grouped GEMMs in MoE layers, handling FP4/FP8 quantization efficiently. | Blackwell with FP4/FP8 models. | | `flashinfer_mxfp4` | FlashInfer variant optimized for MXFP4 (mixed FP4) quantization in MoE runners, focusing on memory-efficient low-precision inference. | Low-precision models with MXFP4. | | `flashinfer_cutedsl` | FlashInfer with a custom DSL for flexible and efficient MoE kernel generation, integrated with ModelOpt FP4 quantization. | Low-precision models with NVFP4. | diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index b8d89c208e07..61fc9059b595 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -312,7 +312,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s | --- | --- | --- | --- | | `--expert-parallel-size`
`--ep-size`
`--ep` | The expert parallelism size. | `1` | Type: int | | `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | `none` | `none`, `deepep`, `mooncake`, `mori`, `ascend_fuseep`| -| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` | +| `--moe-runner-backend` | Choose the runner backend for MoE. | `auto` | `auto`, `deep_gemm`, `triton`, `triton_kernel`, `flashinfer_trtllm`, `flashinfer_trtllm_routed`, `flashinfer_cutlass`, `flashinfer_mxfp4`, `flashinfer_cutedsl`, `cutlass` | | `--flashinfer-mxfp4-moe-precision` | Choose the computation precision of flashinfer mxfp4 moe | `default` | `default`, `bf16` | | `--enable-flashinfer-allreduce-fusion` | Enable FlashInfer allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | | `--enable-aiter-allreduce-fusion` | Enable aiter allreduce fusion with Residual RMSNorm. | `False` | bool flag (set to enable) | diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 5fb874efc56d..43e8bcce08a1 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -763,10 +763,11 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]): elif ( quant_config is None or quant_config.get_name() == "fp8" + or quant_config.get_name() == "mxfp8" or quant_config.get_name() == "modelopt_fp8" or quant_config.get_name() == "compressed_tensors" ): - # FlashInferFusedMoE support bf16, fp8 and compressed_tensors + # FlashInferFusedMoE supports bf16, fp8, mxfp8 and compressed_tensors return FusedMoE if get_moe_runner_backend().is_flashinfer_cutlass(): diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8da7d8eef330..144b050922f5 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -220,7 +220,10 @@ def __init__( self.use_presharded_weights = use_presharded_weights self.use_triton_kernels = get_moe_runner_backend().is_triton_kernels() - self.use_flashinfer_trtllm_moe = get_moe_runner_backend().is_flashinfer_trtllm() + self.use_flashinfer_trtllm_moe = ( + get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ) # flashinfer_trtllm kernel requires intermediate_size to be a multiple of 128 # Pad the intermediate_size_per_partition if necessary @@ -302,7 +305,10 @@ def __init__( self.quant_method, ModelOptNvFp4FusedMoEMethod ) or ( isinstance(self.quant_method, Fp8MoEMethod) - and get_moe_runner_backend().is_cutlass() + and ( + get_moe_runner_backend().is_cutlass() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ) ) self.routing_method_type = routing_method_type diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index 5d6fb78c4fe8..9c3eac87c664 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -119,6 +119,77 @@ def align_fp8_moe_weights_for_flashinfer_trtllm( layer.output2_scales_scalar = Parameter(output2_scales_scalar, requires_grad=False) +def align_mxfp8_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: + """Prepare MXFP8 MoE weights/scales for FlashInfer TRT-LLM kernels.""" + from flashinfer import ( + reorder_rows_for_gated_act_gemm, + shuffle_matrix_a, + shuffle_matrix_sf_a, + ) + + w13_weight = cast(torch.Tensor, layer.w13_weight).contiguous() + w2_weight = cast(torch.Tensor, layer.w2_weight).contiguous() + w13_scale = cast(torch.Tensor, layer.w13_weight_scale_inv).contiguous() + w2_scale = cast(torch.Tensor, layer.w2_weight_scale_inv).contiguous() + + assert w13_scale.dtype == torch.uint8 + assert w2_scale.dtype == torch.uint8 + + num_experts, two_n, _ = w13_weight.shape + _, hidden_size, _ = w2_weight.shape + epilogue_tile_m = 128 + + w13_interleaved = [ + reorder_rows_for_gated_act_gemm(w13_weight[i]) for i in range(num_experts) + ] + w13_scale_interleaved = [ + reorder_rows_for_gated_act_gemm(w13_scale[i]) for i in range(num_experts) + ] + + w13_shuffled = [ + shuffle_matrix_a(w13_interleaved[i].view(torch.uint8), epilogue_tile_m) + for i in range(num_experts) + ] + w2_shuffled = [ + shuffle_matrix_a(w2_weight[i].view(torch.uint8), epilogue_tile_m) + for i in range(num_experts) + ] + w13_scale_shuffled = [ + shuffle_matrix_sf_a( + w13_scale_interleaved[i].view(torch.uint8).reshape(two_n, -1), + epilogue_tile_m, + ) + for i in range(num_experts) + ] + w2_scale_shuffled = [ + shuffle_matrix_sf_a( + w2_scale[i].view(torch.uint8).reshape(hidden_size, -1), + epilogue_tile_m, + ) + for i in range(num_experts) + ] + + # Keep parameter identities stable for CUDA graph capture reuse. + copy_or_rebind_param( + layer, "w13_weight", torch.stack(w13_shuffled).view(torch.float8_e4m3fn) + ) + copy_or_rebind_param( + layer, "w2_weight", torch.stack(w2_shuffled).view(torch.float8_e4m3fn) + ) + copy_or_rebind_param( + layer, + "w13_weight_scale_inv", + torch.stack(w13_scale_shuffled).reshape_as(w13_scale).contiguous(), + ) + copy_or_rebind_param( + layer, + "w2_weight_scale_inv", + torch.stack(w2_scale_shuffled).reshape_as(w2_scale).contiguous(), + ) + layer.w13_weight_scale_inv.format_ue8m0 = True + layer.w2_weight_scale_inv.format_ue8m0 = True + + def align_fp4_moe_weights_for_flashinfer_trtllm(layer: Module) -> None: """Prepare FP4 MoE weights/scales for FlashInfer TRT-LLM kernels. @@ -197,6 +268,7 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): # Block-quant path block_quant: bool + use_mxfp8: bool = False weight_block_k: int | None = None w13_weight_scale_inv: torch.Tensor | None = None w2_weight_scale_inv: torch.Tensor | None = None @@ -209,13 +281,27 @@ class FlashInferTrtllmFp8MoeQuantInfo(MoeQuantInfo): use_routing_scales_on_input: bool = False +def _pack_topk_for_flashinfer_routed( + topk_ids: torch.Tensor, topk_weights: torch.Tensor +) -> torch.Tensor: + """Pack routed top-k tensors into FlashInfer's int32 format.""" + packed_ids = topk_ids.to(torch.int32) + packed_weights = topk_weights.to(torch.bfloat16) + packed = (packed_ids << 16) | packed_weights.view(torch.int16).to(torch.int32) + # SGLang can mark padded tokens with -1 expert ids. + return packed.masked_fill_(packed_ids < 0, 0) + + def fused_experts_none_to_flashinfer_trtllm_fp8( dispatch_output: StandardDispatchOutput, quant_info: FlashInferTrtllmFp8MoeQuantInfo, runner_config: MoeRunnerConfig, + use_routed_topk: bool = False, ) -> StandardCombineInput: from flashinfer.fused_moe import ( + Fp8QuantizationType, trtllm_fp8_block_scale_moe, + trtllm_fp8_block_scale_routed_moe, trtllm_fp8_per_tensor_scale_moe, ) @@ -228,64 +314,132 @@ def fused_experts_none_to_flashinfer_trtllm_fp8( hidden_states = dispatch_output.hidden_states topk_output = dispatch_output.topk_output - assert TopKOutputChecker.format_is_bypassed(topk_output) - - router_logits = topk_output.router_logits - topk_config = topk_output.topk_config - correction_bias = ( - None - if topk_config.correction_bias is None - else topk_config.correction_bias.to(hidden_states.dtype) - ) + if TopKOutputChecker.format_is_bypassed(topk_output): + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(hidden_states.dtype) + ) + else: + router_logits = None + topk_config = None + correction_bias = None routing_method_type = quant_info.routing_method_type + fp8_quantization_type = ( + Fp8QuantizationType.MxFp8 + if quant_info.use_mxfp8 + else Fp8QuantizationType.DeepSeekFp8 + ) + use_shuffled_weight = quant_info.use_mxfp8 if quant_info.block_quant: assert quant_info.weight_block_k is not None assert quant_info.w13_weight_scale_inv is not None assert quant_info.w2_weight_scale_inv is not None - a_q, a_sf = per_token_group_quant_fp8(hidden_states, quant_info.weight_block_k) - a_sf_t = a_sf.t().contiguous() + if quant_info.use_mxfp8: + assert quant_info.weight_block_k == 32 + from flashinfer import mxfp8_quantize + + a_q, a_sf = mxfp8_quantize(hidden_states, False) + # FlashInfer TRT-LLM MxFP8 expects token-major activation scales: + # [num_tokens, hidden_size // 32] (no transpose). + a_sf_t = a_sf.view(torch.uint8).reshape(hidden_states.shape[0], -1) + else: + a_q, a_sf = per_token_group_quant_fp8( + hidden_states, quant_info.weight_block_k + ) + a_sf_t = a_sf.t().contiguous() with use_symmetric_memory( get_tp_group(), disabled=not is_allocation_symmetric() ): - # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. - # It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 - # so we put the whole function under the ``use_symmetric_memory`` context manager. - # If the bug is fixed, we can only put the output tensor allocation under the context manager. - output = trtllm_fp8_block_scale_moe( - routing_logits=( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ), - routing_bias=correction_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=quant_info.w13_weight, - gemm1_weights_scale=quant_info.w13_weight_scale_inv, - gemm2_weights=quant_info.w2_weight, - gemm2_weights_scale=quant_info.w2_weight_scale_inv, - num_experts=quant_info.global_num_experts, - top_k=topk_config.top_k, - n_group=( - topk_config.num_expert_group if topk_config.num_expert_group else 0 - ), - topk_group=topk_config.topk_group if topk_config.topk_group else 0, - intermediate_size=quant_info.intermediate_size, - local_expert_offset=quant_info.local_expert_offset, - local_num_experts=quant_info.local_num_experts, - routed_scaling_factor=( - runner_config.routed_scaling_factor - if runner_config.routed_scaling_factor is not None - else 1.0 - ), - routing_method_type=routing_method_type, - use_shuffled_weight=False, - tune_max_num_tokens=next_power_of_2(a_q.shape[0]), - ) + if use_routed_topk: + assert ( + runner_config.top_k is not None + ), "runner_config.top_k is required for flashinfer_trtllm_routed." + assert TopKOutputChecker.format_is_standard(topk_output) + packed_topk_ids = _pack_topk_for_flashinfer_routed( + topk_ids=topk_output.topk_ids, + topk_weights=topk_output.topk_weights, + ) + + output = trtllm_fp8_block_scale_routed_moe( + topk_ids=packed_topk_ids, + routing_bias=None, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=quant_info.w13_weight, + gemm1_weights_scale=quant_info.w13_weight_scale_inv, + gemm2_weights=quant_info.w2_weight, + gemm2_weights_scale=quant_info.w2_weight_scale_inv, + num_experts=quant_info.global_num_experts, + top_k=runner_config.top_k, + n_group=None, + topk_group=None, + intermediate_size=quant_info.intermediate_size, + local_expert_offset=quant_info.local_expert_offset, + local_num_experts=quant_info.local_num_experts, + routed_scaling_factor=( + runner_config.routed_scaling_factor + if runner_config.routed_scaling_factor is not None + else 1.0 + ), + routing_method_type=( + RoutingMethodType.TopK + if routing_method_type == RoutingMethodType.DeepSeekV3 + else routing_method_type + ), + use_shuffled_weight=use_shuffled_weight, + weight_layout=0, + tune_max_num_tokens=next_power_of_2(a_q.shape[0]), + fp8_quantization_type=fp8_quantization_type, + ) + else: + assert TopKOutputChecker.format_is_bypassed(topk_output) + + # FIXME: there is a bug in the trtllm_fp8_block_scale_moe. + # It ignored the `output` argument. https://github.com/flashinfer-ai/flashinfer/blob/da01b1bd8f9f22aec8c0eea189ad54860b034947/flashinfer/fused_moe/core.py#L1323-L1325 + # so we put the whole function under the ``use_symmetric_memory`` context manager. + # If the bug is fixed, we can only put the output tensor allocation under the context manager. + output = trtllm_fp8_block_scale_moe( + routing_logits=( + router_logits.to(torch.float32) + if routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ), + routing_bias=correction_bias, + hidden_states=a_q, + hidden_states_scale=a_sf_t, + gemm1_weights=quant_info.w13_weight, + gemm1_weights_scale=quant_info.w13_weight_scale_inv, + gemm2_weights=quant_info.w2_weight, + gemm2_weights_scale=quant_info.w2_weight_scale_inv, + num_experts=quant_info.global_num_experts, + top_k=topk_config.top_k, + n_group=( + topk_config.num_expert_group + if topk_config.num_expert_group + else 0 + ), + topk_group=topk_config.topk_group if topk_config.topk_group else 0, + intermediate_size=quant_info.intermediate_size, + local_expert_offset=quant_info.local_expert_offset, + local_num_experts=quant_info.local_num_experts, + routed_scaling_factor=( + runner_config.routed_scaling_factor + if runner_config.routed_scaling_factor is not None + else 1.0 + ), + routing_method_type=routing_method_type, + use_shuffled_weight=use_shuffled_weight, + weight_layout=0, + tune_max_num_tokens=next_power_of_2(a_q.shape[0]), + fp8_quantization_type=fp8_quantization_type, + ) else: assert quant_info.w13_input_scale is not None assert quant_info.output1_scales_scalar is not None @@ -577,3 +731,21 @@ def fused_experts_none_to_flashinfer_trtllm( raise TypeError( f"Unexpected quant_info type for flashinfer_trtllm: {type(quant_info)}" ) + + +@register_fused_func("none", "flashinfer_trtllm_routed") +def fused_experts_none_to_flashinfer_trtllm_routed( + dispatch_output: StandardDispatchOutput, + quant_info: MoeQuantInfo, + runner_config: MoeRunnerConfig, +) -> StandardCombineInput: + if isinstance(quant_info, FlashInferTrtllmFp8MoeQuantInfo): + return fused_experts_none_to_flashinfer_trtllm_fp8( + dispatch_output, + quant_info, + runner_config, + use_routed_topk=True, + ) + raise TypeError( + f"Unexpected quant_info type for flashinfer_trtllm_routed: {type(quant_info)}" + ) diff --git a/python/sglang/srt/layers/moe/moe_runner/runner.py b/python/sglang/srt/layers/moe/moe_runner/runner.py index 8b58cd3115bd..ee580e580586 100644 --- a/python/sglang/srt/layers/moe/moe_runner/runner.py +++ b/python/sglang/srt/layers/moe/moe_runner/runner.py @@ -39,7 +39,10 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig): self.runner_core = DeepGemmRunnerCore(config) elif runner_backend.is_marlin(): self.runner_core = None # Marlin only supports fused path - elif runner_backend.is_flashinfer_trtllm(): + elif ( + runner_backend.is_flashinfer_trtllm() + or runner_backend.is_flashinfer_trtllm_routed() + ): self.runner_core = None # FlashInfer TRT-LLM only supports fused path else: raise NotImplementedError(f"Unsupported runner backend: {runner_backend}") diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index 8ec839991a90..b77c19f83a69 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -86,6 +86,9 @@ def __init__(self, moe_runner_config: MoeRunnerConfig): self.enable_flashinfer_cutlass_moe = ( get_moe_runner_backend().is_flashinfer_cutlass() ) + self.enable_flashinfer_trtllm_routed_moe = ( + get_moe_runner_backend().is_flashinfer_trtllm_routed() + ) self.num_experts = moe_runner_config.num_experts self.num_local_shared_experts = moe_runner_config.num_fused_shared_experts self.num_local_routed_experts = ( @@ -142,6 +145,7 @@ def dispatch( if ( self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe + and not self.enable_flashinfer_trtllm_routed_moe and TopKOutputChecker.format_is_standard(topk_output) ): if self.local_expert_mapping is None: diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 2523ef851711..068bc67cdaae 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -29,6 +29,7 @@ ) import torch +import torch.nn.functional as F try: from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing @@ -443,6 +444,25 @@ def scoring_func_impl(gating_output: torch.Tensor) -> torch.Tensor: return topk_weights, topk_ids +def fused_topk_softmax_torch_raw_logits( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" + + _, topk_ids = torch.topk(gating_output, k=topk, dim=-1, sorted=False) + logits = gating_output.float() + topk_weights = logits.gather(1, topk_ids) + if renormalize: + topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + def fused_topk_cpu( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -1030,15 +1050,28 @@ def select_experts( ) elif custom_routing_function is None: assert not apply_routed_scaling_factor_on_output, "Not implemented" - # Qwen3MOE uses fused_topk - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=num_routed_topk if _use_aiter else top_k, - renormalize=renormalize, - correction_bias=correction_bias, - scoring_func=scoring_func, - ) + if ( + get_moe_runner_backend().is_flashinfer_trtllm_routed() + and scoring_func == "softmax" + and correction_bias is None + ): + # flashinfer_trtllm_routed uses raw-logits topk + topk_weights, topk_ids = fused_topk_softmax_torch_raw_logits( + hidden_states=hidden_states, + gating_output=router_logits, + topk=num_routed_topk if _use_aiter else top_k, + renormalize=renormalize, + ) + else: + # Qwen3MOE uses fused_topk + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=num_routed_topk if _use_aiter else top_k, + renormalize=renormalize, + correction_bias=correction_bias, + scoring_func=scoring_func, + ) else: assert ( num_token_non_padded is None diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index 70ad66e9c514..3a2c7f3e537b 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -61,6 +61,7 @@ class MoeRunnerBackend(Enum): TRITON = "triton" TRITON_KERNELS = "triton_kernel" FLASHINFER_TRTLLM = "flashinfer_trtllm" + FLASHINFER_TRTLLM_ROUTED = "flashinfer_trtllm_routed" FLASHINFER_CUTLASS = "flashinfer_cutlass" FLASHINFER_MXFP4 = "flashinfer_mxfp4" FLASHINFER_CUTEDSL = "flashinfer_cutedsl" @@ -82,6 +83,9 @@ def is_triton_kernels(self): def is_flashinfer_trtllm(self): return self == MoeRunnerBackend.FLASHINFER_TRTLLM + def is_flashinfer_trtllm_routed(self): + return self == MoeRunnerBackend.FLASHINFER_TRTLLM_ROUTED + def is_flashinfer_cutlass(self): return self == MoeRunnerBackend.FLASHINFER_CUTLASS diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 24757958430c..1735881f53f1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -49,11 +49,12 @@ can_auto_enable_marlin_fp8, cutlass_fp8_supported, dispatch_w8a8_block_fp8_linear, + dispatch_w8a8_mxfp8_linear, + get_fp8_gemm_runner_backend, input_to_float8, mxfp8_group_quantize, normalize_e4m3fn_to_e4m3fnuz, requant_weight_ue8m0_inplace, - triton_mxfp8_blockscaled_linear, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.marlin_utils_fp8 import ( @@ -71,6 +72,7 @@ per_tensor_dequantize, requantize_with_max_scale, ) +from sglang.srt.layers.utils import copy_or_rebind_param from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -268,7 +270,12 @@ def __init__(self, quant_config: Union[Fp8Config, W4AFp8Config]): self.block_quant = ( self.use_mxfp8 or self.quant_config.weight_block_size is not None ) - self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() + self.w8a8_block_fp8_linear = None + self.w8a8_mxfp8_linear = None + if self.use_mxfp8: + self.w8a8_mxfp8_linear = dispatch_w8a8_mxfp8_linear() + else: + self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() self.is_checkpoint_fp8_serialized = ( self.quant_config.is_checkpoint_fp8_serialized ) @@ -441,6 +448,7 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: # Keep parameter object to preserve weight_loader attrs for hot reload. layer.weight_scale_inv.requires_grad_(False) layer.weight_scale_inv.format_ue8m0 = True + self._process_mxfp8_linear_weight_scale(layer) return else: # For fp8 linear weights run with deepgemm, the weights and scales need be requantized to ue8m0 @@ -474,6 +482,25 @@ def process_weights_after_loading_block_quant(self, layer: Module) -> None: layer.weight.data = weight.data layer.weight_scale_inv.data = weight_scale.data + def _process_mxfp8_linear_weight_scale(self, layer: Module) -> None: + if not self.use_mxfp8: + return + + if get_fp8_gemm_runner_backend().is_flashinfer_trtllm(): + from flashinfer import block_scale_interleave + + scale_u8 = layer.weight_scale_inv.data + new_swizzled = block_scale_interleave(scale_u8.contiguous()).contiguous() + else: + # Triton path consumes canonical 2D UE8M0 scales directly. + return + + copy_or_rebind_param(layer, "weight_scale_inv_swizzled", new_swizzled) + layer._weight_scale_inv_swizzled_src_version = layer.weight_scale_inv._version + layer._weight_scale_inv_swizzled_src_data_ptr = ( + layer.weight_scale_inv.data_ptr() + ) + def _quantize_mxfp8_weights(self, layer: Module) -> None: weight = layer.weight.data qweight, weight_scale = mxfp8_group_quantize(weight) @@ -489,6 +516,7 @@ def _quantize_mxfp8_weights(self, layer: Module) -> None: "weight_scale_inv", Parameter(weight_scale, requires_grad=False) ) layer.weight_scale_inv.format_ue8m0 = True + self._process_mxfp8_linear_weight_scale(layer) layer.input_scale = None def process_weights_after_loading(self, layer: Module) -> None: @@ -621,18 +649,22 @@ def apply( ) if self.use_mxfp8: + if get_fp8_gemm_runner_backend().is_flashinfer_trtllm(): + weight_scale = layer.weight_scale_inv_swizzled + else: + weight_scale = layer.weight_scale_inv if isinstance(x, tuple): - return triton_mxfp8_blockscaled_linear( + return self.w8a8_mxfp8_linear( input=x[0], weight=layer.weight, - weight_scale=layer.weight_scale_inv, + weight_scale=weight_scale, input_scale=x[1], bias=bias, ) - return triton_mxfp8_blockscaled_linear( + return self.w8a8_mxfp8_linear( input=x, weight=layer.weight, - weight_scale=layer.weight_scale_inv, + weight_scale=weight_scale, input_scale=None, bias=bias, ) @@ -1105,6 +1137,19 @@ def _quantize_and_swizzle_with_triton_kernel(weight: torch.Tensor): scale = _swizzle_with_triton_kernel(weight.shape, scale) return qweight, scale + def _quantize_with_flashinfer_trtllm(weight: torch.Tensor): + weight = weight.contiguous() + num_experts, m, k = weight.shape + assert k % 32 == 0, f"{k=} must be divisible by 32 for MXFP8" + from flashinfer import mxfp8_quantize + + weight_flat = weight.view(-1, k).contiguous() + qweight, scale = mxfp8_quantize(weight_flat, False) + scale_u8 = ( + scale.view(torch.uint8).contiguous().view(num_experts, m, k // 32) + ) + return qweight.view_as(weight), scale_u8 + if quantize: if get_moe_runner_backend().is_cutlass(): w13_q, w13_s = _quantize_and_swizzle_with_cutlass_es_kernel( @@ -1113,6 +1158,15 @@ def _quantize_and_swizzle_with_triton_kernel(weight: torch.Tensor): w2_q, w2_s = _quantize_and_swizzle_with_cutlass_es_kernel( layer.w2_weight.data ) + elif ( + get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ): + # Match FlashInfer TRT-LLM MoE test contracts: + # 1) quantize in canonical (non-swizzled) scale layout, and + # 2) do row/layout shuffling in align_mxfp8_moe_weights_for_flashinfer_trtllm. + w13_q, w13_s = _quantize_with_flashinfer_trtllm(layer.w13_weight.data) + w2_q, w2_s = _quantize_with_flashinfer_trtllm(layer.w2_weight.data) else: w13_q, w13_s = _quantize_and_swizzle_with_triton_kernel( layer.w13_weight.data @@ -1121,14 +1175,23 @@ def _quantize_and_swizzle_with_triton_kernel(weight: torch.Tensor): layer.w2_weight.data ) else: - w13_q = layer.w13_weight.data - w2_q = layer.w2_weight.data - w13_s = _swizzle_with_triton_kernel( - layer.w13_weight.data.shape, layer.w13_weight_scale_inv.data - ) - w2_s = _swizzle_with_triton_kernel( - layer.w2_weight.data.shape, layer.w2_weight_scale_inv.data - ) + if ( + get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ): + w13_q = layer.w13_weight.data + w2_q = layer.w2_weight.data + w13_s = layer.w13_weight_scale_inv.data + w2_s = layer.w2_weight_scale_inv.data + else: + w13_q = layer.w13_weight.data + w2_q = layer.w2_weight.data + w13_s = _swizzle_with_triton_kernel( + layer.w13_weight.data.shape, layer.w13_weight_scale_inv.data + ) + w2_s = _swizzle_with_triton_kernel( + layer.w2_weight.data.shape, layer.w2_weight_scale_inv.data + ) # Keep parameter objects to preserve weight_loader attrs for hot reload. # Prefer in-place copy; rebind only when shape/dtype changes (online quantize). @@ -1154,6 +1217,16 @@ def _copy_or_rebind(param: Parameter, new_value: torch.Tensor) -> None: layer.w13_input_scale = None layer.w2_input_scale = None + if ( + get_moe_runner_backend().is_flashinfer_trtllm() + or get_moe_runner_backend().is_flashinfer_trtllm_routed() + ): + from sglang.srt.layers.moe.moe_runner.flashinfer_trtllm import ( + align_mxfp8_moe_weights_for_flashinfer_trtllm, + ) + + align_mxfp8_moe_weights_for_flashinfer_trtllm(layer) + def process_weights_after_loading(self, layer: Module) -> None: if _is_hip and _use_hip_int4: self.process_weights_hip_int4(layer) @@ -1376,6 +1449,7 @@ def create_moe_runner( moe_runner_backend.is_deep_gemm() or moe_runner_backend.is_triton() or moe_runner_backend.is_flashinfer_trtllm() + or moe_runner_backend.is_flashinfer_trtllm_routed() ): self.runner = MoeRunner(moe_runner_backend, moe_runner_config) else: @@ -1504,9 +1578,14 @@ def apply( w2_scale=w2_scale, block_shape=block_shape, ) - elif self.runner.runner_backend.is_flashinfer_trtllm(): + elif ( + self.runner.runner_backend.is_flashinfer_trtllm() + or self.runner.runner_backend.is_flashinfer_trtllm_routed() + ): # FlashInfer TRT-LLM backend only supports fused execution and consumes # router logits directly (no separate apply_with_router_logits needed). + # FlashInfer TRT-LLM routed backend consumes SGLang-computed + # top-k ids/weights (packed into int32) instead of router logits. global_num_experts = int(getattr(layer, "num_experts")) num_local_experts = int(getattr(layer, "num_local_experts")) moe_ep_rank = int(getattr(layer, "moe_ep_rank")) @@ -1522,6 +1601,7 @@ def apply( getattr(layer, "routing_method_type", RoutingMethodType.DeepSeekV3) ), block_quant=self.block_quant, + use_mxfp8=getattr(self.quant_config, "use_mxfp8", False), weight_block_k=( None if self.quant_config.weight_block_size is None diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index ce65b5a01e6a..f072264f6eef 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -185,6 +185,8 @@ def _check_cutlass_block_fp8_hardware_support() -> bool: if is_blackwell_supported() and is_flashinfer_available(): + from flashinfer import mm_mxfp8 as _raw_flashinfer_mm_mxfp8 + from flashinfer import mxfp8_quantize as _raw_flashinfer_mxfp8_quantize from flashinfer.gemm import gemm_fp8_nt_groupwise as _raw_gemm_fp8_nt_groupwise from sglang.srt.utils.custom_op import register_custom_op @@ -242,6 +244,62 @@ def gemm_fp8_nt_groupwise( backend=backend, ) + # Wrap MXFP8 ops as custom ops so torch.compile does not trace into + # flashinfer's JIT compilation path (filesystem checks/cubin loader). + def _fake_flashinfer_mxfp8_quantize( + input: torch.Tensor, + _is_sf_swizzled_layout: bool = True, + alignment: int = 32, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Fake mode only needs dtypes and output rank to propagate compile graph. + # The scale tensor shape is not consumed before the following fake mm op. + k_aligned = ((input.shape[1] + alignment - 1) // alignment) * alignment + q_input = input.new_empty( + (input.shape[0], k_aligned), dtype=torch.float8_e4m3fn + ) + scale = input.new_empty((1,), dtype=torch.uint8) + return q_input, scale + + @register_custom_op( + op_name="flashinfer_mxfp8_quantize", + mutates_args=[], + fake_impl=_fake_flashinfer_mxfp8_quantize, + ) + def flashinfer_mxfp8_quantize( + input: torch.Tensor, + is_sf_swizzled_layout: bool = True, + alignment: int = 32, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return _raw_flashinfer_mxfp8_quantize( + input, + is_sf_swizzled_layout=is_sf_swizzled_layout, + alignment=alignment, + ) + + @register_custom_op( + op_name="flashinfer_mm_mxfp8", + mutates_args=[], + fake_impl=lambda q_input, weight_t, x_scale_u8, weight_scale_t, out_dtype, backend="auto": ( + q_input.new_empty((q_input.shape[0], weight_t.shape[1]), dtype=out_dtype) + ), + ) + def flashinfer_mm_mxfp8( + q_input: torch.Tensor, + weight_t: torch.Tensor, + x_scale_u8: torch.Tensor, + weight_scale_t: torch.Tensor, + out_dtype: torch.dtype, + backend: str = "auto", + ) -> torch.Tensor: + return _raw_flashinfer_mm_mxfp8( + q_input, + weight_t, + x_scale_u8, + weight_scale_t, + out_dtype=out_dtype, + backend=backend, + ) + if is_sm90_supported() and is_flashinfer_available(): # FlashInfer SM90 DeepGEMM with automatic swapAB optimization for small M @@ -266,6 +324,18 @@ def dispatch_w8a8_block_fp8_linear() -> Callable: return _dispatch_auto_backend() +def dispatch_w8a8_mxfp8_linear() -> Callable: + """Dispatch MXFP8 linear kernel by --fp8-gemm-backend. + + For MXFP8, Triton remains the default path. We only route to FlashInfer + when backend is explicitly set to flashinfer_trtllm. + """ + backend = get_fp8_gemm_runner_backend() + if backend.is_flashinfer_trtllm(): + return flashinfer_mxfp8_blockscaled_linear + return triton_mxfp8_blockscaled_linear + + def _dispatch_explicit_backend(backend: Fp8GemmRunnerBackend) -> Callable: """Dispatch based on explicitly selected backend.""" if backend.is_flashinfer_trtllm(): @@ -843,6 +913,61 @@ def triton_mxfp8_blockscaled_linear( return output.to(dtype=output_dtype).view(*output_shape) +def flashinfer_mxfp8_blockscaled_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """MXFP8 dense linear via FlashInfer mm_mxfp8.""" + input_2d = input.view(-1, input.shape[-1]).contiguous() + output_shape = [*input.shape[:-1], weight.shape[0]] + + m, k = input_2d.shape + n, k_w = weight.shape + if k != k_w: + raise ValueError(f"Input K={k} does not match weight K={k_w}.") + if k % 32 != 0: + raise ValueError(f"K={k} must be divisible by 32 for MXFP8.") + if weight.dtype != torch.float8_e4m3fn: + raise TypeError("MXFP8 weight must be FP8 E4M3.") + + if input_scale is None: + q_input, x_scale_u8 = flashinfer_mxfp8_quantize( + input_2d, is_sf_swizzled_layout=True, alignment=32 + ) + else: + q_input = input_2d + + if output_dtype is None: + if input_2d.dtype in (torch.float16, torch.bfloat16, torch.float32): + output_dtype = input_2d.dtype + else: + output_dtype = torch.bfloat16 + + # Ensure transposed tensors are contiguous for FlashInfer's internal runner. + weight_t = weight.contiguous().t() + weight_scale_t = ( + weight_scale.contiguous().t() + if weight_scale.ndim == 2 + else weight_scale.contiguous() + ) + output = flashinfer_mm_mxfp8( + q_input, + weight_t, + x_scale_u8, + weight_scale_t, + out_dtype=output_dtype, + backend="auto", + ) + + if bias is not None: + output += bias + return output.to(dtype=output_dtype).view(*output_shape) + + def dequant_mxfp4( w_block: torch.Tensor, w_scale: torch.Tensor, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6bfa8ecb30ea..657c4c1b9870 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -179,6 +179,7 @@ "triton", "triton_kernel", "flashinfer_trtllm", + "flashinfer_trtllm_routed", "flashinfer_cutlass", "flashinfer_mxfp4", "flashinfer_cutedsl", @@ -2454,12 +2455,19 @@ def _handle_data_parallelism(self): def _handle_moe_kernel_config(self): if self.quantization == "mxfp8": - if self.moe_runner_backend not in ["auto", "cutlass"]: + if self.moe_runner_backend == "auto": + self.moe_runner_backend = "flashinfer_trtllm" + elif self.moe_runner_backend not in [ + "cutlass", + "flashinfer_trtllm", + "flashinfer_trtllm_routed", + ]: logger.warning( - "mxfp8 quantization forces --moe-runner-backend=cutlass. " + "mxfp8 quantization supports only cutlass, flashinfer_trtllm, " + "or flashinfer_trtllm_routed backends. " f"Overriding {self.moe_runner_backend!r}." ) - self.moe_runner_backend = "cutlass" + self.moe_runner_backend = "flashinfer_trtllm" if self.moe_runner_backend == "flashinfer_cutlass": assert self.quantization in [ @@ -2476,6 +2484,7 @@ def _handle_moe_kernel_config(self): assert self.quantization in [ "modelopt_fp4", "fp8", + "mxfp8", "modelopt_fp8", "compressed-tensors", None, @@ -2485,6 +2494,16 @@ def _handle_moe_kernel_config(self): "FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set." ) + if self.moe_runner_backend == "flashinfer_trtllm_routed": + assert self.quantization in [ + "fp8", + "mxfp8", + ], f"Invalid quantization '{self.quantization}'. \nFlashInfer TRTLLM routed MOE supports only: 'fp8' or 'mxfp8'." + self.disable_shared_experts_fusion = True + logger.warning( + "FlashInfer TRTLLM routed MoE is enabled. --disable-shared-experts-fusion is automatically set." + ) + if get_bool_env_var("SGLANG_CUTLASS_MOE"): logger.warning( "SGLANG_CUTLASS_MOE is deprecated, use --moe-runner-backend=cutlass and/or --speculative-moe-runner-backend=cutlass instead" @@ -2691,7 +2710,8 @@ def _handle_speculative_decoding(self): if self.speculative_moe_runner_backend is None: self.speculative_moe_runner_backend = ( "auto" - if self.moe_runner_backend == "flashinfer_trtllm" + if self.moe_runner_backend + in ["flashinfer_trtllm", "flashinfer_trtllm_routed"] else self.moe_runner_backend ) else: diff --git a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py index 63db2b2ad1cc..df0e3af457cb 100644 --- a/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py +++ b/test/registered/backends/test_flashinfer_trtllm_gen_moe_backend.py @@ -12,10 +12,12 @@ popen_launch_server, ) -register_cuda_ci(est_time=300, suite="nightly-4-gpu-b200", nightly=True) +register_cuda_ci(est_time=500, suite="nightly-4-gpu-b200", nightly=True) -class TestFlashinferTrtllmGenMoeBackendFP8(CustomTestCase): +class FlashinferTrtllmGenMoeBackendFP8Base: + backend = None + @classmethod def setUpClass(cls): cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8" @@ -29,7 +31,7 @@ def setUpClass(cls): "--attention-backend", "triton", "--moe-runner-backend", - "flashinfer_trtllm", + cls.backend, "--tp-size", "4", "--ep-size", @@ -60,7 +62,9 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.93) -class TestFlashinferTrtllmGenMoeBackendBF16(CustomTestCase): +class FlashinferTrtllmGenMoeBackendBF16Base: + backend = None + @classmethod def setUpClass(cls): cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct" @@ -73,7 +77,7 @@ def setUpClass(cls): "--attention-backend", "triton", "--moe-runner-backend", - "flashinfer_trtllm", + cls.backend, "--cuda-graph-max-bs", "512", "--tp-size", @@ -106,5 +110,82 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.93) +class FlashinferTrtllmGenMoeBackendMXFP8Base: + backend = None + + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen3-30B-A3B-Instruct-2507" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + env={**os.environ, "SGLANG_ENABLE_JIT_DEEPGEMM": "False"}, + other_args=[ + "--quantization", + "mxfp8", + "--fp8-gemm-backend", + "flashinfer_trtllm", + "--moe-runner-backend", + cls.backend, + "--tp-size", + "4", + "--ep-size", + "4", + "--mem-fraction-static", + "0.7", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.93) + + +class TestFlashinferTrtllmGenMoeBackendFP8( + FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase +): + backend = "flashinfer_trtllm" + + +class TestFlashinferTrtllmGenMoeBackendMXFP8( + FlashinferTrtllmGenMoeBackendMXFP8Base, CustomTestCase +): + backend = "flashinfer_trtllm" + + +class TestFlashinferTrtllmGenMoeBackendBF16( + FlashinferTrtllmGenMoeBackendBF16Base, CustomTestCase +): + backend = "flashinfer_trtllm" + + +class TestFlashinferTrtllmGenMoeBackendFP8Routed( + FlashinferTrtllmGenMoeBackendFP8Base, CustomTestCase +): + backend = "flashinfer_trtllm_routed" + + +class TestFlashinferTrtllmGenMoeBackendMXFP8Routed( + FlashinferTrtllmGenMoeBackendMXFP8Base, CustomTestCase +): + backend = "flashinfer_trtllm_routed" + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/quant/test_fp8_blockwise_gemm.py b/test/registered/quant/test_fp8_blockwise_gemm.py index b832e4740668..af7600cb5380 100644 --- a/test/registered/quant/test_fp8_blockwise_gemm.py +++ b/test/registered/quant/test_fp8_blockwise_gemm.py @@ -12,9 +12,10 @@ try_cached_model, ) -register_cuda_ci(est_time=280, suite="stage-c-test-4-gpu-b200") +register_cuda_ci(est_time=420, suite="stage-c-test-4-gpu-b200") MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507-FP8" +BF16_MODEL_PATH = "Qwen/Qwen3-4B-Instruct-2507" class FP8BlockwiseGemmBase: @@ -56,7 +57,51 @@ def test_gsm8k(self): metrics = run_eval_few_shot_gsm8k(args) print(metrics) - self.assertGreaterEqual(metrics["accuracy"], 0.41) + self.assertGreaterEqual(metrics["accuracy"], 0.8) + + +class MXFP8GemmBase: + backend = None + + @classmethod + def setUpClass(cls): + if cls.backend is None: + raise NotImplementedError("Subclass must set 'backend' attribute") + cls.model = try_cached_model(BF16_MODEL_PATH) + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--quantization", + "mxfp8", + "--fp8-gemm-backend", + cls.backend, + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + parsed_url = urlparse(self.base_url) + args = SimpleNamespace( + num_shots=8, + data_path=None, + num_questions=1319, + max_new_tokens=512, + parallel=200, + host=f"{parsed_url.scheme}://{parsed_url.hostname}", + port=parsed_url.port, + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreaterEqual(metrics["accuracy"], 0.8) class TestFP8BlockwiseGemmTriton(FP8BlockwiseGemmBase, unittest.TestCase): @@ -77,5 +122,15 @@ class TestFP8BlockwiseGemmFlashinferDeepGemm(FP8BlockwiseGemmBase, unittest.Test backend = "flashinfer_deepgemm" +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestMXFP8GemmTriton(MXFP8GemmBase, unittest.TestCase): + backend = "triton" + + +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestMXFP8GemmFlashinferTrtllm(MXFP8GemmBase, unittest.TestCase): + backend = "flashinfer_trtllm" + + if __name__ == "__main__": unittest.main()