diff --git a/python/sglang/srt/debug_utils/w4a16_moe_ref_related.py b/python/sglang/srt/debug_utils/w4a16_moe_ref_related.py
new file mode 100644
index 000000000000..6f01e7fb1b33
--- /dev/null
+++ b/python/sglang/srt/debug_utils/w4a16_moe_ref_related.py
@@ -0,0 +1,143 @@
+"""
+Pure-torch MoE ref for W4A16 acc investigation.
+
+The body of ``torch_ref_cutlass_fused_moe`` adapts
+``_compute_with_active_experts`` from flashinfer-sunrise PR #3084 at
+ tests/moe/test_trtllm_cutlass_fused_moe.py
+ (commit 77746b81, lines 2458-2491)
+into a drop-in replacement for ``flashinfer.fused_moe.cutlass_fused_moe``:
+identical signature, identical in-place output semantics. Only the subset of
+kwargs actually needed to reproduce DSv4 W4A16 numerics is consumed; the rest
+are accepted and ignored.
+
+Expectation on weights: both ``fc1_expert_weights`` and ``fc2_expert_weights``
+are already bf16 (caller dequanted FP4+UE8M0 up front — see
+``DeepSeekW4A16MoEMethod.process_weights_after_loading`` under
+``SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF=1``). ``quant_scales`` /
+``use_w4_group_scaling`` are accepted for signature parity but unused.
+
+Activation: we reproduce the kernel's behavior when only ``swiglu_limit`` is
+passed (no ``swiglu_alpha``/``swiglu_beta``) — SiLU on the clamped gate,
+symmetric clamp on the up. This matches the sglang ``w4a16_deepseek.py``
+apply() kernel call.
+"""
+from __future__ import annotations
+
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+
+
+def torch_ref_cutlass_fused_moe(
+ input: torch.Tensor,
+ token_selected_experts: torch.Tensor,
+ token_final_scales: torch.Tensor,
+ fc1_expert_weights: torch.Tensor,
+ fc2_expert_weights: torch.Tensor,
+ output_dtype: torch.dtype,
+ quant_scales: Optional[List[torch.Tensor]] = None,
+ fc1_expert_biases: Optional[torch.Tensor] = None,
+ fc2_expert_biases: Optional[torch.Tensor] = None,
+ input_sf: Optional[torch.Tensor] = None,
+ swiglu_alpha: Optional[torch.Tensor] = None,
+ swiglu_beta: Optional[torch.Tensor] = None,
+ swiglu_limit: Optional[torch.Tensor] = None,
+ tp_size: int = 1,
+ tp_rank: int = 0,
+ ep_size: int = 1,
+ ep_rank: int = 0,
+ cluster_size: int = 1,
+ cluster_rank: int = 0,
+ output: Optional[torch.Tensor] = None,
+ enable_alltoall: bool = False,
+ use_deepseek_fp8_block_scale: bool = False,
+ use_w4_group_scaling: bool = False,
+ use_mxfp8_act_scaling: bool = False,
+ min_latency_mode: bool = False,
+ use_packed_weights: bool = False,
+ tune_max_num_tokens: int = 8192,
+ enable_pdl: Optional[bool] = None,
+ activation_type=None,
+ swizzled_input_sf: bool = True,
+) -> torch.Tensor:
+ """Pure-torch drop-in for flashinfer ``cutlass_fused_moe`` (W4A16 path).
+
+ Consumed args: ``input``, ``token_selected_experts``,
+ ``token_final_scales``, ``fc1_expert_weights``, ``fc2_expert_weights``
+ (both bf16), ``output_dtype``, ``swiglu_limit``, ``ep_size``, ``ep_rank``,
+ ``output``. Everything else is ignored.
+ """
+ del (
+ quant_scales,
+ fc1_expert_biases,
+ fc2_expert_biases,
+ input_sf,
+ swiglu_alpha,
+ swiglu_beta,
+ tp_size,
+ tp_rank,
+ cluster_size,
+ cluster_rank,
+ enable_alltoall,
+ use_deepseek_fp8_block_scale,
+ use_w4_group_scaling,
+ use_mxfp8_act_scaling,
+ min_latency_mode,
+ use_packed_weights,
+ tune_max_num_tokens,
+ enable_pdl,
+ activation_type,
+ swizzled_input_sf,
+ )
+
+ assert fc1_expert_weights.dtype == torch.bfloat16, (
+ f"torch-ref expects bf16 weights, got {fc1_expert_weights.dtype}"
+ )
+ assert fc2_expert_weights.dtype == torch.bfloat16, (
+ f"torch-ref expects bf16 weights, got {fc2_expert_weights.dtype}"
+ )
+
+ num_tokens = input.shape[0]
+ hidden = fc2_expert_weights.shape[1]
+ num_local_experts = fc1_expert_weights.shape[0]
+ local_expert_offset = ep_rank * num_local_experts
+
+ if output is None:
+ output = torch.empty(
+ num_tokens, hidden, dtype=output_dtype, device=input.device
+ )
+ output.zero_()
+
+ topk_ids_local = token_selected_experts.long() - local_expert_offset
+ in_range = (topk_ids_local >= 0) & (topk_ids_local < num_local_experts)
+ if not in_range.any():
+ return output
+
+ active_local = torch.unique(topk_ids_local[in_range])
+ for eid_local in active_local.tolist():
+ mask = (topk_ids_local == eid_local) & in_range
+ tok_idx, nth = torch.where(mask)
+ if tok_idx.numel() == 0:
+ continue
+
+ w31 = fc1_expert_weights[eid_local]
+ w3, w1 = torch.chunk(w31, 2, dim=0)
+ w2 = fc2_expert_weights[eid_local]
+
+ expert_in = input[tok_idx]
+ x1 = expert_in @ w1.t()
+ x3 = expert_in @ w3.t()
+
+ if swiglu_limit is not None:
+ limit = swiglu_limit[eid_local]
+ x1 = x1.clamp(max=limit)
+ x3 = x3.clamp(min=-limit, max=limit)
+
+ inter = F.silu(x1) * x3
+ out = inter @ w2.t()
+
+ weight = token_final_scales[tok_idx, nth, None].to(out.dtype)
+ output.index_add_(0, tok_idx, (weight * out).to(output.dtype))
+
+ return output
diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py
index 4277aad31c54..ebe7be599e93 100644
--- a/python/sglang/srt/environ.py
+++ b/python/sglang/srt/environ.py
@@ -498,6 +498,9 @@ class Envs:
SGLANG_OPT_MXFP4_FUSE_RSF_SHARED_ADD = EnvBool(True)
SGLANG_OPT_MXFP4_STATIC_SCALE_ONES = EnvBool(True)
SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING = EnvBool(True)
+ SGLANG_HACK_DEBUG_W4A16_REMOVE_SWIGLU_LIMIT = EnvBool(False)
+ SGLANG_HACK_DEBUG_W4A16_USE_BF16_API = EnvBool(False)
+ SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF = EnvBool(False)
SGLANG_OPT_USE_JIT_INDEXER_METADATA = EnvBool(False)
SGLANG_OPT_SWIGLU_CLAMP_FUSION = EnvBool(True)
SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE = EnvInt(-1)
diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py
index a47534ce7f4a..f47e522a8453 100644
--- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py
+++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py
@@ -90,10 +90,20 @@ def __init__(self, moe_runner_config: MoeRunnerConfig):
self.enable_flashinfer_mxfp4_moe = (
get_moe_runner_backend().is_flashinfer_mxfp4()
)
+ self.enable_flashinfer_w4a16_moe = (
+ get_moe_runner_backend().is_flashinfer_w4a16()
+ )
self.skip_local_expert_mapping = (
self.enable_flashinfer_mxfp4_moe
and envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get()
)
+ # Read env once at init; per-layer dispatcher instance is long-lived.
+ # w4a16 uses the same flashinfer-style "kernel expects global ids +
+ # ep_rank/ep_size for local filtering" contract as mxfp4, so the
+ # dispatcher must be a passthrough for both backends when the flag is on.
+ self.skip_local_expert_mapping = (
+ self.enable_flashinfer_mxfp4_moe or self.enable_flashinfer_w4a16_moe
+ ) and envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get()
self.num_experts = moe_runner_config.num_experts
self.num_local_shared_experts = moe_runner_config.num_fused_shared_experts
self.num_local_routed_experts = (
diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py
index ba6ca01ff140..d119d90f7335 100644
--- a/python/sglang/srt/layers/moe/utils.py
+++ b/python/sglang/srt/layers/moe/utils.py
@@ -60,6 +60,7 @@ class MoeRunnerBackend(Enum):
FLASHINFER_TRTLLM = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
+ FLASHINFER_W4A16 = "flashinfer_w4a16"
FLASHINFER_CUTEDSL = "flashinfer_cutedsl"
CUTLASS = "cutlass"
MARLIN = "marlin"
@@ -88,6 +89,9 @@ def is_flashinfer_cutedsl(self):
def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4
+ def is_flashinfer_w4a16(self):
+ return self == MoeRunnerBackend.FLASHINFER_W4A16
+
def is_cutlass(self):
return self == MoeRunnerBackend.CUTLASS
diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py
index ed74b1d47e52..1371742bc5f3 100644
--- a/python/sglang/srt/layers/quantization/fp8.py
+++ b/python/sglang/srt/layers/quantization/fp8.py
@@ -196,6 +196,16 @@ def get_quant_method(
)
return DeepSeekMxfp4MoEMethod(fp8_method, prefix=prefix)
+ if (
+ envs.SGLANG_DSV4_MODE.get() == "2604"
+ and envs.SGLANG_DSV4_FP4_EXPERTS.get()
+ and get_moe_runner_backend().is_flashinfer_w4a16()
+ ):
+ from sglang.srt.layers.quantization.w4a16_deepseek import (
+ DeepSeekW4A16MoEMethod,
+ )
+
+ return DeepSeekW4A16MoEMethod(fp8_method, prefix=prefix)
return fp8_method
elif isinstance(layer, RadixAttention):
return Fp8KVCacheMethod(self)
diff --git a/python/sglang/srt/layers/quantization/w4a16_deepseek.py b/python/sglang/srt/layers/quantization/w4a16_deepseek.py
new file mode 100644
index 000000000000..6c3517ccd782
--- /dev/null
+++ b/python/sglang/srt/layers/quantization/w4a16_deepseek.py
@@ -0,0 +1,451 @@
+"""
+DeepSeek W4A16 MoE quantization method (SM90 / H200).
+
+Wraps Fp8MoEMethod to reuse the FP4-expert weight creation/loading, then
+overrides process_weights_after_loading to pre-interleave FP4 weights and
+MXFP4 block scales for the SM90 mixed-input CUTLASS kernel exposed by
+flashinfer-ai/flashinfer PR #3084, and overrides apply to call
+cutlass_fused_moe(..., use_w4_group_scaling=True) directly.
+
+This is the H200 counterpart to mxfp4_deepseek.py. The two share the same
+DSv4 FP4 checkpoint (SGLANG_DSV4_MODE=2604 SGLANG_DSV4_FP4_EXPERTS=1): weight
+shapes and dtypes are identical; only the post-load layout and the kernel
+call differ.
+
+Usage: --moe-runner-backend flashinfer_w4a16 --moe-a2a-backend none
+ with SGLANG_DSV4_MODE=2604 SGLANG_DSV4_FP4_EXPERTS=1
+"""
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING
+
+import torch
+from torch.nn import Module
+from torch.nn.parameter import Parameter
+
+from sglang.srt.distributed import get_tp_group
+from sglang.srt.distributed.device_communicators.pynccl_allocator import (
+ use_symmetric_memory,
+)
+from sglang.srt.layers.dp_attention import is_allocation_symmetric
+from sglang.srt.server_args import get_global_server_args
+from sglang.srt.utils import (
+ is_flashinfer_available,
+ log_info_on_rank0,
+ set_weight_attrs,
+)
+from sglang.srt.utils.common import next_power_of_2
+
+if is_flashinfer_available():
+ from flashinfer.fused_moe import (
+ cutlass_fused_moe,
+ interleave_moe_scales_for_sm90_mixed_gemm,
+ interleave_moe_weights_for_sm90_mixed_gemm,
+ )
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput
+
+
+from sglang.srt.debug_utils.sunrise_debug_utils import sunrise_moe_code_path_checker
+from sglang.srt.environ import envs
+
+
+def _fp32_to_ue8m0(x: torch.Tensor) -> torch.Tensor:
+ """Convert float32 → UE8M0 (float8_e8m0fnu) and assert lossless round-trip.
+
+ UE8M0 stores only the 8-bit biased exponent (no mantissa), so only exact
+ powers of 2 round-trip bit-exactly. DSv4 MXFP4 block scales should already
+ be powers of 2 per the MXFP4 spec; if this assert fires, the checkpoint
+ isn't actually MXFP4-quantized and we must feed the kernel the raw E8M0
+ bytes instead of round-tripping through fp32.
+ """
+ assert x.dtype == torch.float32, f"expected float32 input, got {x.dtype}"
+ ans = x.to(torch.float8_e8m0fnu)
+ rt = ans.float()
+ mismatch = rt != x
+ if mismatch.any():
+ bad_orig = x[mismatch][:5].tolist()
+ bad_rt = rt[mismatch][:5].tolist()
+ raise AssertionError(
+ f"fp32→UE8M0 lossy: {int(mismatch.sum())}/{x.numel()} elements "
+ f"changed; min/max input = {x.min().item()}/{x.max().item()}; "
+ f"first 5 (orig → round-trip): {list(zip(bad_orig, bad_rt))}"
+ )
+ return ans
+
+
+# MXFP4 4-bit codebook and dequant helper, copied verbatim (module-level
+# constant `_MXFP4_LUT` + body of `_dequant_mxfp4_on_device`) from flashinfer
+# PR #3084 branch at
+# flashinfer-sunrise/tests/moe/test_trtllm_cutlass_fused_moe.py
+# (commit 77746b81, lines 2419-2452)
+# so the bf16 API path sees weights bit-equivalent to what the SM90
+# mixed-input kernel dequants inside itself.
+_MXFP4_LUT = (
+ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
+ -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,
+)
+
+
+def _dequant_mxfp4(
+ w_fp4_u8: torch.Tensor, w_scale_ue8m0_u8: torch.Tensor
+) -> torch.Tensor:
+ """[E, rows, K/2] uint8 FP4 packed + [E, rows, K/32] uint8 UE8M0 → [E, rows, K] bf16."""
+ lut = torch.tensor(_MXFP4_LUT, dtype=torch.float32, device=w_fp4_u8.device)
+ lo = w_fp4_u8 & 0x0F
+ hi = (w_fp4_u8 >> 4) & 0x0F
+ nib = torch.stack([lo, hi], dim=-1).reshape(*w_fp4_u8.shape[:-1], -1)
+ values = lut[nib.long()]
+ scale = torch.exp2(w_scale_ue8m0_u8.to(torch.float32) - 127.0)
+ scale = scale.repeat_interleave(32, dim=-1)
+ return (values * scale).to(torch.bfloat16)
+
+
+class DeepSeekW4A16MoEMethod:
+ """W4A16 MoE method for DeepSeek-family models with FP4 expert weights on SM90.
+
+ Wraps Fp8MoEMethod for weight creation/loading, but overrides
+ post-loading processing to pre-interleave FP4 weights and MXFP4 scales
+ for the flashinfer SM90 mixed-input CUTLASS kernel, and directly calls
+ cutlass_fused_moe(..., use_w4_group_scaling=True) in apply().
+ """
+
+ def __init__(self, fp8_method, prefix: str):
+ self._fp8 = fp8_method
+ self.prefix = prefix
+ # Kept for parity with mxfp4_deepseek; unused by cutlass_fused_moe.
+ self.flashinfer_mxfp4_moe_precision = (
+ get_global_server_args().flashinfer_mxfp4_moe_precision
+ )
+
+ def create_moe_runner(self, layer, moe_runner_config):
+ self.moe_runner_config = moe_runner_config
+
+ # Sanity check: v5 (260415) ckpt's HF config has swiglu_limit=10.0;
+ # v4 (260409) does not. Same check as mxfp4_deepseek.
+ swiglu_limit = moe_runner_config.swiglu_limit
+ is_260415 = envs.SGLANG_DSV4_2604_SUBMODE.get() == "260415"
+ assert is_260415 == (swiglu_limit is not None), (
+ f"swiglu_limit must be non-None iff submode=260415 "
+ f"(got submode={envs.SGLANG_DSV4_2604_SUBMODE.get()!r}, "
+ f"swiglu_limit={swiglu_limit!r})"
+ )
+ self._swiglu_limit_tensor = (
+ torch.full(
+ (layer.num_local_experts,),
+ swiglu_limit,
+ dtype=torch.float32,
+ device=layer.w13_weight.device,
+ )
+ if swiglu_limit is not None
+ else None
+ )
+
+ def create_weights(
+ self,
+ layer,
+ num_experts: int,
+ hidden_size: int,
+ intermediate_size_per_partition: int,
+ params_dtype,
+ **extra_weight_attrs,
+ ):
+ """Create FP4-packed weights with TP-aware shapes.
+
+ Shapes and dtypes are identical to mxfp4_deepseek (same checkpoint);
+ the only difference is the post-load layout produced by
+ process_weights_after_loading.
+ """
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
+
+ fp4_block_k = 32
+
+ # FP4 packed weights: 2 values per byte -> physical K = logical K / 2
+ w13_weight = Parameter(
+ torch.empty(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ hidden_size // 2,
+ dtype=torch.int8,
+ ),
+ requires_grad=False,
+ )
+ w2_weight = Parameter(
+ torch.empty(
+ num_experts,
+ hidden_size,
+ intermediate_size_per_partition // 2,
+ dtype=torch.int8,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_weight", w13_weight)
+ set_weight_attrs(w13_weight, extra_weight_attrs)
+ layer.register_parameter("w2_weight", w2_weight)
+ set_weight_attrs(w2_weight, extra_weight_attrs)
+
+ # Block scales: one float32 scale per fp4_block_k FP4 elements along K
+ w13_weight_scale = Parameter(
+ torch.ones(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ hidden_size // fp4_block_k,
+ dtype=torch.float32,
+ ),
+ requires_grad=False,
+ )
+ w2_weight_scale = Parameter(
+ torch.ones(
+ num_experts,
+ hidden_size,
+ intermediate_size_per_partition // fp4_block_k,
+ dtype=torch.float32,
+ ),
+ requires_grad=False,
+ )
+ w13_weight_scale.format_ue8m0 = False
+ w2_weight_scale.format_ue8m0 = False
+ scale_attrs = dict(extra_weight_attrs)
+ scale_attrs["quant_method"] = FusedMoeWeightScaleSupported.BLOCK.value
+ layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
+ set_weight_attrs(w13_weight_scale, scale_attrs)
+ layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
+ set_weight_attrs(w2_weight_scale, scale_attrs)
+
+ def process_weights_after_loading(self, layer: Module) -> None:
+ from sglang.srt.layers.quantization.utils import reorder_w1w3_to_w3w1
+
+ # Let Fp8MoEMethod do its processing first (FP4 view conversion, etc.).
+ # When SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE=1 is set, that path builds the
+ # mega-MoE weight tuples on the layer; we must then skip the
+ # reorder/interleave below since mega wants the checkpoint's
+ # [gate(w1), up(w3)] row order.
+ self._fp8.process_weights_after_loading(layer)
+
+ if getattr(layer, "_mega_moe_weights_built", False):
+ return
+
+ # Reorder w13 from checkpoint [w1(gate), w3(up)] to kernel [w3(up), w1(gate)].
+ # flashinfer's SM90 W4A16 test (`test_moe_bf16_mxfp4`) computes its
+ # reference as `w3, w1 = torch.chunk(w31, 2, dim=0)` — i.e. w3 (up) is
+ # the first half along dim -2, w1 (gate) is the second. Same row order
+ # as the B200 TRT-LLM routed kernel.
+ w13_w, w13_s = reorder_w1w3_to_w3w1(
+ layer.w13_weight.data, layer.w13_weight_scale_inv.data
+ )
+ layer.w13_weight = Parameter(w13_w, requires_grad=False)
+ layer.w13_weight_scale_inv = Parameter(w13_s, requires_grad=False)
+
+ log_info_on_rank0(
+ logger,
+ f"Interleaving FP4 expert weights/scales for SM90 W4A16 kernel "
+ f"(layer: {self.prefix})...",
+ )
+
+ w13 = layer.w13_weight.data
+ w2 = layer.w2_weight.data
+ w13_scale = layer.w13_weight_scale_inv.data
+ w2_scale = layer.w2_weight_scale_inv.data
+
+ # Convert float32 block scales to UE8M0 (1 byte per element) before the
+ # layout interleave. The SM90 kernel reads E8M0 uint8 bytes.
+ if w13_scale.dtype == torch.float32:
+ w13_scale = _fp32_to_ue8m0(w13_scale)
+ w2_scale = _fp32_to_ue8m0(w2_scale)
+
+ # bf16-weight debug path: dequant FP4+UE8M0 → bf16 once. Two downstream
+ # consumers share this:
+ # - BF16_API: apply() calls plain bf16 cutlass_fused_moe (skips the
+ # SM90 mixed-input kernel and the SM90 interleave).
+ # - TORCH_REF: apply() calls a pure-torch MoE forward (skips the
+ # flashinfer bf16 grouped GEMM too).
+ # Both are independent numerical references for W4A16 acc drops.
+ use_bf16_api = envs.SGLANG_HACK_DEBUG_W4A16_USE_BF16_API.get()
+ use_torch_ref = envs.SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF.get()
+ assert not (use_bf16_api and use_torch_ref), (
+ "SGLANG_HACK_DEBUG_W4A16_USE_BF16_API and "
+ "SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF are mutually exclusive"
+ )
+ if use_bf16_api or use_torch_ref:
+ consumer = "bf16-API" if use_bf16_api else "torch-ref"
+ log_info_on_rank0(
+ logger,
+ f"Dequant FP4 → bf16 for {consumer} path (layer: {self.prefix})...",
+ )
+ w13_bf16 = _dequant_mxfp4(
+ w13.contiguous().view(torch.uint8),
+ w13_scale.contiguous().view(torch.uint8),
+ )
+ w2_bf16 = _dequant_mxfp4(
+ w2.contiguous().view(torch.uint8),
+ w2_scale.contiguous().view(torch.uint8),
+ )
+ layer.w13_weight = Parameter(w13_bf16, requires_grad=False)
+ layer.w2_weight = Parameter(w2_bf16, requires_grad=False)
+ # Drop scale parameters — bf16 path does not read them. Replace
+ # with zero-size placeholders to keep any attribute-existence
+ # checks happy.
+ layer.w13_weight_scale_inv = Parameter(
+ torch.empty(0, device=w13_bf16.device), requires_grad=False
+ )
+ layer.w2_weight_scale_inv = Parameter(
+ torch.empty(0, device=w2_bf16.device), requires_grad=False
+ )
+ torch.cuda.empty_cache()
+ return
+
+ # Pre-interleave MXFP4 weights and scales (runs once at load time).
+ # Shapes after interleave:
+ # weights: same as input (byte-permutation only).
+ # scales: [E, rows, K/32] -> [E, K/(32*4), rows*4] uint8.
+ w13_u8 = w13.contiguous().view(torch.uint8)
+ w2_u8 = w2.contiguous().view(torch.uint8)
+ w13_scale_u8 = w13_scale.contiguous().view(torch.uint8)
+ w2_scale_u8 = w2_scale.contiguous().view(torch.uint8)
+
+ w13_il = interleave_moe_weights_for_sm90_mixed_gemm(w13_u8, "fp4")
+ w2_il = interleave_moe_weights_for_sm90_mixed_gemm(w2_u8, "fp4")
+ w13_scale_il = interleave_moe_scales_for_sm90_mixed_gemm(
+ w13_scale_u8, group_size=32
+ )
+ w2_scale_il = interleave_moe_scales_for_sm90_mixed_gemm(
+ w2_scale_u8, group_size=32
+ )
+
+ layer.w13_weight = Parameter(w13_il, requires_grad=False)
+ layer.w2_weight = Parameter(w2_il, requires_grad=False)
+ # Keep interleaved scales as uint8 — .view(torch.int32) at apply-time.
+ layer.w13_weight_scale_inv = Parameter(w13_scale_il, requires_grad=False)
+ layer.w2_weight_scale_inv = Parameter(w2_scale_il, requires_grad=False)
+
+ torch.cuda.empty_cache()
+
+ def apply(
+ self,
+ layer: Module,
+ dispatch_output: DispatchOutput,
+ ) -> CombineInput:
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
+
+ hidden_states = dispatch_output.hidden_states
+ topk_output = dispatch_output.topk_output
+
+ # --- Step A: Prepare weights and sizes ---
+ w13 = layer.w13_weight
+ w2 = layer.w2_weight
+ use_bf16_api = envs.SGLANG_HACK_DEBUG_W4A16_USE_BF16_API.get()
+ use_torch_ref = envs.SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF.get()
+ if use_bf16_api or use_torch_ref:
+ # bf16 weights path: weights already dequanted to bf16 in
+ # process_weights_after_loading; no scale tensors to pass.
+ quant_scales_arg = None
+ else:
+ quant_scales_arg = [
+ layer.w13_weight_scale_inv.view(torch.int32),
+ layer.w2_weight_scale_inv.view(torch.int32),
+ ]
+
+ # w13/w2 are pre-interleaved uint8 (W4A16) or plain bf16 (bf16-API);
+ # logical shapes come from the layer-configured sizes rather than
+ # tensor dims (interleave preserves numel but the 3D view no longer
+ # maps 1:1 to [E, 2*I, H/2]).
+ hidden_size = layer.hidden_size
+
+ # --- Step B: Determine routing ---
+ if TopKOutputChecker.format_is_standard(topk_output):
+ topk_ids = topk_output.topk_ids
+ topk_weights = topk_output.topk_weights
+ else:
+ raise ValueError(
+ f"Unsupported topk output format for W4A16 MoE: {topk_output.format}"
+ )
+
+ # Undo StandardDispatcher's global->local+sentinel mapping so the
+ # flashinfer kernel (which expects global expert ids plus ep_rank/ep_size
+ # for local filtering) gets what it wants. Mirror the mxfp4_deepseek
+ # logic gated on SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.
+ if not envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get():
+ local_expert_offset = layer.moe_ep_rank * layer.num_local_experts
+ topk_ids = torch.where(
+ topk_ids >= 0,
+ topk_ids + local_expert_offset,
+ topk_ids,
+ )
+
+ # --- Step C: Activations ---
+ # W4A16 path: bf16 activations, no quantization needed.
+ assert hidden_states.dtype == torch.bfloat16, (
+ f"W4A16 expects bf16 activations, got {hidden_states.dtype}"
+ )
+ x = hidden_states
+ origin_dim = x.shape[-1]
+ if hidden_size != origin_dim:
+ x = torch.nn.functional.pad(
+ x, (0, hidden_size - origin_dim), mode="constant", value=0.0
+ )
+
+ # --- Step D: Allocate output with symmetric memory for TP all-reduce ---
+ with use_symmetric_memory(
+ get_tp_group(), disabled=not is_allocation_symmetric()
+ ):
+ num_tokens = x.shape[0]
+ symm_output = torch.empty(
+ num_tokens, hidden_size, dtype=torch.bfloat16, device=x.device
+ )
+
+ # --- Step E: Call kernel ---
+ # DSv4 260415 ships a per-MoE-layer sanity counter that deepseek_v4.py
+ # asserts is bumped exactly once per forward (see deepseek_v4.py:2014).
+ # Mirror the mxfp4_deepseek bump so the checker is satisfied.
+ if envs.SGLANG_DSV4_2604_SUBMODE.get() == "260415" and (
+ self._swiglu_limit_tensor is not None
+ ):
+ sunrise_moe_code_path_checker.observed += 1
+
+ swiglu_limit_arg = (
+ None
+ if envs.SGLANG_HACK_DEBUG_W4A16_REMOVE_SWIGLU_LIMIT.get()
+ else self._swiglu_limit_tensor
+ )
+
+ _moe_fn = cutlass_fused_moe
+ if use_torch_ref:
+ from sglang.srt.debug_utils.w4a16_moe_ref_related import (
+ torch_ref_cutlass_fused_moe as _moe_fn,
+ )
+
+ _moe_fn(
+ input=x,
+ token_selected_experts=topk_ids.to(torch.int32).contiguous(),
+ token_final_scales=topk_weights.to(torch.float32).contiguous(),
+ fc1_expert_weights=w13,
+ fc2_expert_weights=w2,
+ output_dtype=torch.bfloat16,
+ quant_scales=quant_scales_arg,
+ swiglu_limit=swiglu_limit_arg,
+ ep_size=layer.moe_ep_size,
+ ep_rank=layer.moe_ep_rank,
+ tp_size=1,
+ tp_rank=0,
+ use_w4_group_scaling=not use_bf16_api,
+ tune_max_num_tokens=next_power_of_2(x.shape[0]),
+ output=symm_output,
+ )
+ output = symm_output
+
+ # Apply routed_scaling_factor (DSv4 = 1.5). cutlass_fused_moe has no
+ # routed_scaling_factor parameter, so unless we hand this to the fused
+ # shared-add fast path, we multiply post-hoc. See mxfp4_deepseek for
+ # the same rationale.
+ if not envs.SGLANG_OPT_MXFP4_FUSE_RSF_SHARED_ADD.get():
+ rsf = layer.moe_runner_config.routed_scaling_factor
+ if rsf is not None and rsf != 1.0:
+ output.mul_(rsf)
+
+ return StandardCombineInput(hidden_states=output)
diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py
index 28a3c92e1b55..77b833568077 100644
--- a/python/sglang/srt/server_args.py
+++ b/python/sglang/srt/server_args.py
@@ -179,6 +179,7 @@
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
+ "flashinfer_w4a16",
"flashinfer_cutedsl",
"cutlass",
]
diff --git a/sunrise/aime25_q6/__init__.py b/sunrise/aime25_q6/__init__.py
new file mode 100644
index 000000000000..b8d7ce3fb16e
--- /dev/null
+++ b/sunrise/aime25_q6/__init__.py
@@ -0,0 +1,3 @@
+DATASET_GROUP = "math"
+METRICS_TYPE = "math"
+GENERATION_ARGS = "++prompt_config=generic/math ++eval_type=math"
diff --git a/sunrise/aime25_q6/generate.sh b/sunrise/aime25_q6/generate.sh
new file mode 100755
index 000000000000..5bf2a84a2fdd
--- /dev/null
+++ b/sunrise/aime25_q6/generate.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env bash
+# Regenerate sunrise/aime25_q6/test.jsonl from the canonical nemo_skills aime25 dataset.
+#
+# Run from inside an rcli container that has the nemo_skills venv set up (see
+# journal 2026-04-21-024 §Step 6a), or adjust --source-jsonl to point at any
+# local copy of nemo_skills/dataset/aime25/test.jsonl.
+#
+# Must be run after `ns prepare_data aime25` so the source jsonl exists.
+set -euo pipefail
+HERE="$(cd "$(dirname "$0")" && pwd)"
+REPO_ROOT="$(cd "$HERE/../.." && pwd)"
+
+SOURCE="${SOURCE:-/workspace/nemo_skills-venv/lib/python3.12/site-packages/nemo_skills/dataset/aime25/test.jsonl}"
+
+cd "$REPO_ROOT"
+python3 sunrise/filter_nemo_skills_questions.py \
+ --question-ids aime25-6 \
+ --source-jsonl "$SOURCE" \
+ --output-dir sunrise
+echo "Wrote $HERE/test.jsonl"
diff --git a/sunrise/aime25_q6/test.jsonl b/sunrise/aime25_q6/test.jsonl
new file mode 100644
index 000000000000..f81781ebc7c1
--- /dev/null
+++ b/sunrise/aime25_q6/test.jsonl
@@ -0,0 +1 @@
+{"id": "aime25-6", "problem": "The twelve letters $A$ , $B$ , $C$ , $D$ , $E$ , $F$ , $G$ , $H$ , $I$ , $J$ , $K$ , and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and then those six words are listed alphabetically. For example, a possible result is $AB$ , $CJ$ , $DG$ , $EK$ , $FL$ , $HI$ . The probability that the last word listed contains $G$ is $\\frac mn$ , where $m$ and $n$ are relatively prime positive integers. Find $m+n$ .", "expected_answer": "821", "reference_solution": "Splitting up into $2$ cases: $G$ is the first letter or the second letter of the last word. Case $1:$ $G$ in first letter Notice that $A$ must take the first letter of first word, one of the letters $B$ - $F$ needs to be the second letter of a word and the rest being the first letter of a word. \nThe combinations will be $1 + 2 + 3 + 4 + 5 = 15.$ After the first $7$ letters has been decided then the last $5$ will just fill by $5!.$ This case will have $15 \\cdot 5!$ outcomes.
\nCase $2:$ $G$ in last letter Notice that $A$ - $G$ has been arranged by $A? B? C? D? E? FG,$ where the $?$ is undecided. We have another $5!$ to fill out the possible outcomes. In total, there are $16 \\cdot 5!.$ The total case will be $11 \\cdot 9 \\cdot 7 \\cdot 5 \\cdot 3 \\cdot 1$ (Consider A must be in the first letter of first word, then you have $11$ choices, then you must take the next letter in alphabetical order as mandatory, then you have a free choice of $9$ and so on). Answer:\n \\[= \\frac{16 \\cdot 5 \\cdot 4 \\cdot 3 \\cdot 2 \\cdot 1}{ 11 \\cdot 9 \\cdot 7 \\cdot 5 \\cdot 3 \\cdot 1}\\] \n \\[= \\frac{16 \\cdot 4 \\cdot 2}{11 \\cdot 9 \\cdot 7}\\] \n \\[= \\frac{128}{ 693}\\] \nTherefore it gives us the answer of ${128 + 693 = \\boxed{821}.}$ ~Mitsuihisashi14"}
diff --git a/sunrise/verify_dequant_mxfp4.py b/sunrise/verify_dequant_mxfp4.py
new file mode 100644
index 000000000000..2e4f410a7346
--- /dev/null
+++ b/sunrise/verify_dequant_mxfp4.py
@@ -0,0 +1,173 @@
+"""Bitwise equivalence check between flashinfer's _dequant_mxfp4_on_device and
+sglang's locally-copied _dequant_mxfp4 (in w4a16_deepseek.py).
+
+Neither file is imported as a module (the flashinfer test file has heavy deps,
+the sglang file imports CUDA-specific bits). We textually extract the two
+symbols from each source file and exec them into isolated namespaces.
+
+Run (CPU-only, no CUDA required):
+ uv run python sunrise/verify_dequant_mxfp4.py
+
+Override reference path with:
+ FLASHINFER_SUNRISE_TEST_FILE=/some/other/path python sunrise/verify_dequant_mxfp4.py
+"""
+
+from __future__ import annotations
+
+import ast
+import os
+import sys
+from pathlib import Path
+from typing import Callable
+
+import torch
+
+
+_SCRIPT = Path(__file__).resolve()
+_SGLANG_REPO_ROOT = _SCRIPT.parent.parent # sunrise/ -> repo root
+OUR_FILE = (
+ _SGLANG_REPO_ROOT / "python/sglang/srt/layers/quantization/w4a16_deepseek.py"
+)
+# Default assumes the standard NDA workspace layout where flashinfer-sunrise
+# sits as a sibling of the sglang worktrees directory
+# (ws_nda/flashinfer-sunrise and ws_nda/worktrees/).
+_DEFAULT_REF = (
+ _SGLANG_REPO_ROOT.parent.parent
+ / "flashinfer-sunrise/tests/moe/test_trtllm_cutlass_fused_moe.py"
+)
+REF_FILE = Path(os.environ.get("FLASHINFER_SUNRISE_TEST_FILE", str(_DEFAULT_REF)))
+
+
+def _extract_symbols(path: Path, lut_name: str, fn_name: str) -> dict:
+ """Parse `path` with ast, pull out the `lut_name` Assign and the `fn_name`
+ FunctionDef, exec them in an isolated namespace, and return the namespace."""
+ source = path.read_text()
+ tree = ast.parse(source)
+ wanted_nodes: list[ast.stmt] = []
+ for node in tree.body:
+ if isinstance(node, ast.Assign):
+ for target in node.targets:
+ if isinstance(target, ast.Name) and target.id == lut_name:
+ wanted_nodes.append(node)
+ elif isinstance(node, ast.FunctionDef) and node.name == fn_name:
+ wanted_nodes.append(node)
+ if len(wanted_nodes) != 2:
+ raise RuntimeError(
+ f"Expected to find 1 Assign for {lut_name!r} and 1 FunctionDef for {fn_name!r} in {path}; "
+ f"got {len(wanted_nodes)} matching nodes."
+ )
+ module = ast.Module(body=wanted_nodes, type_ignores=[])
+ code = compile(module, filename=str(path), mode="exec")
+ ns: dict = {"torch": torch}
+ exec(code, ns)
+ return ns
+
+
+def main() -> int:
+ print(f"REF: {REF_FILE}")
+ print(f"OUR: {OUR_FILE}")
+ if not REF_FILE.exists():
+ print(f"ERROR: reference file not found: {REF_FILE}", file=sys.stderr)
+ return 2
+ if not OUR_FILE.exists():
+ print(f"ERROR: local file not found: {OUR_FILE}", file=sys.stderr)
+ return 2
+
+ ref_ns = _extract_symbols(REF_FILE, "_MXFP4_LUT", "_dequant_mxfp4_on_device")
+ our_ns = _extract_symbols(OUR_FILE, "_MXFP4_LUT", "_dequant_mxfp4")
+
+ ref_lut = ref_ns["_MXFP4_LUT"]
+ our_lut = our_ns["_MXFP4_LUT"]
+ assert ref_lut == our_lut, f"LUT mismatch: ref={ref_lut} our={our_lut}"
+ print(f"LUT equal: {ref_lut == our_lut} (len={len(ref_lut)})", flush=True)
+
+ ref_fn: Callable = ref_ns["_dequant_mxfp4_on_device"]
+ our_fn: Callable = our_ns["_dequant_mxfp4"]
+
+ torch.manual_seed(0)
+
+ # Shapes: last dim = K/2 (so K = 2*last_dim); K must be divisible by 32 for
+ # the UE8M0 scale derivation. The (4, 1024, 2048) shape is a DSv4-realistic
+ # slice (K=4096 = DSv4 hidden dim); full (256, 4096, 2048) OOMs a laptop
+ # during fp32 intermediate allocation.
+ shapes = [
+ (2, 2, 16), # minimal: K=32
+ (4, 8, 32), # K=64
+ (8, 256, 64), # K=128
+ (256, 64, 256), # full e=256, K=512
+ (4, 1024, 2048), # DSv4-shaped slice: K=4096
+ ]
+
+ all_ok = True
+ mismatches: list[tuple] = []
+ for shape in shapes:
+ K_half = shape[-1]
+ K = 2 * K_half
+ assert K % 32 == 0, f"K={K} not divisible by 32 for shape {shape}"
+ scale_shape = shape[:-1] + (K // 32,)
+
+ w_fp4 = torch.randint(0, 256, shape, dtype=torch.uint8)
+ # Scale UE8M0 byte 255 produces exp2(128)=inf, and inf*0 (FP4 zero
+ # nibbles) → NaN, which torch.equal treats as unequal-to-itself. Cap
+ # to <255 in this loop; NaN-position agreement is verified separately
+ # below.
+ w_scale = torch.randint(0, 255, scale_shape, dtype=torch.uint8)
+
+ ref_out = ref_fn(w_fp4, w_scale)
+ our_out = our_fn(w_fp4, w_scale_ue8m0_u8=w_scale)
+
+ ok = torch.equal(ref_out, our_out)
+ assert not ref_out.isnan().any(), "ref_out has NaN — scale cap failed"
+ assert not our_out.isnan().any(), "our_out has NaN — scale cap failed"
+ if not ok:
+ diff = (ref_out.float() - our_out.float()).abs()
+ num_mismatch = int((ref_out != our_out).sum().item())
+ max_abs_diff = float(diff.max().item())
+ mismatches.append((shape, num_mismatch, max_abs_diff))
+ all_ok = False
+ print(
+ f"MISMATCH shape={shape} numel={ref_out.numel()} "
+ f"num_mismatch={num_mismatch} max_abs_diff={max_abs_diff} "
+ f"ref_dtype={ref_out.dtype} our_dtype={our_out.dtype} "
+ f"ref_shape={tuple(ref_out.shape)} our_shape={tuple(our_out.shape)}"
+ )
+ else:
+ print(
+ f"OK shape={shape} numel={ref_out.numel()} "
+ f"dtype={ref_out.dtype} out_shape={tuple(ref_out.shape)}",
+ flush=True,
+ )
+ del ref_out, our_out, w_fp4, w_scale
+
+ # NaN-edge: scale=255 path (inf*0 → NaN). Both fns should produce NaN at
+ # identical positions and identical finite values elsewhere.
+ print("--- NaN-edge shape (scale includes 255) ---", flush=True)
+ shape_nan = (4, 8, 32)
+ scale_shape_nan = shape_nan[:-1] + (shape_nan[-1] * 2 // 32,)
+ w_fp4 = torch.randint(0, 256, shape_nan, dtype=torch.uint8)
+ w_scale = torch.randint(0, 256, scale_shape_nan, dtype=torch.uint8)
+ ref_out = ref_fn(w_fp4, w_scale)
+ our_out = our_fn(w_fp4, w_scale_ue8m0_u8=w_scale)
+ both_nan = ref_out.isnan() & our_out.isnan()
+ eq_or_both_nan = (ref_out == our_out) | both_nan
+ nan_ok = bool(eq_or_both_nan.all().item()) and torch.equal(
+ ref_out.isnan(), our_out.isnan()
+ )
+ print(
+ f"NaN-edge agree (incl NaN positions): {nan_ok} "
+ f"(ref_nans={int(ref_out.isnan().sum())}, our_nans={int(our_out.isnan().sum())})",
+ flush=True,
+ )
+ if not nan_ok:
+ all_ok = False
+
+ if all_ok:
+ print("ALL SHAPES BITWISE EQUAL")
+ return 0
+ else:
+ print(f"FAIL: {len(mismatches)} shape(s) mismatched: {mismatches}")
+ return 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/sunrise/verify_torch_ref_w4a16_moe.py b/sunrise/verify_torch_ref_w4a16_moe.py
new file mode 100644
index 000000000000..2ce830164fb3
--- /dev/null
+++ b/sunrise/verify_torch_ref_w4a16_moe.py
@@ -0,0 +1,149 @@
+"""Element-wise equivalence check between
+ sglang's torch_ref_w4a16_moe_forward (debug_utils/w4a16_moe_ref_related.py)
+and
+ flashinfer's cutlass_fused_moe(use_w4_group_scaling=True, swiglu_limit=...)
+on tiny random MXFP4 weights.
+
+Purpose: before running a full DSv4 bench with
+SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF=1 (which takes many hours per seed), we
+want a fast smoke that the torch ref matches the kernel on tiny shapes. If
+this smoke diverges, the bench-scale acc numbers are not comparable.
+
+Run (needs CUDA + flashinfer PR #3084 installed):
+ uv run python sunrise/verify_torch_ref_w4a16_moe.py
+"""
+
+from __future__ import annotations
+
+import torch
+from flashinfer.fused_moe import (
+ cutlass_fused_moe,
+ interleave_moe_scales_for_sm90_mixed_gemm,
+ interleave_moe_weights_for_sm90_mixed_gemm,
+)
+
+from sglang.srt.debug_utils.w4a16_moe_ref_related import torch_ref_cutlass_fused_moe
+from sglang.srt.layers.quantization.w4a16_deepseek import _dequant_mxfp4
+
+# (batch_size, hidden_size, num_experts, top_k, intermediate_size, swiglu_limit)
+# Shapes borrowed from flashinfer's W4A16_CORRECTNESS_CONFIGS so we stay inside
+# the kernel's supported envelope. Swiglu limit is the DSv4 260415 value.
+SHAPES = [
+ (4, 128, 4, 2, 128, 10.0),
+ (4, 768, 8, 2, 512, 10.0),
+ (4, 2048, 8, 4, 1024, 10.0),
+ (4, 4096, 8, 4, 2048, 10.0),
+]
+
+
+def _compute_routing(
+ router_logits: torch.Tensor, top_k: int
+) -> tuple[torch.Tensor, torch.Tensor]:
+ probs = torch.softmax(router_logits.float(), dim=-1)
+ topk_weights, topk_ids = probs.topk(top_k, dim=-1)
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+ return topk_weights.to(router_logits.dtype), topk_ids.to(torch.int64)
+
+
+def _run_one(
+ batch_size: int,
+ hidden_size: int,
+ num_experts: int,
+ top_k: int,
+ intermediate_size: int,
+ swiglu_limit_val: float,
+) -> None:
+ torch.manual_seed(42)
+ device = torch.device("cuda")
+ m, k, e, n = batch_size, hidden_size, num_experts, intermediate_size
+
+ x = torch.randn(m, k, dtype=torch.bfloat16, device=device)
+
+ w13_fp4 = torch.randint(
+ 0, 256, (e, 2 * n, k // 2), device=device, dtype=torch.uint8
+ )
+ w2_fp4 = torch.randint(
+ 0, 256, (e, k, n // 2), device=device, dtype=torch.uint8
+ )
+ w13_scale = torch.randint(
+ 118, 123, (e, 2 * n, k // 32), device=device, dtype=torch.uint8
+ )
+ w2_scale = torch.randint(
+ 118, 123, (e, k, n // 32), device=device, dtype=torch.uint8
+ )
+
+ router_logits = torch.randn(m, e, dtype=torch.bfloat16, device=device)
+ topk_weights, topk_ids = _compute_routing(router_logits, top_k)
+
+ swiglu_limit_tensor = torch.full(
+ (e,), swiglu_limit_val, dtype=torch.float32, device=device
+ )
+
+ # --- Flashinfer kernel path ---
+ w13_il = interleave_moe_weights_for_sm90_mixed_gemm(w13_fp4, "fp4")
+ w2_il = interleave_moe_weights_for_sm90_mixed_gemm(w2_fp4, "fp4")
+ w13_scale_il = interleave_moe_scales_for_sm90_mixed_gemm(w13_scale, group_size=32)
+ w2_scale_il = interleave_moe_scales_for_sm90_mixed_gemm(w2_scale, group_size=32)
+
+ flash_output = torch.zeros_like(x)
+ cutlass_fused_moe(
+ input=x,
+ token_selected_experts=topk_ids.to(torch.int32).contiguous(),
+ token_final_scales=topk_weights.to(torch.float32).contiguous(),
+ fc1_expert_weights=w13_il,
+ fc2_expert_weights=w2_il,
+ output_dtype=torch.bfloat16,
+ quant_scales=[w13_scale_il.view(torch.int32), w2_scale_il.view(torch.int32)],
+ swiglu_limit=swiglu_limit_tensor,
+ ep_size=1,
+ ep_rank=0,
+ tp_size=1,
+ tp_rank=0,
+ use_w4_group_scaling=True,
+ output=flash_output,
+ )
+
+ # --- Torch ref path ---
+ # Kernel and ref see the same raw FP4 tensor; ref's chunk(dim=0) gives
+ # (w3, w1) which matches flashinfer's own reference convention in
+ # _run_w4a16_moe_hopper. No explicit reorder needed on either side.
+ w13_bf16 = _dequant_mxfp4(w13_fp4, w13_scale)
+ w2_bf16 = _dequant_mxfp4(w2_fp4, w2_scale)
+
+ ref_output = torch.zeros_like(x)
+ torch_ref_cutlass_fused_moe(
+ input=x,
+ token_selected_experts=topk_ids.to(torch.int32).contiguous(),
+ token_final_scales=topk_weights.to(torch.float32).contiguous(),
+ fc1_expert_weights=w13_bf16,
+ fc2_expert_weights=w2_bf16,
+ output_dtype=torch.bfloat16,
+ swiglu_limit=swiglu_limit_tensor,
+ ep_size=1,
+ ep_rank=0,
+ output=ref_output,
+ )
+
+ # Compare
+ diff = (ref_output.float() - flash_output.float()).abs()
+ tol = 0.1 + 1e-1 * ref_output.float().abs()
+ close_pct = (diff <= tol).float().mean().item()
+ max_abs = diff.max().item()
+ print(
+ f"m={m} k={k} e={e} top_k={top_k} n={n} limit={swiglu_limit_val} "
+ f"close%={close_pct:.4%} max_abs={max_abs:.4f}"
+ )
+ assert close_pct >= 0.99, (
+ f"torch-ref vs kernel mismatch: only {close_pct:.4%} within tol; "
+ f"max_abs={max_abs:.4f}"
+ )
+
+
+def main() -> None:
+ for cfg in SHAPES:
+ _run_one(*cfg)
+ print("ALL PASS")
+
+
+if __name__ == "__main__":
+ main()