diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index b5663e44be18..bd406a903693 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -243,6 +243,18 @@ class BypassedTopKOutput(NamedTuple): def format(self) -> TopKOutputFormat: return TopKOutputFormat.BYPASSED + def to_standard(self, layer_id: Optional[int] = None) -> "StandardTopKOutput": + """Materialize routing tensors. Used by MoE kernels that need explicit + topk_ids / topk_weights rather than doing routing internally.""" + return select_experts( + hidden_states=self.hidden_states, + router_logits=self.router_logits, + topk_config=self.topk_config, + layer_id=layer_id, + num_token_non_padded=self.num_token_non_padded, + expert_location_dispatch_info=self.expert_location_dispatch_info, + ) + # -------------------------------- TopK --------------------------------------- diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 291e2c2de3bd..cd74a1c280e2 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -262,6 +262,15 @@ def get_quant_method( return Mxfp4MarlinMoEMethod(fp8_method, prefix=prefix) if self.is_fp4_experts and get_moe_runner_backend().is_flashinfer_mxfp4(): + # SM100 (Blackwell) -> trtllm-gen path. + # SM90 (Hopper) -> cutlass mixed-input path (FlashInfer #3084). + if is_sm90_supported() and not is_sm100_supported(): + from sglang.srt.layers.quantization.mxfp4_flashinfer_cutlass_moe import ( + Mxfp4FlashinferCutlassMoEMethod, + ) + + return Mxfp4FlashinferCutlassMoEMethod(fp8_method, prefix=prefix) + from sglang.srt.layers.quantization.mxfp4_flashinfer_trtllm_moe import ( Mxfp4FlashinferTrtllmMoEMethod, ) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index d5ad4d403493..1beeefab895d 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -16,12 +16,18 @@ from __future__ import annotations +import os from dataclasses import replace from typing import TYPE_CHECKING, List, Optional import torch from torch.nn.parameter import Parameter +# Silence the TRT-LLM cutlass autotune trace embedded inside FlashInfer's +# cutlass_fused_moe. Its C++ logger reads TLLM_LOG_LEVEL on first kernel launch; +# setdefault preserves any explicit user override. +os.environ.setdefault("TLLM_LOG_LEVEL", "INFO") + from sglang.srt.distributed import get_tp_group from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, @@ -62,7 +68,27 @@ nvfp4_block_scale_interleave, trtllm_fp4_block_scale_moe, ) - from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache + from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe + from flashinfer.fused_moe.core import ( + ActivationType, + get_w2_permute_indices_with_cache, + ) + + # SM90 mixed-input helpers landed in FlashInfer #3084 (post-0.6.10). Older + # versions don't ship them; gate at import so unrelated code paths still load. + try: + from flashinfer.fused_moe import ( + interleave_moe_scales_for_sm90_mixed_gemm, + interleave_moe_weights_for_sm90_mixed_gemm, + ) + + _FI_HAS_SM90_CUTLASS_MXFP4 = True + except ImportError: + interleave_moe_scales_for_sm90_mixed_gemm = None + interleave_moe_weights_for_sm90_mixed_gemm = None + _FI_HAS_SM90_CUTLASS_MXFP4 = False +else: + _FI_HAS_SM90_CUTLASS_MXFP4 = False _flashinfer_mxfp4_permute_indices_cache: dict[torch.Size, torch.Tensor] = {} _flashinfer_mxfp4_permute_indices_device_cache: dict[ @@ -318,6 +344,28 @@ def __init__( self.flashinfer_mxfp4_moe_precision = ( get_global_server_args().flashinfer_mxfp4_moe_precision ) + # When `flashinfer_mxfp4` is enabled, dispatch to one of two FlashInfer + # entry points depending on the GPU: + # - SM100 (Blackwell) -> trtllm_fp4_block_scale_moe (existing) + # - SM90 (Hopper) -> cutlass_fused_moe(use_w4_group_scaling=True) + # (FlashInfer PR #3084, post-0.6.10) + self._fi_kernel: Optional[str] = None + if self.use_flashinfer: + if is_sm100_supported(): + self._fi_kernel = "trtllm_sm100" + elif is_sm90_supported(): + if not _FI_HAS_SM90_CUTLASS_MXFP4: + raise RuntimeError( + "moe_runner_backend=flashinfer_mxfp4 on SM90 requires the " + "interleave_moe_{weights,scales}_for_sm90_mixed_gemm helpers " + "from FlashInfer PR #3084 (>= 0.6.11). Upgrade flashinfer-python " + "or pick a different backend (e.g. marlin / triton_kernel)." + ) + self._fi_kernel = "cutlass_sm90" + else: + raise NotImplementedError( + "moe_runner_backend=flashinfer_mxfp4 requires SM90 or SM100." + ) def create_weights( self, @@ -349,6 +397,26 @@ def create_weights( intermediate_size_per_partition_after_pad = round_up( intermediate_size_per_partition, triton_kernels_padding_alignment ) + elif self._fi_kernel == "cutlass_sm90": + # cutlass mixed-input GEMM contraction dim K must be % 128 == 0 + # (interleave factor for MXFP4 group_size=32 is 4). The kernel + # also expects ``fc1_expert_weights`` in halved ``[up; gate]`` + # layout, which means the padding boundary must fall on the + # gate / up split. + # + # The mxfp4 weight loader (FusedMoE.weight_loader fast path) does + # a NAIVE copy of HF's ``[2*intermediate_size, hidden_packed]`` + # tensor into the buffer's ``[:dim1, :dim2]`` slice. Padding the + # buffer here would push the gate/up boundary, so HF's "up" + # rows would land in the buffer's "gate" half and vice versa. + # Marlin sidesteps this by not padding; we do the same and + # rebuild a properly-padded buffer in + # ``_process_weights_for_sm90_cutlass`` after the load completes. + self._padded_intermediate = round_up(intermediate_size_per_partition, 128) + self._padded_hidden = round_up(hidden_size, 128) + # create_weights below uses the *unpadded* sizes so the loader's + # naive-copy fast path is correct. + intermediate_size_per_partition_after_pad = intermediate_size_per_partition elif _use_aiter: intermediate_size_per_partition_after_pad = round_up( @@ -438,6 +506,9 @@ def create_weights( set_weight_attrs(w2_weight_bias, extra_weight_attrs) def process_weights_after_loading(self, layer): + if self._fi_kernel == "cutlass_sm90": + self._process_weights_for_sm90_cutlass(layer) + return if self.use_flashinfer: # TODO: these values are hardcoded for now, we need to get them from the model layer.gemm1_alpha = Parameter( @@ -736,6 +807,133 @@ def swap_every_two_rows(x, axis=-1): layer.w2_weight = Parameter(w2_weight.data, requires_grad=False) torch.cuda.empty_cache() + def _process_weights_for_sm90_cutlass(self, layer): + """De-interleave + pad + halving-swap + byte-interleave MXFP4 weights + for FlashInfer's SM90 ``cutlass_fused_moe(use_w4_group_scaling=True)`` + path (PR #3084). + + The cutlass kernel needs (a) K (contraction dim) % 128 == 0, and (b) + ``fc1_expert_weights`` in halved ``[up; gate]`` order -- the + ``compute_with_experts`` reference in FlashInfer's + ``test_trtllm_cutlass_fused_moe.py`` splits + ``w3, w1 = chunk(W, 2, dim=0)`` and uses w3 as up, w1 as gate. + + GPT-OSS's HF layout is *interleaved* ``[g_0, u_0, g_1, u_1, ..., g_{N-1}, u_{N-1}]`` + (each pair occupies two adjacent rows). The mxfp4 weight loader does + a naive copy, so our unpadded buffer is interleaved post-load. We + de-interleave (even rows -> gate, odd rows -> up), pad each half from + N_un to N_pad, concatenate as halved ``[up; gate]``, and then run + FlashInfer's byte / scale interleave helpers. + """ + sf_block_size = 32 # MXFP4 group size + + # Sizes from the unpadded loaded buffers. + N_un = layer.w13_weight.shape[1] // 2 # intermediate (unpadded) + K_un = ( + layer.w13_weight.shape[2] * 2 + ) # hidden (unpadded, *2 because packed 4-bit) + N_pad = self._padded_intermediate + K_pad = self._padded_hidden + # Use the local expert count (matches the existing buffer allocation in + # create_weights) so the SM90 cutlass path remains correct under + # Expert Parallelism. `self.num_experts` is the *global* count. + E = layer.num_local_experts + device = layer.w13_weight.device + bias_dtype = layer.w13_weight_bias.dtype + + # ---- De-interleave + pad w13 weight/scale/bias to halved [up; gate] + # Even rows of HF = gate, odd rows = up. After splitting we pad each + # half along its row dim (N) from N_un to N_pad with zeros, and along + # its last dim (K) from K_un (or K_un / sf_block_size) to K_pad. + + def _stack_up_gate_w13(unpadded_w13, last_pad, last_un): + # unpadded_w13: [E, 2*N_un, last_un] + # Returns: [E, 2*N_pad, last_pad] in [up_padded; gate_padded] order. + gate_rows = unpadded_w13[:, 0::2, :] # [E, N_un, last_un] + up_rows = unpadded_w13[:, 1::2, :] # [E, N_un, last_un] + out = torch.zeros( + E, 2 * N_pad, last_pad, dtype=unpadded_w13.dtype, device=device + ) + # First half: up (with row + col padding zeros). + out[:, :N_un, :last_un] = up_rows + # Second half: gate. + out[:, N_pad : N_pad + N_un, :last_un] = gate_rows + return out + + w13_padded = _stack_up_gate_w13( + layer.w13_weight.data.view(torch.uint8), K_pad // 2, K_un // 2 + ) + w13_scale_padded = _stack_up_gate_w13( + layer.w13_weight_scale.data, + K_pad // sf_block_size, + K_un // sf_block_size, + ) + # Bias: same de-interleave on dim=-1. + w13_bias_gate = layer.w13_weight_bias.data[:, 0::2] # [E, N_un] + w13_bias_up = layer.w13_weight_bias.data[:, 1::2] # [E, N_un] + w13_bias_padded = torch.zeros(E, 2 * N_pad, dtype=bias_dtype, device=device) + w13_bias_padded[:, :N_un] = w13_bias_up + w13_bias_padded[:, N_pad : N_pad + N_un] = w13_bias_gate + + def _pad_w2_3d(unpadded, last_pad, last_un): + out = torch.zeros(E, K_pad, last_pad, dtype=unpadded.dtype, device=device) + out[:, :K_un, :last_un] = unpadded[:, :K_un, :] + return out + + # ---- w2 (no halving, just pad to [E, K_pad, N_pad/2]) ---------------- + w2_padded = _pad_w2_3d( + layer.w2_weight.data.view(torch.uint8), N_pad // 2, N_un // 2 + ) + w2_scale_padded = _pad_w2_3d( + layer.w2_weight_scale.data, + N_pad // sf_block_size, + N_un // sf_block_size, + ) + w2_bias_padded = torch.zeros(E, K_pad, dtype=bias_dtype, device=device) + w2_bias_padded[:, :K_un] = layer.w2_weight_bias.data + + # ---- Per-expert SwiGLU scalars (GPT-OSS defaults) ------------------ + layer.swiglu_alpha = Parameter( + torch.full((E,), 1.702, dtype=torch.float32, device=device), + requires_grad=False, + ) + layer.swiglu_beta = Parameter( + torch.full((E,), 1.0, dtype=torch.float32, device=device), + requires_grad=False, + ) + layer.swiglu_limit = Parameter( + torch.full((E,), 7.0, dtype=torch.float32, device=device), + requires_grad=False, + ) + + # ---- FlashInfer SM90 byte / scale interleave ----------------------- + # The padded buffers above are contiguous by construction (allocated + # via torch.zeros + slice assignment), so we feed them straight in. + layer.w13_weight = Parameter( + interleave_moe_weights_for_sm90_mixed_gemm(w13_padded, "fp4"), + requires_grad=False, + ) + layer.w2_weight = Parameter( + interleave_moe_weights_for_sm90_mixed_gemm(w2_padded, "fp4"), + requires_grad=False, + ) + layer.w13_weight_scale = Parameter( + interleave_moe_scales_for_sm90_mixed_gemm( + w13_scale_padded, group_size=sf_block_size + ), + requires_grad=False, + ) + layer.w2_weight_scale = Parameter( + interleave_moe_scales_for_sm90_mixed_gemm( + w2_scale_padded, group_size=sf_block_size + ), + requires_grad=False, + ) + layer.w13_weight_bias = Parameter(w13_bias_padded, requires_grad=False) + layer.w2_weight_bias = Parameter(w2_bias_padded, requires_grad=False) + + torch.cuda.empty_cache() + def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): @@ -761,6 +959,74 @@ def create_moe_runner( # TODO(cwan): refactor other backends pass + def _apply_sm90_cutlass(self, layer, x, topk_output): + """SM90 (Hopper) MXFP4 x BF16 MoE via FlashInfer's cutlass mixed-input + path (PR #3084). The fused kernel does GEMM1 + SwiGLU + GEMM2 in one + call; weights/scales were pre-interleaved at load time.""" + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + from sglang.srt.layers.moe.topk import TopKOutputChecker + + # Under ``--moe-runner-backend flashinfer_mxfp4`` the SGLang TopK layer + # emits BypassedTopKOutput by default (the SM100 trtllm-gen kernel does + # routing internally). The cutlass kernel needs explicit topk_ids / + # topk_weights, so materialize them here when bypassed. + if TopKOutputChecker.format_is_bypassed(topk_output): + topk_output = topk_output.to_standard() + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + # Pad input hidden dim to the (already-padded) loaded weight width. + origin_hidden = x.shape[-1] + padded_hidden = self._padded_hidden + if padded_hidden != origin_hidden: + x = torch.nn.functional.pad( + x, + (0, padded_hidden - origin_hidden), + mode="constant", + value=0.0, + ) + + output_dtype = torch.bfloat16 + # Output is allocated at padded width (kernel writes padded_hidden + # columns), then trimmed back to origin_hidden before returning. + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + out_padded = torch.empty( + x.shape[0], padded_hidden, dtype=output_dtype, device=x.device + ) + + flashinfer_cutlass_fused_moe( + input=x, + token_selected_experts=topk_ids.to(torch.int), + token_final_scales=topk_weights, + fc1_expert_weights=layer.w13_weight, # uint8 [E, 2*N, K/2] interleaved + fc2_expert_weights=layer.w2_weight, # uint8 [E, K, N/2] interleaved + output_dtype=output_dtype, + quant_scales=[ + layer.w13_weight_scale.view(torch.int32), + layer.w2_weight_scale.view(torch.int32), + ], + fc1_expert_biases=layer.w13_weight_bias, # bf16 [E, 2*N] + fc2_expert_biases=layer.w2_weight_bias, # bf16 [E, K] + swiglu_alpha=layer.swiglu_alpha, + swiglu_beta=layer.swiglu_beta, + swiglu_limit=layer.swiglu_limit, + tp_size=layer.moe_tp_size, + tp_rank=layer.moe_tp_rank, + ep_size=layer.moe_ep_size, + ep_rank=layer.moe_ep_rank, + use_w4_group_scaling=True, + activation_type=ActivationType.Swiglu, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + output=out_padded, + ) + + if padded_hidden != origin_hidden: + out = out_padded[:, :origin_hidden].contiguous() + else: + out = out_padded + return StandardCombineInput(hidden_states=out) + def apply( self, layer: torch.nn.Module, @@ -773,6 +1039,8 @@ def apply( x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output + if self._fi_kernel == "cutlass_sm90": + return self._apply_sm90_cutlass(layer, x, topk_output) if self.use_flashinfer: # When bf16 mode is enabled, we don't need to quantize the input, # TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations, diff --git a/python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py b/python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py new file mode 100644 index 000000000000..7fce478e17a3 --- /dev/null +++ b/python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py @@ -0,0 +1,263 @@ +"""DeepSeek-V4 MXFP4 expert backend backed by FlashInfer's SM90 cutlass +mixed-input MoE GEMM (FlashInfer PR #3084). + +Sibling of :class:`Mxfp4MarlinMoEMethod` and :class:`Mxfp4FlashinferTrtllmMoEMethod`. +Wired into :func:`Fp8MoEConfig.get_quant_method` when +``is_fp4_experts=True`` and ``--moe-runner-backend flashinfer_mxfp4`` is +selected on a Hopper (SM90) device. SM100 still routes to +:class:`Mxfp4FlashinferTrtllmMoEMethod` (trtllm-gen). + +Performance trade-off vs Marlin (kernel-level on H100, GPT-OSS-like body): + - decode (M <= 64) : Marlin +12-15 % + - tie (M ~= 256) + - prefill (M >= 1024) : FlashInfer +24-36 % + +PD-disaggregated prefill workers are the natural fit; decode workers should +keep the Marlin default. +""" + +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric +from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput +from sglang.srt.layers.moe.topk import TopKOutputChecker +from sglang.srt.utils import is_flashinfer_available, log_info_on_rank0 +from sglang.srt.utils.common import next_power_of_2 + +# Silence the TRT-LLM cutlass autotune trace embedded inside FlashInfer's +# cutlass_fused_moe. Its C++ logger reads TLLM_LOG_LEVEL on first kernel launch; +# setdefault preserves any explicit user override. +os.environ.setdefault("TLLM_LOG_LEVEL", "INFO") + +if is_flashinfer_available(): + from flashinfer.fused_moe import cutlass_fused_moe as flashinfer_cutlass_fused_moe + from flashinfer.fused_moe.core import ActivationType + + try: + from flashinfer.fused_moe import ( + interleave_moe_scales_for_sm90_mixed_gemm, + interleave_moe_weights_for_sm90_mixed_gemm, + ) + + _FI_HAS_SM90_CUTLASS_MXFP4 = True + except ImportError: + interleave_moe_scales_for_sm90_mixed_gemm = None + interleave_moe_weights_for_sm90_mixed_gemm = None + _FI_HAS_SM90_CUTLASS_MXFP4 = False +else: + _FI_HAS_SM90_CUTLASS_MXFP4 = False + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput + +# MXFP4 group/block size (E8M0 scale per 32 fp4 weights). +_GROUP_SIZE = 32 + + +class Mxfp4FlashinferCutlassMoEMethod: + """DeepSeek-V4 W4A16 MXFP4 MoE via FlashInfer's SM90 mixed-input cutlass + grouped GEMM. The fused kernel does GEMM1 + clamped SwiGLU + GEMM2 in one + call after a one-shot weight/scale interleave at load time.""" + + def __init__(self, fp8_method, prefix: str): + if not _FI_HAS_SM90_CUTLASS_MXFP4: + raise RuntimeError( + "Mxfp4FlashinferCutlassMoEMethod requires FlashInfer >= 0.6.11 " + "(PR #3084 SM90 mixed-input helpers). Older builds lack " + "interleave_moe_{weights,scales}_for_sm90_mixed_gemm; " + "either upgrade flashinfer-python or fall back to " + "--moe-runner-backend marlin." + ) + self._fp8 = fp8_method + self.prefix = prefix + self._swiglu_alpha_tensor: torch.Tensor | None = None + self._swiglu_beta_tensor: torch.Tensor | None = None + self._swiglu_limit_tensor: torch.Tensor | None = None + + # --- Lifecycle --------------------------------------------------------- + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype, + **extra_weight_attrs, + ): + # SM90 mixed-input GEMM: contraction dim K must be a multiple of 128 + # (interleave factor = 128 / group_size = 4). For DSv4 (hidden=7168, + # inter=2048) both are already multiples of 128; we assert rather than + # silently pad here, since padding the FP8-base buffers in-place would + # require deeper changes. + if hidden_size % 128 != 0 or intermediate_size_per_partition % 128 != 0: + raise ValueError( + "Mxfp4FlashinferCutlassMoEMethod requires hidden_size and " + "intermediate_size_per_partition to be multiples of 128 " + f"(got hidden={hidden_size}, " + f"intermediate={intermediate_size_per_partition})." + ) + # Raw weight shapes match what the fp8 base method allocates for fp4 + # experts (uint8 4-bit packed weights, fp32 E8M0 scales). Delegate. + self._fp8.create_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + + def create_moe_runner(self, layer: Module, moe_runner_config) -> None: + self.moe_runner_config = moe_runner_config + + # DSv4 uses standard SwiGLU plus a config-driven activation clamp. + # We pass all three (alpha, beta, limit) as explicit per-expert tensors + # rather than mixing tensors with None: the cutlass SwiGLU kernel + # branches on whether each is None, and partial-None inputs land in + # less-tested code paths. ``alpha=1.0``, ``beta=0.0`` reproduce plain + # ``silu(gate) * up``; ``limit`` enforces the activation clamp the + # checkpoint was trained with. + swiglu_limit = getattr(moe_runner_config, "swiglu_limit", None) + if swiglu_limit is not None: + E = layer.num_local_experts + device = layer.w13_weight.device + self._swiglu_alpha_tensor = torch.ones( + E, dtype=torch.float32, device=device + ) + self._swiglu_beta_tensor = torch.zeros( + E, dtype=torch.float32, device=device + ) + self._swiglu_limit_tensor = torch.full( + (E,), float(swiglu_limit), dtype=torch.float32, device=device + ) + else: + self._swiglu_alpha_tensor = None + self._swiglu_beta_tensor = None + self._swiglu_limit_tensor = None + + def process_weights_after_loading(self, layer: Module) -> None: + from sglang.srt.layers.quantization.utils import reorder_w1w3_to_w3w1 + + # Run the fp8 base hook first (ROCm normalization, mxfp8 requant, ...). + self._fp8.process_weights_after_loading(layer) + + if getattr(layer, "_mega_moe_weights_built", False): + return + + # cutlass_fused_moe expects fc1 in [w3; w1] = [up; gate] order, just + # like the trtllm-gen path. The HF / FP8 loader emits [w1; w3]. + w13, w13_s = reorder_w1w3_to_w3w1( + layer.w13_weight.data, layer.w13_weight_scale_inv.data + ) + layer.w13_weight = Parameter(w13, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_s, requires_grad=False) + + log_info_on_rank0( + logger, + f"Preparing DSv4 MXFP4 experts for FlashInfer SM90 cutlass " + f"(layer: {self.prefix})...", + ) + + # FP8 base stores scales as fp32 numerical values (= 2**e). The + # FlashInfer SM90 helper reads raw E8M0 bytes (uint8 with the + # exponent + 127 bias). Cast through float8_e8m0fnu to extract the + # raw byte without losing the exponent. + w13_scale_u8 = ( + layer.w13_weight_scale_inv.data.to(torch.float8_e8m0fnu) + .view(torch.uint8) + .contiguous() + ) + w2_scale_u8 = ( + layer.w2_weight_scale_inv.data.to(torch.float8_e8m0fnu) + .view(torch.uint8) + .contiguous() + ) + + # C++ byte interleave on packed 4-bit weights. + w13_il = interleave_moe_weights_for_sm90_mixed_gemm( + layer.w13_weight.data.view(torch.uint8).contiguous(), "fp4" + ) + w2_il = interleave_moe_weights_for_sm90_mixed_gemm( + layer.w2_weight.data.view(torch.uint8).contiguous(), "fp4" + ) + # Pure-PyTorch reshape+permute on E8M0 block scales. + w13_s_il = interleave_moe_scales_for_sm90_mixed_gemm( + w13_scale_u8, group_size=_GROUP_SIZE + ) + w2_s_il = interleave_moe_scales_for_sm90_mixed_gemm( + w2_scale_u8, group_size=_GROUP_SIZE + ) + + layer.w13_weight = Parameter(w13_il, requires_grad=False) + layer.w2_weight = Parameter(w2_il, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_s_il, requires_grad=False) + layer.w2_weight_scale_inv = Parameter(w2_s_il, requires_grad=False) + + layer._dsv4_mxfp4_backend = "flashinfer_cutlass_sm90" + torch.cuda.empty_cache() + + # --- Forward ----------------------------------------------------------- + + def apply( + self, + layer: Module, + dispatch_output: "DispatchOutput", + ) -> "CombineInput": + topk_output = dispatch_output.topk_output + if not TopKOutputChecker.format_is_standard(topk_output): + raise ValueError(f"Unsupported topk output format: {topk_output.format}") + + x = dispatch_output.hidden_states + topk_weights = topk_output.topk_weights + topk_ids = topk_output.topk_ids + + output_dtype = torch.bfloat16 + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + out = torch.empty( + x.shape[0], x.shape[-1], dtype=output_dtype, device=x.device + ) + + flashinfer_cutlass_fused_moe( + input=x, + token_selected_experts=topk_ids.to(torch.int), + token_final_scales=topk_weights, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + output_dtype=output_dtype, + quant_scales=[ + layer.w13_weight_scale_inv.view(torch.int32), + layer.w2_weight_scale_inv.view(torch.int32), + ], + fc1_expert_biases=None, # DSv4 has no MoE expert bias. + fc2_expert_biases=None, + swiglu_alpha=self._swiglu_alpha_tensor, # ones: standard SiLU gate + swiglu_beta=self._swiglu_beta_tensor, # zeros: standard up + swiglu_limit=self._swiglu_limit_tensor, + tp_size=layer.moe_tp_size, + tp_rank=layer.moe_tp_rank, + ep_size=layer.moe_ep_size, + ep_rank=layer.moe_ep_rank, + use_w4_group_scaling=True, + activation_type=ActivationType.Swiglu, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + output=out, + ) + + return StandardCombineInput(hidden_states=out) diff --git a/python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py b/python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py index dc398f491905..6ec432394786 100644 --- a/python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py +++ b/python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py @@ -445,12 +445,20 @@ def maybe_fuse_routed_scale_and_shared_add( # alpha=scale)`. With no shared output, the missing scale is applied # in-place. Otherwise `routed` is already scale-final and we just add # `shared` (or pass through if there is none). + from sglang.srt.layers.quantization.mxfp4_flashinfer_cutlass_moe import ( + Mxfp4FlashinferCutlassMoEMethod, + ) from sglang.srt.layers.quantization.mxfp4_marlin_moe import ( Mxfp4MarlinMoEMethod, ) fused = isinstance( - experts.quant_method, (Mxfp4FlashinferTrtllmMoEMethod, Mxfp4MarlinMoEMethod) + experts.quant_method, + ( + Mxfp4FlashinferTrtllmMoEMethod, + Mxfp4FlashinferCutlassMoEMethod, + Mxfp4MarlinMoEMethod, + ), ) if fused: if shared is not None: diff --git a/python/sglang/test/bench_mxfp4_sm90_kernels.py b/python/sglang/test/bench_mxfp4_sm90_kernels.py new file mode 100644 index 000000000000..afbbcde3409f --- /dev/null +++ b/python/sglang/test/bench_mxfp4_sm90_kernels.py @@ -0,0 +1,366 @@ +"""Benchmark MXFP4 MoE kernels on H100/H200: SGLang Marlin vs FlashInfer cutlass. + +Compares per-call latency of: + + * Marlin path : ``fused_marlin_moe(...)`` after Marlin weight repack + * FlashInfer : ``cutlass_fused_moe(use_w4_group_scaling=True, ...)`` + (PR #3084's SM90 mixed-input path) + +Both run on the same random MXFP4 weights/scales (semantics differ slightly -- +Marlin uses a scalar swiglu clamp + no bias, FlashInfer fuses per-expert +SwiGLU with bias -- so the timing comparison reports kernel cost for +*equivalent compute volume*, not bit-exact numerics). + +Run on H100/H200: + + cd /sgl-workspace/sglang_dev3 && \\ + PYTHONPATH=python:/sgl-workspace/flashinfer FLASHINFER_DISABLE_VERSION_CHECK=1 \\ + python python/sglang/test/bench_mxfp4_sm90_kernels.py +""" + +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from typing import Callable, List, Tuple + +import torch +from flashinfer.autotuner import autotune + +# ---- FlashInfer ---- +from flashinfer.fused_moe import ( + cutlass_fused_moe, + interleave_moe_scales_for_sm90_mixed_gemm, + interleave_moe_weights_for_sm90_mixed_gemm, +) +from flashinfer.fused_moe.core import ActivationType + +# ---- SGLang Marlin ---- +from sglang.jit_kernel.gptq_marlin_repack import gptq_marlin_repack +from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import fused_marlin_moe +from sglang.srt.layers.quantization.marlin_utils import ( + marlin_make_workspace, + marlin_permute_scales, +) +from sglang.srt.layers.quantization.marlin_utils_fp4 import mxfp4_marlin_process_scales + +GROUP_SIZE = 32 + + +@dataclass +class Shape: + tokens: int + hidden: int + inter: int + num_experts: int + top_k: int + + def label(self) -> str: + return ( + f"m={self.tokens:>4} h={self.hidden} i={self.inter} " + f"E={self.num_experts} k={self.top_k}" + ) + + +# Sweep tokens at a fixed GPT-OSS-like body (hidden=4096, inter=2048, E=256, +# topk=6 -- matches PR #3084's headline shape so the small-batch numbers stay +# directly comparable). Token range covers decode (4-256) and prefill chunks +# (1024-8192). +_BODY = dict(hidden=4096, inter=2048, num_experts=256, top_k=6) +DEFAULT_SHAPES: List[Shape] = [ + Shape(tokens=4, **_BODY), + Shape(tokens=16, **_BODY), + Shape(tokens=64, **_BODY), + Shape(tokens=256, **_BODY), + Shape(tokens=1024, **_BODY), + Shape(tokens=2048, **_BODY), + Shape(tokens=4096, **_BODY), + Shape(tokens=8192, **_BODY), +] + + +def _make_random_mxfp4(shape: Shape, seed: int = 0): + g = torch.Generator(device="cuda").manual_seed(seed) + e = shape.num_experts + n = shape.inter + k = shape.hidden + w13 = torch.randint( + 0, 256, (e, 2 * n, k // 2), dtype=torch.uint8, device="cuda", generator=g + ) + w2 = torch.randint( + 0, 256, (e, k, n // 2), dtype=torch.uint8, device="cuda", generator=g + ) + # Narrow E8M0 band so dequant magnitudes stay sane. + w13_s = torch.randint( + 125, + 130, + (e, 2 * n, k // GROUP_SIZE), + dtype=torch.uint8, + device="cuda", + generator=g, + ) + w2_s = torch.randint( + 125, + 130, + (e, k, n // GROUP_SIZE), + dtype=torch.uint8, + device="cuda", + generator=g, + ) + w13_b = ( + torch.randn(e, 2 * n, dtype=torch.float32, device="cuda", generator=g).to( + torch.bfloat16 + ) + * 0.01 + ) + w2_b = ( + torch.randn(e, k, dtype=torch.float32, device="cuda", generator=g).to( + torch.bfloat16 + ) + * 0.01 + ) + return w13, w2, w13_s, w2_s, w13_b, w2_b + + +def _make_topk(shape: Shape, seed: int = 1): + g = torch.Generator(device="cuda").manual_seed(seed) + logits = torch.randn( + shape.tokens, + shape.num_experts, + dtype=torch.float32, + device="cuda", + generator=g, + ) + weights, ids = torch.topk(torch.softmax(logits, dim=-1), shape.top_k, dim=-1) + weights = weights / weights.sum(dim=-1, keepdim=True) + return logits, weights.to(torch.float32), ids.to(torch.int32) + + +# --------------------------------------------------------------------------- +# FlashInfer cutlass path +# --------------------------------------------------------------------------- + + +def build_flashinfer_inputs(shape: Shape, w13, w2, w13_s, w2_s, w13_b, w2_b): + w13_il = interleave_moe_weights_for_sm90_mixed_gemm(w13, "fp4") + w2_il = interleave_moe_weights_for_sm90_mixed_gemm(w2, "fp4") + w13_s_il = interleave_moe_scales_for_sm90_mixed_gemm(w13_s, group_size=GROUP_SIZE) + w2_s_il = interleave_moe_scales_for_sm90_mixed_gemm(w2_s, group_size=GROUP_SIZE) + e = shape.num_experts + swiglu_alpha = torch.full((e,), 1.702, dtype=torch.float32, device="cuda") + swiglu_beta = torch.full((e,), 1.0, dtype=torch.float32, device="cuda") + swiglu_limit = torch.full((e,), 7.0, dtype=torch.float32, device="cuda") + return { + "w13": w13_il, + "w2": w2_il, + "quant_scales": [w13_s_il.view(torch.int32), w2_s_il.view(torch.int32)], + "w13_b": w13_b, + "w2_b": w2_b, + "swiglu_alpha": swiglu_alpha, + "swiglu_beta": swiglu_beta, + "swiglu_limit": swiglu_limit, + } + + +def make_flashinfer_runner( + shape: Shape, prep, x, topk_w, topk_i, autotuned: bool, with_bias: bool = True +): + out = torch.empty(shape.tokens, shape.hidden, dtype=torch.bfloat16, device="cuda") + fc1_b = prep["w13_b"] if with_bias else None + fc2_b = prep["w2_b"] if with_bias else None + + def _call(): + cutlass_fused_moe( + input=x, + token_selected_experts=topk_i, + token_final_scales=topk_w, + fc1_expert_weights=prep["w13"], + fc2_expert_weights=prep["w2"], + output_dtype=torch.bfloat16, + quant_scales=prep["quant_scales"], + fc1_expert_biases=fc1_b, + fc2_expert_biases=fc2_b, + swiglu_alpha=prep["swiglu_alpha"], + swiglu_beta=prep["swiglu_beta"], + swiglu_limit=prep["swiglu_limit"], + use_w4_group_scaling=True, + activation_type=ActivationType.Swiglu, + output=out, + ) + + if autotuned: + # Populate FlashInfer's tactic cache once before timing. + with autotune(True): + _call() + + return _call + + +# --------------------------------------------------------------------------- +# SGLang Marlin path +# --------------------------------------------------------------------------- + + +def build_marlin_inputs(shape: Shape, w13, w2, w13_s, w2_s): + """Repack MXFP4 weights into Marlin layout. Mirrors + ``prepare_moe_mxfp4_layer_for_marlin`` but does not require a layer object.""" + e = shape.num_experts + n = shape.inter + k = shape.hidden + device = w13.device + perm = torch.empty(0, dtype=torch.int, device=device) + + def _repack(weight, size_n, size_k): + out_list = [] + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + out_list.append( + gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + ) + return torch.stack(out_list) + + def _scales_for(scales, size_n, size_k): + out_list = [] + # Reinterpret uint8 E8M0 byte as float8_e8m0fnu, then to bf16 numerical. + scales_bf16 = scales.view(torch.float8_e8m0fnu).to(torch.bfloat16) + for i in range(e): + s = scales_bf16[i].T.contiguous() + ms = marlin_permute_scales( + s=s, size_k=size_k, size_n=size_n, group_size=GROUP_SIZE + ) + out_list.append(mxfp4_marlin_process_scales(ms, input_dtype=torch.bfloat16)) + return torch.stack(out_list) + + w13_marlin = _repack(w13, size_n=2 * n, size_k=k) + w2_marlin = _repack(w2, size_n=k, size_k=n) + w13_s_marlin = _scales_for(w13_s, size_n=2 * n, size_k=k) + w2_s_marlin = _scales_for(w2_s, size_n=k, size_k=n) + + workspace = marlin_make_workspace(device, 4) + return { + "w13": w13_marlin, + "w2": w2_marlin, + "w13_s": w13_s_marlin, + "w2_s": w2_s_marlin, + "workspace": workspace, + } + + +def make_marlin_runner(shape: Shape, prep, x_bf16, router_logits, topk_w, topk_i): + def _call(): + fused_marlin_moe( + hidden_states=x_bf16, + w1=prep["w13"], + w2=prep["w2"], + w1_scale=prep["w13_s"], + w2_scale=prep["w2_s"], + gating_output=router_logits, + topk_weights=topk_w, + topk_ids=topk_i, + workspace=prep["workspace"], + num_bits=4, + is_k_full=True, + inplace=False, + clamp_limit=7.0, + ) + + return _call + + +# --------------------------------------------------------------------------- +# Timing harness +# --------------------------------------------------------------------------- + + +def time_call(fn: Callable, warmup: int = 5, iters: int = 30) -> Tuple[float, float]: + """Returns (median_ms, min_ms) across ``iters`` calls after ``warmup``.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for s, e in zip(starts, ends): + s.record() + fn() + e.record() + torch.cuda.synchronize() + times = sorted(s.elapsed_time(e) for s, e in zip(starts, ends)) + return times[len(times) // 2], times[0] + + +def run_one_shape(shape: Shape, run_marlin: bool): + print(f"\n=== {shape.label()} ===") + w13, w2, w13_s, w2_s, w13_b, w2_b = _make_random_mxfp4(shape, seed=0) + router_logits, topk_w, topk_i = _make_topk(shape, seed=1) + x = ( + torch.randn(shape.tokens, shape.hidden, dtype=torch.bfloat16, device="cuda") + * 0.1 + ) + + # FlashInfer cutlass (autotune ON, with bias). + fi_prep = build_flashinfer_inputs(shape, w13, w2, w13_s, w2_s, w13_b, w2_b) + fi_at_call = make_flashinfer_runner( + shape, fi_prep, x, topk_w, topk_i, autotuned=True, with_bias=True + ) + fi_at_med, fi_at_min = time_call(fi_at_call) + print( + f" FlashInfer cutlass (autotune): median={fi_at_med:.3f} ms " + f"min={fi_at_min:.3f} ms" + ) + + # FlashInfer cutlass (autotune ON, no bias) -- isolate bias epilogue cost. + fi_at_nb_call = make_flashinfer_runner( + shape, fi_prep, x, topk_w, topk_i, autotuned=True, with_bias=False + ) + fi_at_nb_med, fi_at_nb_min = time_call(fi_at_nb_call) + print( + f" FlashInfer cutlass (AT, no-bias): median={fi_at_nb_med:.3f} ms " + f"min={fi_at_nb_min:.3f} ms " + f"(bias overhead = {fi_at_med - fi_at_nb_med:+.3f} ms / " + f"{(fi_at_med / fi_at_nb_med - 1) * 100:+.1f}%)" + ) + fi_med = fi_at_med # alias for downstream speedup print + + # Marlin + if run_marlin: + try: + ml_prep = build_marlin_inputs(shape, w13, w2, w13_s, w2_s) + ml_call = make_marlin_runner( + shape, ml_prep, x, router_logits, topk_w, topk_i + ) + ml_med, ml_min = time_call(ml_call) + print( + f" SGLang Marlin: median={ml_med:.3f} ms " + f"min={ml_min:.3f} ms" + ) + print(f" speedup (Marlin / FI autotune): {ml_med / fi_at_med:.2f}x") + print(f" speedup (Marlin / FI AT no-bias): {ml_med / fi_at_nb_med:.2f}x") + except Exception as exc: # pylint: disable=broad-except + print(f" SGLang Marlin: SKIPPED ({type(exc).__name__}: {exc})") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--no-marlin", action="store_true", help="Skip Marlin path.") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise SystemExit("CUDA required.") + cap = torch.cuda.get_device_capability() + if cap[0] != 9: + print(f"WARNING: device cap {cap} is not SM90; SM90-specific kernel may fail.") + + print(f"Device: {torch.cuda.get_device_name()} (cap {cap[0]}.{cap[1]})") + for shape in DEFAULT_SHAPES: + run_one_shape(shape, run_marlin=not args.no_marlin) + + +if __name__ == "__main__": + main() diff --git a/test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py b/test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py index 77fbe04b5131..4c942363819b 100644 --- a/test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py +++ b/test/registered/dsv4/test_deepseek_v4_flash_fp4_h200.py @@ -21,7 +21,19 @@ try_cached_model, ) -register_cuda_ci(est_time=900, suite="stage-c-test-dsv4-8-gpu-h200") +register_cuda_ci(est_time=1800, suite="stage-c-test-dsv4-8-gpu-h200") + + +def _flashinfer_has_sm90_cutlass_mxfp4() -> bool: + try: + from flashinfer.fused_moe import ( # noqa: F401 + interleave_moe_weights_for_sm90_mixed_gemm, + ) + + return True + except ImportError: + return False + MODEL = "deepseek-ai/DeepSeek-V4-Flash" MODEL_FP8 = "sgl-project/DeepSeek-V4-Flash-FP8" @@ -77,5 +89,62 @@ def test_gsm8k(self): self.assertGreater(metrics["score"], 0.93) +@unittest.skipUnless( + _flashinfer_has_sm90_cutlass_mxfp4(), + "FlashInfer build lacks SM90 mixed-input MXFP4 helpers (PR #3084, >= 0.6.11)", +) +class TestDSV4FlashFP4H200FlashInferCutlass(ServerSanityMixin, CustomTestCase): + """FlashInfer SM90 mixed-input cutlass MXFP4 backend (this PR): TP=4 + EAGLE. + + Mirrors :class:`TestDSV4FlashFP4H200` but swaps `--moe-runner-backend marlin` + for `flashinfer_mxfp4`, exercising the SM90 cutlass path from FlashInfer PR + #3084 end-to-end on a real DSv4-Flash checkpoint. + """ + + @classmethod + def setUpClass(cls): + cls.model = try_cached_model(MODEL) + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=SERVER_LAUNCH_TIMEOUT, + other_args=[ + "--trust-remote-code", + "--tp", + "4", + "--moe-runner-backend", + "flashinfer_mxfp4", + "--speculative-algorithm", + "EAGLE", + "--speculative-num-steps", + "3", + "--speculative-eagle-topk", + "1", + "--speculative-num-draft-tokens", + "4", + ], + ) + + @classmethod + def tearDownClass(cls): + if hasattr(cls, "process") and cls.process: + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="gsm8k", + api="completion", + max_tokens=512, + num_examples=200, + num_threads=128, + ) + metrics = run_eval(args) + print(f"[DSV4 Flash FP4 FlashInfer Cutlass H200] GSM8K {metrics=}") + self.assertGreater(metrics["score"], 0.93) + + if __name__ == "__main__": unittest.main() diff --git a/test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py b/test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py new file mode 100644 index 000000000000..830b55d7b498 --- /dev/null +++ b/test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py @@ -0,0 +1,544 @@ +"""Unit test for the SM90 cutlass MXFP4 path in :class:`Mxfp4MoEMethod`. + +Builds a single-layer GPT-OSS-style MoE with random MXFP4 weights, drives the +SGLang plumbing (``_process_weights_for_sm90_cutlass`` + ``_apply_sm90_cutlass``) +and compares against a direct FlashInfer ``cutlass_fused_moe`` call with the +same inputs. Both paths invoke the same SM90 kernel from FlashInfer PR #3084, +so outputs must be bit-exact. + +Run on H100/H200: + + python -m pytest test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py -v +""" + +from __future__ import annotations + +from contextlib import nullcontext + +import pytest +import torch + +from sglang.test.ci.ci_register import register_cuda_ci + +register_cuda_ci(est_time=120, suite="stage-b-test-1-gpu-large") + +flashinfer_fused_moe = pytest.importorskip("flashinfer.fused_moe") + +if not hasattr(flashinfer_fused_moe, "interleave_moe_weights_for_sm90_mixed_gemm"): + pytest.skip( + "FlashInfer build does not include PR #3084 SM90 mixed-input helpers", + allow_module_level=True, + ) + +if not torch.cuda.is_available(): + pytest.skip("CUDA required", allow_module_level=True) + +from sglang.srt.utils import is_sm90_supported, is_sm100_supported + +if not is_sm90_supported() or is_sm100_supported(): + pytest.skip( + "SM90-only path; require Hopper without SM100 promotion", + allow_module_level=True, + ) + +from flashinfer.fused_moe import ( + cutlass_fused_moe, + interleave_moe_scales_for_sm90_mixed_gemm, + interleave_moe_weights_for_sm90_mixed_gemm, +) +from flashinfer.fused_moe.core import ActivationType + +GROUP_SIZE = 32 # MXFP4 block size + + +class _MockLayer: + """Stand-in for ``FusedMoE`` carrying the attributes the SM90 helpers read. + + We construct one by hand so the test stays out of SGLang's distributed init + path (``get_tp_group`` etc.). + """ + + +class _MockTopKOutput: + def __init__(self, weights, ids): + self.topk_weights = weights + self.topk_ids = ids + + +def _make_random_mxfp4(num_experts, hidden, inter, seed=0): + g = torch.Generator(device="cuda").manual_seed(seed) + w13 = torch.randint( + 0, + 256, + (num_experts, 2 * inter, hidden // 2), + dtype=torch.uint8, + device="cuda", + generator=g, + ) + w2 = torch.randint( + 0, + 256, + (num_experts, hidden, inter // 2), + dtype=torch.uint8, + device="cuda", + generator=g, + ) + # E8M0 scales centered around 127 (= 2^0); narrow band keeps dequant values + # in a sane range so SwiGLU clamp doesn't dominate. + w13_s = torch.randint( + 125, + 130, + (num_experts, 2 * inter, hidden // GROUP_SIZE), + dtype=torch.uint8, + device="cuda", + generator=g, + ) + w2_s = torch.randint( + 125, + 130, + (num_experts, hidden, inter // GROUP_SIZE), + dtype=torch.uint8, + device="cuda", + generator=g, + ) + w13_b = ( + torch.randn( + num_experts, 2 * inter, dtype=torch.float32, device="cuda", generator=g + ).to(torch.bfloat16) + * 0.01 + ) + w2_b = ( + torch.randn( + num_experts, hidden, dtype=torch.float32, device="cuda", generator=g + ).to(torch.bfloat16) + * 0.01 + ) + return w13, w2, w13_s, w2_s, w13_b, w2_b + + +def _make_topk(tokens, num_experts, top_k, seed=1): + g = torch.Generator(device="cuda").manual_seed(seed) + logits = torch.randn( + tokens, num_experts, dtype=torch.float32, device="cuda", generator=g + ) + weights, ids = torch.topk(torch.softmax(logits, dim=-1), top_k, dim=-1) + weights = weights / weights.sum(dim=-1, keepdim=True) + return weights.to(torch.float32), ids.to(torch.int32) + + +def _build_mock_layer(num_experts, hidden, inter, w13, w2, w13_s, w2_s, w13_b, w2_b): + layer = _MockLayer() + layer.w13_weight = torch.nn.Parameter(w13.clone(), requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2.clone(), requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_s.clone(), requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_s.clone(), requires_grad=False) + layer.w13_weight_bias = torch.nn.Parameter(w13_b.clone(), requires_grad=False) + layer.w2_weight_bias = torch.nn.Parameter(w2_b.clone(), requires_grad=False) + layer.num_local_experts = num_experts # tests run with EP size = 1 + layer.moe_tp_size = 1 + layer.moe_tp_rank = 0 + layer.moe_ep_size = 1 + layer.moe_ep_rank = 0 + return layer + + +def _round_up(x, base): + return ((x + base - 1) // base) * base + + +def _build_method(num_experts, hidden, inter): + from sglang.srt.layers.quantization.mxfp4 import Mxfp4MoEMethod + + method = Mxfp4MoEMethod.__new__(Mxfp4MoEMethod) + method._fi_kernel = "cutlass_sm90" + method.num_experts = num_experts + # The new SM90 cutlass path tracks padded sizes in dedicated attrs; + # ``hidden_size`` / ``intermediate_size_per_partition`` keep the unpadded + # values to mirror what ``create_weights`` records. + method.hidden_size = hidden + method.intermediate_size_per_partition = inter + method._padded_hidden = _round_up(hidden, 128) + method._padded_intermediate = _round_up(inter, 128) + method.use_flashinfer = True + return method + + +def _expected_w13_processed(w13_un, w13_s_un, w13_b_un, N_pad, K_pad, group_size): + """Replicate ``_process_weights_for_sm90_cutlass`` for w13: de-interleave + HF's pair-wise ``[g_0, u_0, g_1, u_1, ...]`` layout into halved + ``[up; gate]``, pad each half along its row dim from ``N_un -> N_pad`` + and last dim from ``K_un -> K_pad`` with zeros, then run the FlashInfer + SM90 byte / scale interleave helpers.""" + E, two_n_un, last_un_w = w13_un.shape + N_un = two_n_un // 2 + K_un = last_un_w * 2 # packed 4-bit -> *2 for raw K + + def _split_and_pad(unpadded, last_pad, last_un, dtype): + gate = unpadded[:, 0::2, :] + up = unpadded[:, 1::2, :] + out = torch.zeros(E, 2 * N_pad, last_pad, dtype=dtype, device=unpadded.device) + out[:, :N_un, :last_un] = up + out[:, N_pad : N_pad + N_un, :last_un] = gate + return out + + w13_pad = _split_and_pad( + w13_un.view(torch.uint8), K_pad // 2, K_un // 2, w13_un.dtype + ) + w13_s_pad = _split_and_pad( + w13_s_un, K_pad // group_size, K_un // group_size, w13_s_un.dtype + ) + + gate_b = w13_b_un[:, 0::2] + up_b = w13_b_un[:, 1::2] + w13_b_pad = torch.zeros(E, 2 * N_pad, dtype=w13_b_un.dtype, device=w13_b_un.device) + w13_b_pad[:, :N_un] = up_b + w13_b_pad[:, N_pad : N_pad + N_un] = gate_b + + w13_il = interleave_moe_weights_for_sm90_mixed_gemm(w13_pad, "fp4") + w13_s_il = interleave_moe_scales_for_sm90_mixed_gemm( + w13_s_pad, group_size=group_size + ) + return w13_il, w13_s_il, w13_b_pad + + +def _expected_w2_processed(w2_un, w2_s_un, w2_b_un, N_pad, K_pad, group_size): + """w2 needs padding only (no halving / no de-interleave).""" + E, K_un, last_un_w = w2_un.shape + N_un = last_un_w * 2 + + def _pad(unpadded, last_pad, last_un): + out = torch.zeros( + E, K_pad, last_pad, dtype=unpadded.dtype, device=unpadded.device + ) + out[:, :K_un, :last_un] = unpadded + return out + + w2_pad = _pad(w2_un.view(torch.uint8), N_pad // 2, N_un // 2) + w2_s_pad = _pad(w2_s_un, N_pad // group_size, N_un // group_size) + w2_b_pad = torch.zeros(E, K_pad, dtype=w2_b_un.dtype, device=w2_b_un.device) + w2_b_pad[:, :K_un] = w2_b_un + + w2_il = interleave_moe_weights_for_sm90_mixed_gemm(w2_pad, "fp4") + w2_s_il = interleave_moe_scales_for_sm90_mixed_gemm(w2_s_pad, group_size=group_size) + return w2_il, w2_s_il, w2_b_pad + + +@pytest.mark.parametrize( + "num_experts,hidden,inter", + [ + # Aligned shapes (no padding needed). + (4, 256, 256), + (8, 768, 384), + (8, 1024, 1024), + # Non-aligned shapes (exercise the de-interleave + pad path). + # 192 % 128 = 64, so N_pad = K_pad = 256 (round_up(192, 128)). + (4, 192, 192), + # GPT-OSS-20B-like: hidden=2880, inter=2880 -> padded to 2944. + # Use smaller E to keep memory bounded. + (4, 2880, 2880), + ], +) +def test_process_weights_matches_direct_interleave(num_experts, hidden, inter): + """``_process_weights_for_sm90_cutlass`` must produce the same bytes as + a manual de-interleave + pad + halved-swap + interleave reference.""" + w13, w2, w13_s, w2_s, w13_b, w2_b = _make_random_mxfp4(num_experts, hidden, inter) + + layer = _build_mock_layer( + num_experts, hidden, inter, w13, w2, w13_s, w2_s, w13_b, w2_b + ) + method = _build_method(num_experts, hidden, inter) + method._process_weights_for_sm90_cutlass(layer) + + N_pad = _round_up(inter, 128) + K_pad = _round_up(hidden, 128) + ref_w13, ref_w13_s, ref_w13_b = _expected_w13_processed( + w13, w13_s, w13_b, N_pad, K_pad, GROUP_SIZE + ) + ref_w2, ref_w2_s, ref_w2_b = _expected_w2_processed( + w2, w2_s, w2_b, N_pad, K_pad, GROUP_SIZE + ) + + assert torch.equal(layer.w13_weight.data, ref_w13) + assert torch.equal(layer.w2_weight.data, ref_w2) + assert torch.equal(layer.w13_weight_scale.data, ref_w13_s) + assert torch.equal(layer.w2_weight_scale.data, ref_w2_s) + assert torch.equal(layer.w13_weight_bias.data, ref_w13_b) + assert torch.equal(layer.w2_weight_bias.data, ref_w2_b) + + # SwiGLU per-expert scalars seeded with GPT-OSS defaults. + assert torch.allclose( + layer.swiglu_alpha, + torch.full((num_experts,), 1.702, dtype=torch.float32, device="cuda"), + ) + assert torch.allclose( + layer.swiglu_beta, + torch.full((num_experts,), 1.0, dtype=torch.float32, device="cuda"), + ) + assert torch.allclose( + layer.swiglu_limit, + torch.full((num_experts,), 7.0, dtype=torch.float32, device="cuda"), + ) + + +@pytest.mark.parametrize( + "tokens,num_experts,hidden,inter,top_k", + [ + # Aligned shapes (no padding). + (4, 4, 256, 256, 2), + (16, 8, 768, 384, 2), + (32, 8, 1024, 1024, 4), + # Non-aligned (exercises pad x + trim output). + (8, 4, 192, 192, 2), + ], +) +def test_apply_sm90_cutlass_matches_flashinfer_direct( + tokens, num_experts, hidden, inter, top_k, monkeypatch +): + """End-to-end: SGLang's ``_apply_sm90_cutlass`` must produce the same + output as a direct FlashInfer ``cutlass_fused_moe`` call fed with the + same processed weights / scales / biases. The processing pipeline is + covered separately by ``test_process_weights_matches_direct_interleave``; + here we just verify that ``apply`` calls the kernel with the right + arguments (incl. input padding + output trim).""" + import sglang.srt.layers.quantization.mxfp4 as mxfp4_mod + + # Bypass symmetric-memory / TP-group: not relevant to numerics. + monkeypatch.setattr( + mxfp4_mod, "use_symmetric_memory", lambda *a, **kw: nullcontext() + ) + monkeypatch.setattr(mxfp4_mod, "is_allocation_symmetric", lambda: False) + monkeypatch.setattr(mxfp4_mod, "get_tp_group", lambda: None) + + w13, w2, w13_s, w2_s, w13_b, w2_b = _make_random_mxfp4(num_experts, hidden, inter) + x = torch.randn(tokens, hidden, dtype=torch.bfloat16, device="cuda") * 0.1 + topk_w, topk_i = _make_topk(tokens, num_experts, top_k) + + # ---- SGLang path ---- + layer = _build_mock_layer( + num_experts, hidden, inter, w13, w2, w13_s, w2_s, w13_b, w2_b + ) + method = _build_method(num_experts, hidden, inter) + method._process_weights_for_sm90_cutlass(layer) + + out_sglang = method._apply_sm90_cutlass( + layer, x.clone(), _MockTopKOutput(topk_w, topk_i) + ).hidden_states + + # ---- FlashInfer-direct reference using the same processed weights ---- + K_pad = method._padded_hidden + if K_pad != hidden: + x_padded = torch.nn.functional.pad( + x.clone(), (0, K_pad - hidden), mode="constant", value=0.0 + ) + else: + x_padded = x.clone() + + out_ref_padded = torch.empty(tokens, K_pad, dtype=torch.bfloat16, device="cuda") + cutlass_fused_moe( + input=x_padded, + token_selected_experts=topk_i.to(torch.int), + token_final_scales=topk_w, + fc1_expert_weights=layer.w13_weight, + fc2_expert_weights=layer.w2_weight, + output_dtype=torch.bfloat16, + quant_scales=[ + layer.w13_weight_scale.view(torch.int32), + layer.w2_weight_scale.view(torch.int32), + ], + fc1_expert_biases=layer.w13_weight_bias, + fc2_expert_biases=layer.w2_weight_bias, + swiglu_alpha=layer.swiglu_alpha, + swiglu_beta=layer.swiglu_beta, + swiglu_limit=layer.swiglu_limit, + use_w4_group_scaling=True, + activation_type=ActivationType.Swiglu, + output=out_ref_padded, + ) + out_ref = ( + out_ref_padded[:, :hidden].contiguous() if K_pad != hidden else out_ref_padded + ) + + assert torch.equal(out_sglang, out_ref), ( + f"SGLang vs FlashInfer-direct mismatch; " + f"max abs diff = {(out_sglang.float() - out_ref.float()).abs().max().item():.4g}" + ) + + +# ============================================================================= +# DeepSeek-V4 path: Mxfp4FlashinferCutlassMoEMethod (sibling of Marlin / +# trtllm-gen). Wired into fp8.py's get_quant_method when SM90 + +# is_flashinfer_mxfp4 + is_fp4_experts. +# ============================================================================= + + +def _make_random_dsv4_mxfp4(num_experts, hidden, inter, seed=0): + """Mirrors the fp8 base method's allocation for fp4 experts: int8-packed + 4-bit weights, fp32 scales (containing 2**e values, not raw E8M0 bytes).""" + g = torch.Generator(device="cuda").manual_seed(seed) + # int8 storage (signed) -- matches Fp8MoEMethod.create_weights for fp4_experts. + w13 = torch.randint( + -128, + 128, + (num_experts, 2 * inter, hidden // 2), + dtype=torch.int8, + device="cuda", + generator=g, + ) + w2 = torch.randint( + -128, + 128, + (num_experts, hidden, inter // 2), + dtype=torch.int8, + device="cuda", + generator=g, + ) + # fp32 scales whose bit pattern after .to(float8_e8m0fnu).view(uint8) lands + # in a sane E8M0 band -- generate exponents around 0 (= 2**0). + raw_e = torch.randint( + 125, + 130, + (num_experts, 2 * inter, hidden // GROUP_SIZE), + dtype=torch.uint8, + device="cuda", + generator=g, + ) + raw_e2 = torch.randint( + 125, + 130, + (num_experts, hidden, inter // GROUP_SIZE), + dtype=torch.uint8, + device="cuda", + generator=g, + ) + w13_s = raw_e.view(torch.float8_e8m0fnu).to(torch.float32) + w2_s = raw_e2.view(torch.float8_e8m0fnu).to(torch.float32) + return w13, w2, w13_s, w2_s + + +@pytest.mark.parametrize( + "tokens,num_experts,hidden,inter,top_k", + [ + (4, 4, 256, 256, 2), + (16, 8, 768, 384, 2), + (256, 8, 1024, 1024, 4), + ], +) +def test_dsv4_apply_matches_flashinfer_direct( + tokens, num_experts, hidden, inter, top_k, monkeypatch +): + """End-to-end: SGLang's DSv4 ``Mxfp4FlashinferCutlassMoEMethod.apply`` + output must match a direct FlashInfer ``cutlass_fused_moe`` call with + the equivalent reorder + scale-cast + interleave applied manually.""" + from types import SimpleNamespace + + import sglang.srt.layers.quantization.mxfp4_flashinfer_cutlass_moe as ds_mod + from sglang.srt.layers.quantization.utils import reorder_w1w3_to_w3w1 + + # Bypass symmetric-memory / TP-group stack -- not relevant to numerics. + monkeypatch.setattr(ds_mod, "use_symmetric_memory", lambda *a, **kw: nullcontext()) + monkeypatch.setattr(ds_mod, "is_allocation_symmetric", lambda: False) + monkeypatch.setattr(ds_mod, "get_tp_group", lambda: None) + + w13, w2, w13_s, w2_s = _make_random_dsv4_mxfp4(num_experts, hidden, inter) + x = torch.randn(tokens, hidden, dtype=torch.bfloat16, device="cuda") * 0.1 + topk_w, topk_i = _make_topk(tokens, num_experts, top_k) + + # ---- SGLang DSv4 path ---- + method = ds_mod.Mxfp4FlashinferCutlassMoEMethod.__new__( + ds_mod.Mxfp4FlashinferCutlassMoEMethod + ) + method._fp8 = SimpleNamespace( + process_weights_after_loading=lambda layer: None, + ) + method.prefix = "test" + # plain SiLU * up — all three SwiGLU scalars None (no clamp configured). + method._swiglu_alpha_tensor = None + method._swiglu_beta_tensor = None + method._swiglu_limit_tensor = None + + layer = _MockLayer() + layer.w13_weight = torch.nn.Parameter(w13.clone(), requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2.clone(), requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter(w13_s.clone(), requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter(w2_s.clone(), requires_grad=False) + layer.num_local_experts = num_experts + layer.moe_tp_size = 1 + layer.moe_tp_rank = 0 + layer.moe_ep_size = 1 + layer.moe_ep_rank = 0 + + method.process_weights_after_loading(layer) + + out_sglang = method.apply( + layer, _MockDispatchOutput(x.clone(), topk_w, topk_i) + ).hidden_states + + # ---- Direct FlashInfer reference ---- + w13_re, w13_s_re = reorder_w1w3_to_w3w1(w13, w13_s) + w13_s_u8 = w13_s_re.to(torch.float8_e8m0fnu).view(torch.uint8).contiguous() + w2_s_u8 = w2_s.to(torch.float8_e8m0fnu).view(torch.uint8).contiguous() + ref_w13 = interleave_moe_weights_for_sm90_mixed_gemm( + w13_re.view(torch.uint8).contiguous(), "fp4" + ) + ref_w2 = interleave_moe_weights_for_sm90_mixed_gemm( + w2.view(torch.uint8).contiguous(), "fp4" + ) + ref_w13_s = interleave_moe_scales_for_sm90_mixed_gemm( + w13_s_u8, group_size=GROUP_SIZE + ) + ref_w2_s = interleave_moe_scales_for_sm90_mixed_gemm(w2_s_u8, group_size=GROUP_SIZE) + + out_ref = torch.empty(tokens, hidden, dtype=torch.bfloat16, device="cuda") + cutlass_fused_moe( + input=x.clone(), + token_selected_experts=topk_i, + token_final_scales=topk_w, + fc1_expert_weights=ref_w13, + fc2_expert_weights=ref_w2, + output_dtype=torch.bfloat16, + quant_scales=[ref_w13_s.view(torch.int32), ref_w2_s.view(torch.int32)], + fc1_expert_biases=None, + fc2_expert_biases=None, + swiglu_alpha=None, + swiglu_beta=None, + swiglu_limit=None, + use_w4_group_scaling=True, + activation_type=ActivationType.Swiglu, + output=out_ref, + ) + + assert torch.equal(out_sglang, out_ref), ( + f"DSv4 SGLang vs FlashInfer-direct mismatch; " + f"max abs diff = " + f"{(out_sglang.float() - out_ref.float()).abs().max().item():.4g}" + ) + + +class _MockDispatchOutput: + """Stand-in for StandardDispatchOutput. ``topk_output`` is a real + ``StandardTopKOutput`` so ``TopKOutputChecker.format_is_standard`` + (an isinstance check) returns True without distributed init.""" + + def __init__(self, hidden_states, topk_weights, topk_ids): + from sglang.srt.layers.moe.topk import StandardTopKOutput + + self.hidden_states = hidden_states + # router_logits is unused by Mxfp4FlashinferCutlassMoEMethod.apply; + # supply a placeholder of the right shape to keep the NamedTuple happy. + router_logits = torch.zeros( + topk_ids.shape[0], + int(topk_ids.max().item()) + 1 if topk_ids.numel() else 1, + dtype=torch.float32, + device=topk_ids.device, + ) + self.topk_output = StandardTopKOutput( + topk_weights=topk_weights, + topk_ids=topk_ids, + router_logits=router_logits, + ) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"]))