diff --git a/python/sglang/srt/debug_utils/w4a16_moe_ref_related.py b/python/sglang/srt/debug_utils/w4a16_moe_ref_related.py new file mode 100644 index 000000000000..6f01e7fb1b33 --- /dev/null +++ b/python/sglang/srt/debug_utils/w4a16_moe_ref_related.py @@ -0,0 +1,143 @@ +""" +Pure-torch MoE ref for W4A16 acc investigation. + +The body of ``torch_ref_cutlass_fused_moe`` adapts +``_compute_with_active_experts`` from flashinfer-sunrise PR #3084 at + tests/moe/test_trtllm_cutlass_fused_moe.py + (commit 77746b81, lines 2458-2491) +into a drop-in replacement for ``flashinfer.fused_moe.cutlass_fused_moe``: +identical signature, identical in-place output semantics. Only the subset of +kwargs actually needed to reproduce DSv4 W4A16 numerics is consumed; the rest +are accepted and ignored. + +Expectation on weights: both ``fc1_expert_weights`` and ``fc2_expert_weights`` +are already bf16 (caller dequanted FP4+UE8M0 up front — see +``DeepSeekW4A16MoEMethod.process_weights_after_loading`` under +``SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF=1``). ``quant_scales`` / +``use_w4_group_scaling`` are accepted for signature parity but unused. + +Activation: we reproduce the kernel's behavior when only ``swiglu_limit`` is +passed (no ``swiglu_alpha``/``swiglu_beta``) — SiLU on the clamped gate, +symmetric clamp on the up. This matches the sglang ``w4a16_deepseek.py`` +apply() kernel call. +""" +from __future__ import annotations + +from typing import List, Optional + +import torch +import torch.nn.functional as F + + +def torch_ref_cutlass_fused_moe( + input: torch.Tensor, + token_selected_experts: torch.Tensor, + token_final_scales: torch.Tensor, + fc1_expert_weights: torch.Tensor, + fc2_expert_weights: torch.Tensor, + output_dtype: torch.dtype, + quant_scales: Optional[List[torch.Tensor]] = None, + fc1_expert_biases: Optional[torch.Tensor] = None, + fc2_expert_biases: Optional[torch.Tensor] = None, + input_sf: Optional[torch.Tensor] = None, + swiglu_alpha: Optional[torch.Tensor] = None, + swiglu_beta: Optional[torch.Tensor] = None, + swiglu_limit: Optional[torch.Tensor] = None, + tp_size: int = 1, + tp_rank: int = 0, + ep_size: int = 1, + ep_rank: int = 0, + cluster_size: int = 1, + cluster_rank: int = 0, + output: Optional[torch.Tensor] = None, + enable_alltoall: bool = False, + use_deepseek_fp8_block_scale: bool = False, + use_w4_group_scaling: bool = False, + use_mxfp8_act_scaling: bool = False, + min_latency_mode: bool = False, + use_packed_weights: bool = False, + tune_max_num_tokens: int = 8192, + enable_pdl: Optional[bool] = None, + activation_type=None, + swizzled_input_sf: bool = True, +) -> torch.Tensor: + """Pure-torch drop-in for flashinfer ``cutlass_fused_moe`` (W4A16 path). + + Consumed args: ``input``, ``token_selected_experts``, + ``token_final_scales``, ``fc1_expert_weights``, ``fc2_expert_weights`` + (both bf16), ``output_dtype``, ``swiglu_limit``, ``ep_size``, ``ep_rank``, + ``output``. Everything else is ignored. + """ + del ( + quant_scales, + fc1_expert_biases, + fc2_expert_biases, + input_sf, + swiglu_alpha, + swiglu_beta, + tp_size, + tp_rank, + cluster_size, + cluster_rank, + enable_alltoall, + use_deepseek_fp8_block_scale, + use_w4_group_scaling, + use_mxfp8_act_scaling, + min_latency_mode, + use_packed_weights, + tune_max_num_tokens, + enable_pdl, + activation_type, + swizzled_input_sf, + ) + + assert fc1_expert_weights.dtype == torch.bfloat16, ( + f"torch-ref expects bf16 weights, got {fc1_expert_weights.dtype}" + ) + assert fc2_expert_weights.dtype == torch.bfloat16, ( + f"torch-ref expects bf16 weights, got {fc2_expert_weights.dtype}" + ) + + num_tokens = input.shape[0] + hidden = fc2_expert_weights.shape[1] + num_local_experts = fc1_expert_weights.shape[0] + local_expert_offset = ep_rank * num_local_experts + + if output is None: + output = torch.empty( + num_tokens, hidden, dtype=output_dtype, device=input.device + ) + output.zero_() + + topk_ids_local = token_selected_experts.long() - local_expert_offset + in_range = (topk_ids_local >= 0) & (topk_ids_local < num_local_experts) + if not in_range.any(): + return output + + active_local = torch.unique(topk_ids_local[in_range]) + for eid_local in active_local.tolist(): + mask = (topk_ids_local == eid_local) & in_range + tok_idx, nth = torch.where(mask) + if tok_idx.numel() == 0: + continue + + w31 = fc1_expert_weights[eid_local] + w3, w1 = torch.chunk(w31, 2, dim=0) + w2 = fc2_expert_weights[eid_local] + + expert_in = input[tok_idx] + x1 = expert_in @ w1.t() + x3 = expert_in @ w3.t() + + if swiglu_limit is not None: + limit = swiglu_limit[eid_local] + x1 = x1.clamp(max=limit) + x3 = x3.clamp(min=-limit, max=limit) + + inter = F.silu(x1) * x3 + out = inter @ w2.t() + + weight = token_final_scales[tok_idx, nth, None].to(out.dtype) + output.index_add_(0, tok_idx, (weight * out).to(output.dtype)) + + return output diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 4277aad31c54..ebe7be599e93 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -498,6 +498,9 @@ class Envs: SGLANG_OPT_MXFP4_FUSE_RSF_SHARED_ADD = EnvBool(True) SGLANG_OPT_MXFP4_STATIC_SCALE_ONES = EnvBool(True) SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING = EnvBool(True) + SGLANG_HACK_DEBUG_W4A16_REMOVE_SWIGLU_LIMIT = EnvBool(False) + SGLANG_HACK_DEBUG_W4A16_USE_BF16_API = EnvBool(False) + SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF = EnvBool(False) SGLANG_OPT_USE_JIT_INDEXER_METADATA = EnvBool(False) SGLANG_OPT_SWIGLU_CLAMP_FUSION = EnvBool(True) SGLANG_OPT_DG_PAGED_MQA_LOGITS_CHUNK_SIZE = EnvInt(-1) diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index a47534ce7f4a..f47e522a8453 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -90,10 +90,20 @@ def __init__(self, moe_runner_config: MoeRunnerConfig): self.enable_flashinfer_mxfp4_moe = ( get_moe_runner_backend().is_flashinfer_mxfp4() ) + self.enable_flashinfer_w4a16_moe = ( + get_moe_runner_backend().is_flashinfer_w4a16() + ) self.skip_local_expert_mapping = ( self.enable_flashinfer_mxfp4_moe and envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get() ) + # Read env once at init; per-layer dispatcher instance is long-lived. + # w4a16 uses the same flashinfer-style "kernel expects global ids + + # ep_rank/ep_size for local filtering" contract as mxfp4, so the + # dispatcher must be a passthrough for both backends when the flag is on. + self.skip_local_expert_mapping = ( + self.enable_flashinfer_mxfp4_moe or self.enable_flashinfer_w4a16_moe + ) and envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get() self.num_experts = moe_runner_config.num_experts self.num_local_shared_experts = moe_runner_config.num_fused_shared_experts self.num_local_routed_experts = ( diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index ba6ca01ff140..d119d90f7335 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -60,6 +60,7 @@ class MoeRunnerBackend(Enum): FLASHINFER_TRTLLM = "flashinfer_trtllm" FLASHINFER_CUTLASS = "flashinfer_cutlass" FLASHINFER_MXFP4 = "flashinfer_mxfp4" + FLASHINFER_W4A16 = "flashinfer_w4a16" FLASHINFER_CUTEDSL = "flashinfer_cutedsl" CUTLASS = "cutlass" MARLIN = "marlin" @@ -88,6 +89,9 @@ def is_flashinfer_cutedsl(self): def is_flashinfer_mxfp4(self): return self == MoeRunnerBackend.FLASHINFER_MXFP4 + def is_flashinfer_w4a16(self): + return self == MoeRunnerBackend.FLASHINFER_W4A16 + def is_cutlass(self): return self == MoeRunnerBackend.CUTLASS diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index ed74b1d47e52..1371742bc5f3 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -196,6 +196,16 @@ def get_quant_method( ) return DeepSeekMxfp4MoEMethod(fp8_method, prefix=prefix) + if ( + envs.SGLANG_DSV4_MODE.get() == "2604" + and envs.SGLANG_DSV4_FP4_EXPERTS.get() + and get_moe_runner_backend().is_flashinfer_w4a16() + ): + from sglang.srt.layers.quantization.w4a16_deepseek import ( + DeepSeekW4A16MoEMethod, + ) + + return DeepSeekW4A16MoEMethod(fp8_method, prefix=prefix) return fp8_method elif isinstance(layer, RadixAttention): return Fp8KVCacheMethod(self) diff --git a/python/sglang/srt/layers/quantization/w4a16_deepseek.py b/python/sglang/srt/layers/quantization/w4a16_deepseek.py new file mode 100644 index 000000000000..6c3517ccd782 --- /dev/null +++ b/python/sglang/srt/layers/quantization/w4a16_deepseek.py @@ -0,0 +1,451 @@ +""" +DeepSeek W4A16 MoE quantization method (SM90 / H200). + +Wraps Fp8MoEMethod to reuse the FP4-expert weight creation/loading, then +overrides process_weights_after_loading to pre-interleave FP4 weights and +MXFP4 block scales for the SM90 mixed-input CUTLASS kernel exposed by +flashinfer-ai/flashinfer PR #3084, and overrides apply to call +cutlass_fused_moe(..., use_w4_group_scaling=True) directly. + +This is the H200 counterpart to mxfp4_deepseek.py. The two share the same +DSv4 FP4 checkpoint (SGLANG_DSV4_MODE=2604 SGLANG_DSV4_FP4_EXPERTS=1): weight +shapes and dtypes are identical; only the post-load layout and the kernel +call differ. + +Usage: --moe-runner-backend flashinfer_w4a16 --moe-a2a-backend none + with SGLANG_DSV4_MODE=2604 SGLANG_DSV4_FP4_EXPERTS=1 +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from sglang.srt.distributed import get_tp_group +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import ( + is_flashinfer_available, + log_info_on_rank0, + set_weight_attrs, +) +from sglang.srt.utils.common import next_power_of_2 + +if is_flashinfer_available(): + from flashinfer.fused_moe import ( + cutlass_fused_moe, + interleave_moe_scales_for_sm90_mixed_gemm, + interleave_moe_weights_for_sm90_mixed_gemm, + ) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import CombineInput, DispatchOutput + + +from sglang.srt.debug_utils.sunrise_debug_utils import sunrise_moe_code_path_checker +from sglang.srt.environ import envs + + +def _fp32_to_ue8m0(x: torch.Tensor) -> torch.Tensor: + """Convert float32 → UE8M0 (float8_e8m0fnu) and assert lossless round-trip. + + UE8M0 stores only the 8-bit biased exponent (no mantissa), so only exact + powers of 2 round-trip bit-exactly. DSv4 MXFP4 block scales should already + be powers of 2 per the MXFP4 spec; if this assert fires, the checkpoint + isn't actually MXFP4-quantized and we must feed the kernel the raw E8M0 + bytes instead of round-tripping through fp32. + """ + assert x.dtype == torch.float32, f"expected float32 input, got {x.dtype}" + ans = x.to(torch.float8_e8m0fnu) + rt = ans.float() + mismatch = rt != x + if mismatch.any(): + bad_orig = x[mismatch][:5].tolist() + bad_rt = rt[mismatch][:5].tolist() + raise AssertionError( + f"fp32→UE8M0 lossy: {int(mismatch.sum())}/{x.numel()} elements " + f"changed; min/max input = {x.min().item()}/{x.max().item()}; " + f"first 5 (orig → round-trip): {list(zip(bad_orig, bad_rt))}" + ) + return ans + + +# MXFP4 4-bit codebook and dequant helper, copied verbatim (module-level +# constant `_MXFP4_LUT` + body of `_dequant_mxfp4_on_device`) from flashinfer +# PR #3084 branch at +# flashinfer-sunrise/tests/moe/test_trtllm_cutlass_fused_moe.py +# (commit 77746b81, lines 2419-2452) +# so the bf16 API path sees weights bit-equivalent to what the SM90 +# mixed-input kernel dequants inside itself. +_MXFP4_LUT = ( + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +) + + +def _dequant_mxfp4( + w_fp4_u8: torch.Tensor, w_scale_ue8m0_u8: torch.Tensor +) -> torch.Tensor: + """[E, rows, K/2] uint8 FP4 packed + [E, rows, K/32] uint8 UE8M0 → [E, rows, K] bf16.""" + lut = torch.tensor(_MXFP4_LUT, dtype=torch.float32, device=w_fp4_u8.device) + lo = w_fp4_u8 & 0x0F + hi = (w_fp4_u8 >> 4) & 0x0F + nib = torch.stack([lo, hi], dim=-1).reshape(*w_fp4_u8.shape[:-1], -1) + values = lut[nib.long()] + scale = torch.exp2(w_scale_ue8m0_u8.to(torch.float32) - 127.0) + scale = scale.repeat_interleave(32, dim=-1) + return (values * scale).to(torch.bfloat16) + + +class DeepSeekW4A16MoEMethod: + """W4A16 MoE method for DeepSeek-family models with FP4 expert weights on SM90. + + Wraps Fp8MoEMethod for weight creation/loading, but overrides + post-loading processing to pre-interleave FP4 weights and MXFP4 scales + for the flashinfer SM90 mixed-input CUTLASS kernel, and directly calls + cutlass_fused_moe(..., use_w4_group_scaling=True) in apply(). + """ + + def __init__(self, fp8_method, prefix: str): + self._fp8 = fp8_method + self.prefix = prefix + # Kept for parity with mxfp4_deepseek; unused by cutlass_fused_moe. + self.flashinfer_mxfp4_moe_precision = ( + get_global_server_args().flashinfer_mxfp4_moe_precision + ) + + def create_moe_runner(self, layer, moe_runner_config): + self.moe_runner_config = moe_runner_config + + # Sanity check: v5 (260415) ckpt's HF config has swiglu_limit=10.0; + # v4 (260409) does not. Same check as mxfp4_deepseek. + swiglu_limit = moe_runner_config.swiglu_limit + is_260415 = envs.SGLANG_DSV4_2604_SUBMODE.get() == "260415" + assert is_260415 == (swiglu_limit is not None), ( + f"swiglu_limit must be non-None iff submode=260415 " + f"(got submode={envs.SGLANG_DSV4_2604_SUBMODE.get()!r}, " + f"swiglu_limit={swiglu_limit!r})" + ) + self._swiglu_limit_tensor = ( + torch.full( + (layer.num_local_experts,), + swiglu_limit, + dtype=torch.float32, + device=layer.w13_weight.device, + ) + if swiglu_limit is not None + else None + ) + + def create_weights( + self, + layer, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype, + **extra_weight_attrs, + ): + """Create FP4-packed weights with TP-aware shapes. + + Shapes and dtypes are identical to mxfp4_deepseek (same checkpoint); + the only difference is the post-load layout produced by + process_weights_after_loading. + """ + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + fp4_block_k = 32 + + # FP4 packed weights: 2 values per byte -> physical K = logical K / 2 + w13_weight = Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + w2_weight = Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // 2, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Block scales: one float32 scale per fp4_block_k FP4 elements along K + w13_weight_scale = Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = Parameter( + torch.ones( + num_experts, + hidden_size, + intermediate_size_per_partition // fp4_block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w13_weight_scale.format_ue8m0 = False + w2_weight_scale.format_ue8m0 = False + scale_attrs = dict(extra_weight_attrs) + scale_attrs["quant_method"] = FusedMoeWeightScaleSupported.BLOCK.value + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + set_weight_attrs(w13_weight_scale, scale_attrs) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + set_weight_attrs(w2_weight_scale, scale_attrs) + + def process_weights_after_loading(self, layer: Module) -> None: + from sglang.srt.layers.quantization.utils import reorder_w1w3_to_w3w1 + + # Let Fp8MoEMethod do its processing first (FP4 view conversion, etc.). + # When SGLANG_OPT_USE_DEEPGEMM_MEGA_MOE=1 is set, that path builds the + # mega-MoE weight tuples on the layer; we must then skip the + # reorder/interleave below since mega wants the checkpoint's + # [gate(w1), up(w3)] row order. + self._fp8.process_weights_after_loading(layer) + + if getattr(layer, "_mega_moe_weights_built", False): + return + + # Reorder w13 from checkpoint [w1(gate), w3(up)] to kernel [w3(up), w1(gate)]. + # flashinfer's SM90 W4A16 test (`test_moe_bf16_mxfp4`) computes its + # reference as `w3, w1 = torch.chunk(w31, 2, dim=0)` — i.e. w3 (up) is + # the first half along dim -2, w1 (gate) is the second. Same row order + # as the B200 TRT-LLM routed kernel. + w13_w, w13_s = reorder_w1w3_to_w3w1( + layer.w13_weight.data, layer.w13_weight_scale_inv.data + ) + layer.w13_weight = Parameter(w13_w, requires_grad=False) + layer.w13_weight_scale_inv = Parameter(w13_s, requires_grad=False) + + log_info_on_rank0( + logger, + f"Interleaving FP4 expert weights/scales for SM90 W4A16 kernel " + f"(layer: {self.prefix})...", + ) + + w13 = layer.w13_weight.data + w2 = layer.w2_weight.data + w13_scale = layer.w13_weight_scale_inv.data + w2_scale = layer.w2_weight_scale_inv.data + + # Convert float32 block scales to UE8M0 (1 byte per element) before the + # layout interleave. The SM90 kernel reads E8M0 uint8 bytes. + if w13_scale.dtype == torch.float32: + w13_scale = _fp32_to_ue8m0(w13_scale) + w2_scale = _fp32_to_ue8m0(w2_scale) + + # bf16-weight debug path: dequant FP4+UE8M0 → bf16 once. Two downstream + # consumers share this: + # - BF16_API: apply() calls plain bf16 cutlass_fused_moe (skips the + # SM90 mixed-input kernel and the SM90 interleave). + # - TORCH_REF: apply() calls a pure-torch MoE forward (skips the + # flashinfer bf16 grouped GEMM too). + # Both are independent numerical references for W4A16 acc drops. + use_bf16_api = envs.SGLANG_HACK_DEBUG_W4A16_USE_BF16_API.get() + use_torch_ref = envs.SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF.get() + assert not (use_bf16_api and use_torch_ref), ( + "SGLANG_HACK_DEBUG_W4A16_USE_BF16_API and " + "SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF are mutually exclusive" + ) + if use_bf16_api or use_torch_ref: + consumer = "bf16-API" if use_bf16_api else "torch-ref" + log_info_on_rank0( + logger, + f"Dequant FP4 → bf16 for {consumer} path (layer: {self.prefix})...", + ) + w13_bf16 = _dequant_mxfp4( + w13.contiguous().view(torch.uint8), + w13_scale.contiguous().view(torch.uint8), + ) + w2_bf16 = _dequant_mxfp4( + w2.contiguous().view(torch.uint8), + w2_scale.contiguous().view(torch.uint8), + ) + layer.w13_weight = Parameter(w13_bf16, requires_grad=False) + layer.w2_weight = Parameter(w2_bf16, requires_grad=False) + # Drop scale parameters — bf16 path does not read them. Replace + # with zero-size placeholders to keep any attribute-existence + # checks happy. + layer.w13_weight_scale_inv = Parameter( + torch.empty(0, device=w13_bf16.device), requires_grad=False + ) + layer.w2_weight_scale_inv = Parameter( + torch.empty(0, device=w2_bf16.device), requires_grad=False + ) + torch.cuda.empty_cache() + return + + # Pre-interleave MXFP4 weights and scales (runs once at load time). + # Shapes after interleave: + # weights: same as input (byte-permutation only). + # scales: [E, rows, K/32] -> [E, K/(32*4), rows*4] uint8. + w13_u8 = w13.contiguous().view(torch.uint8) + w2_u8 = w2.contiguous().view(torch.uint8) + w13_scale_u8 = w13_scale.contiguous().view(torch.uint8) + w2_scale_u8 = w2_scale.contiguous().view(torch.uint8) + + w13_il = interleave_moe_weights_for_sm90_mixed_gemm(w13_u8, "fp4") + w2_il = interleave_moe_weights_for_sm90_mixed_gemm(w2_u8, "fp4") + w13_scale_il = interleave_moe_scales_for_sm90_mixed_gemm( + w13_scale_u8, group_size=32 + ) + w2_scale_il = interleave_moe_scales_for_sm90_mixed_gemm( + w2_scale_u8, group_size=32 + ) + + layer.w13_weight = Parameter(w13_il, requires_grad=False) + layer.w2_weight = Parameter(w2_il, requires_grad=False) + # Keep interleaved scales as uint8 — .view(torch.int32) at apply-time. + layer.w13_weight_scale_inv = Parameter(w13_scale_il, requires_grad=False) + layer.w2_weight_scale_inv = Parameter(w2_scale_il, requires_grad=False) + + torch.cuda.empty_cache() + + def apply( + self, + layer: Module, + dispatch_output: DispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + from sglang.srt.layers.moe.topk import TopKOutputChecker + + hidden_states = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + # --- Step A: Prepare weights and sizes --- + w13 = layer.w13_weight + w2 = layer.w2_weight + use_bf16_api = envs.SGLANG_HACK_DEBUG_W4A16_USE_BF16_API.get() + use_torch_ref = envs.SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF.get() + if use_bf16_api or use_torch_ref: + # bf16 weights path: weights already dequanted to bf16 in + # process_weights_after_loading; no scale tensors to pass. + quant_scales_arg = None + else: + quant_scales_arg = [ + layer.w13_weight_scale_inv.view(torch.int32), + layer.w2_weight_scale_inv.view(torch.int32), + ] + + # w13/w2 are pre-interleaved uint8 (W4A16) or plain bf16 (bf16-API); + # logical shapes come from the layer-configured sizes rather than + # tensor dims (interleave preserves numel but the 3D view no longer + # maps 1:1 to [E, 2*I, H/2]). + hidden_size = layer.hidden_size + + # --- Step B: Determine routing --- + if TopKOutputChecker.format_is_standard(topk_output): + topk_ids = topk_output.topk_ids + topk_weights = topk_output.topk_weights + else: + raise ValueError( + f"Unsupported topk output format for W4A16 MoE: {topk_output.format}" + ) + + # Undo StandardDispatcher's global->local+sentinel mapping so the + # flashinfer kernel (which expects global expert ids plus ep_rank/ep_size + # for local filtering) gets what it wants. Mirror the mxfp4_deepseek + # logic gated on SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING. + if not envs.SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING.get(): + local_expert_offset = layer.moe_ep_rank * layer.num_local_experts + topk_ids = torch.where( + topk_ids >= 0, + topk_ids + local_expert_offset, + topk_ids, + ) + + # --- Step C: Activations --- + # W4A16 path: bf16 activations, no quantization needed. + assert hidden_states.dtype == torch.bfloat16, ( + f"W4A16 expects bf16 activations, got {hidden_states.dtype}" + ) + x = hidden_states + origin_dim = x.shape[-1] + if hidden_size != origin_dim: + x = torch.nn.functional.pad( + x, (0, hidden_size - origin_dim), mode="constant", value=0.0 + ) + + # --- Step D: Allocate output with symmetric memory for TP all-reduce --- + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + num_tokens = x.shape[0] + symm_output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=x.device + ) + + # --- Step E: Call kernel --- + # DSv4 260415 ships a per-MoE-layer sanity counter that deepseek_v4.py + # asserts is bumped exactly once per forward (see deepseek_v4.py:2014). + # Mirror the mxfp4_deepseek bump so the checker is satisfied. + if envs.SGLANG_DSV4_2604_SUBMODE.get() == "260415" and ( + self._swiglu_limit_tensor is not None + ): + sunrise_moe_code_path_checker.observed += 1 + + swiglu_limit_arg = ( + None + if envs.SGLANG_HACK_DEBUG_W4A16_REMOVE_SWIGLU_LIMIT.get() + else self._swiglu_limit_tensor + ) + + _moe_fn = cutlass_fused_moe + if use_torch_ref: + from sglang.srt.debug_utils.w4a16_moe_ref_related import ( + torch_ref_cutlass_fused_moe as _moe_fn, + ) + + _moe_fn( + input=x, + token_selected_experts=topk_ids.to(torch.int32).contiguous(), + token_final_scales=topk_weights.to(torch.float32).contiguous(), + fc1_expert_weights=w13, + fc2_expert_weights=w2, + output_dtype=torch.bfloat16, + quant_scales=quant_scales_arg, + swiglu_limit=swiglu_limit_arg, + ep_size=layer.moe_ep_size, + ep_rank=layer.moe_ep_rank, + tp_size=1, + tp_rank=0, + use_w4_group_scaling=not use_bf16_api, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + output=symm_output, + ) + output = symm_output + + # Apply routed_scaling_factor (DSv4 = 1.5). cutlass_fused_moe has no + # routed_scaling_factor parameter, so unless we hand this to the fused + # shared-add fast path, we multiply post-hoc. See mxfp4_deepseek for + # the same rationale. + if not envs.SGLANG_OPT_MXFP4_FUSE_RSF_SHARED_ADD.get(): + rsf = layer.moe_runner_config.routed_scaling_factor + if rsf is not None and rsf != 1.0: + output.mul_(rsf) + + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 28a3c92e1b55..77b833568077 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -179,6 +179,7 @@ "flashinfer_trtllm", "flashinfer_cutlass", "flashinfer_mxfp4", + "flashinfer_w4a16", "flashinfer_cutedsl", "cutlass", ] diff --git a/sunrise/aime25_q6/__init__.py b/sunrise/aime25_q6/__init__.py new file mode 100644 index 000000000000..b8d7ce3fb16e --- /dev/null +++ b/sunrise/aime25_q6/__init__.py @@ -0,0 +1,3 @@ +DATASET_GROUP = "math" +METRICS_TYPE = "math" +GENERATION_ARGS = "++prompt_config=generic/math ++eval_type=math" diff --git a/sunrise/aime25_q6/generate.sh b/sunrise/aime25_q6/generate.sh new file mode 100755 index 000000000000..5bf2a84a2fdd --- /dev/null +++ b/sunrise/aime25_q6/generate.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +# Regenerate sunrise/aime25_q6/test.jsonl from the canonical nemo_skills aime25 dataset. +# +# Run from inside an rcli container that has the nemo_skills venv set up (see +# journal 2026-04-21-024 §Step 6a), or adjust --source-jsonl to point at any +# local copy of nemo_skills/dataset/aime25/test.jsonl. +# +# Must be run after `ns prepare_data aime25` so the source jsonl exists. +set -euo pipefail +HERE="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "$HERE/../.." && pwd)" + +SOURCE="${SOURCE:-/workspace/nemo_skills-venv/lib/python3.12/site-packages/nemo_skills/dataset/aime25/test.jsonl}" + +cd "$REPO_ROOT" +python3 sunrise/filter_nemo_skills_questions.py \ + --question-ids aime25-6 \ + --source-jsonl "$SOURCE" \ + --output-dir sunrise +echo "Wrote $HERE/test.jsonl" diff --git a/sunrise/aime25_q6/test.jsonl b/sunrise/aime25_q6/test.jsonl new file mode 100644 index 000000000000..f81781ebc7c1 --- /dev/null +++ b/sunrise/aime25_q6/test.jsonl @@ -0,0 +1 @@ +{"id": "aime25-6", "problem": "The twelve letters $A$ , $B$ , $C$ , $D$ , $E$ , $F$ , $G$ , $H$ , $I$ , $J$ , $K$ , and $L$ are randomly grouped into six pairs of letters. The two letters in each pair are placed next to each other in alphabetical order to form six two-letter words, and then those six words are listed alphabetically. For example, a possible result is $AB$ , $CJ$ , $DG$ , $EK$ , $FL$ , $HI$ . The probability that the last word listed contains $G$ is $\\frac mn$ , where $m$ and $n$ are relatively prime positive integers. Find $m+n$ .", "expected_answer": "821", "reference_solution": "Splitting up into $2$ cases: $G$ is the first letter or the second letter of the last word. Case $1:$ $G$ in first letter Notice that $A$ must take the first letter of first word, one of the letters $B$ - $F$ needs to be the second letter of a word and the rest being the first letter of a word. \nThe combinations will be $1 + 2 + 3 + 4 + 5 = 15.$ After the first $7$ letters has been decided then the last $5$ will just fill by $5!.$ This case will have $15 \\cdot 5!$ outcomes.
\nCase $2:$ $G$ in last letter Notice that $A$ - $G$ has been arranged by $A? B? C? D? E? FG,$ where the $?$ is undecided. We have another $5!$ to fill out the possible outcomes. In total, there are $16 \\cdot 5!.$ The total case will be $11 \\cdot 9 \\cdot 7 \\cdot 5 \\cdot 3 \\cdot 1$ (Consider A must be in the first letter of first word, then you have $11$ choices, then you must take the next letter in alphabetical order as mandatory, then you have a free choice of $9$ and so on). Answer:\n \\[= \\frac{16 \\cdot 5 \\cdot 4 \\cdot 3 \\cdot 2 \\cdot 1}{ 11 \\cdot 9 \\cdot 7 \\cdot 5 \\cdot 3 \\cdot 1}\\] \n \\[= \\frac{16 \\cdot 4 \\cdot 2}{11 \\cdot 9 \\cdot 7}\\] \n \\[= \\frac{128}{ 693}\\] \nTherefore it gives us the answer of ${128 + 693 = \\boxed{821}.}$ ~Mitsuihisashi14"} diff --git a/sunrise/verify_dequant_mxfp4.py b/sunrise/verify_dequant_mxfp4.py new file mode 100644 index 000000000000..2e4f410a7346 --- /dev/null +++ b/sunrise/verify_dequant_mxfp4.py @@ -0,0 +1,173 @@ +"""Bitwise equivalence check between flashinfer's _dequant_mxfp4_on_device and +sglang's locally-copied _dequant_mxfp4 (in w4a16_deepseek.py). + +Neither file is imported as a module (the flashinfer test file has heavy deps, +the sglang file imports CUDA-specific bits). We textually extract the two +symbols from each source file and exec them into isolated namespaces. + +Run (CPU-only, no CUDA required): + uv run python sunrise/verify_dequant_mxfp4.py + +Override reference path with: + FLASHINFER_SUNRISE_TEST_FILE=/some/other/path python sunrise/verify_dequant_mxfp4.py +""" + +from __future__ import annotations + +import ast +import os +import sys +from pathlib import Path +from typing import Callable + +import torch + + +_SCRIPT = Path(__file__).resolve() +_SGLANG_REPO_ROOT = _SCRIPT.parent.parent # sunrise/ -> repo root +OUR_FILE = ( + _SGLANG_REPO_ROOT / "python/sglang/srt/layers/quantization/w4a16_deepseek.py" +) +# Default assumes the standard NDA workspace layout where flashinfer-sunrise +# sits as a sibling of the sglang worktrees directory +# (ws_nda/flashinfer-sunrise and ws_nda/worktrees/). +_DEFAULT_REF = ( + _SGLANG_REPO_ROOT.parent.parent + / "flashinfer-sunrise/tests/moe/test_trtllm_cutlass_fused_moe.py" +) +REF_FILE = Path(os.environ.get("FLASHINFER_SUNRISE_TEST_FILE", str(_DEFAULT_REF))) + + +def _extract_symbols(path: Path, lut_name: str, fn_name: str) -> dict: + """Parse `path` with ast, pull out the `lut_name` Assign and the `fn_name` + FunctionDef, exec them in an isolated namespace, and return the namespace.""" + source = path.read_text() + tree = ast.parse(source) + wanted_nodes: list[ast.stmt] = [] + for node in tree.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == lut_name: + wanted_nodes.append(node) + elif isinstance(node, ast.FunctionDef) and node.name == fn_name: + wanted_nodes.append(node) + if len(wanted_nodes) != 2: + raise RuntimeError( + f"Expected to find 1 Assign for {lut_name!r} and 1 FunctionDef for {fn_name!r} in {path}; " + f"got {len(wanted_nodes)} matching nodes." + ) + module = ast.Module(body=wanted_nodes, type_ignores=[]) + code = compile(module, filename=str(path), mode="exec") + ns: dict = {"torch": torch} + exec(code, ns) + return ns + + +def main() -> int: + print(f"REF: {REF_FILE}") + print(f"OUR: {OUR_FILE}") + if not REF_FILE.exists(): + print(f"ERROR: reference file not found: {REF_FILE}", file=sys.stderr) + return 2 + if not OUR_FILE.exists(): + print(f"ERROR: local file not found: {OUR_FILE}", file=sys.stderr) + return 2 + + ref_ns = _extract_symbols(REF_FILE, "_MXFP4_LUT", "_dequant_mxfp4_on_device") + our_ns = _extract_symbols(OUR_FILE, "_MXFP4_LUT", "_dequant_mxfp4") + + ref_lut = ref_ns["_MXFP4_LUT"] + our_lut = our_ns["_MXFP4_LUT"] + assert ref_lut == our_lut, f"LUT mismatch: ref={ref_lut} our={our_lut}" + print(f"LUT equal: {ref_lut == our_lut} (len={len(ref_lut)})", flush=True) + + ref_fn: Callable = ref_ns["_dequant_mxfp4_on_device"] + our_fn: Callable = our_ns["_dequant_mxfp4"] + + torch.manual_seed(0) + + # Shapes: last dim = K/2 (so K = 2*last_dim); K must be divisible by 32 for + # the UE8M0 scale derivation. The (4, 1024, 2048) shape is a DSv4-realistic + # slice (K=4096 = DSv4 hidden dim); full (256, 4096, 2048) OOMs a laptop + # during fp32 intermediate allocation. + shapes = [ + (2, 2, 16), # minimal: K=32 + (4, 8, 32), # K=64 + (8, 256, 64), # K=128 + (256, 64, 256), # full e=256, K=512 + (4, 1024, 2048), # DSv4-shaped slice: K=4096 + ] + + all_ok = True + mismatches: list[tuple] = [] + for shape in shapes: + K_half = shape[-1] + K = 2 * K_half + assert K % 32 == 0, f"K={K} not divisible by 32 for shape {shape}" + scale_shape = shape[:-1] + (K // 32,) + + w_fp4 = torch.randint(0, 256, shape, dtype=torch.uint8) + # Scale UE8M0 byte 255 produces exp2(128)=inf, and inf*0 (FP4 zero + # nibbles) → NaN, which torch.equal treats as unequal-to-itself. Cap + # to <255 in this loop; NaN-position agreement is verified separately + # below. + w_scale = torch.randint(0, 255, scale_shape, dtype=torch.uint8) + + ref_out = ref_fn(w_fp4, w_scale) + our_out = our_fn(w_fp4, w_scale_ue8m0_u8=w_scale) + + ok = torch.equal(ref_out, our_out) + assert not ref_out.isnan().any(), "ref_out has NaN — scale cap failed" + assert not our_out.isnan().any(), "our_out has NaN — scale cap failed" + if not ok: + diff = (ref_out.float() - our_out.float()).abs() + num_mismatch = int((ref_out != our_out).sum().item()) + max_abs_diff = float(diff.max().item()) + mismatches.append((shape, num_mismatch, max_abs_diff)) + all_ok = False + print( + f"MISMATCH shape={shape} numel={ref_out.numel()} " + f"num_mismatch={num_mismatch} max_abs_diff={max_abs_diff} " + f"ref_dtype={ref_out.dtype} our_dtype={our_out.dtype} " + f"ref_shape={tuple(ref_out.shape)} our_shape={tuple(our_out.shape)}" + ) + else: + print( + f"OK shape={shape} numel={ref_out.numel()} " + f"dtype={ref_out.dtype} out_shape={tuple(ref_out.shape)}", + flush=True, + ) + del ref_out, our_out, w_fp4, w_scale + + # NaN-edge: scale=255 path (inf*0 → NaN). Both fns should produce NaN at + # identical positions and identical finite values elsewhere. + print("--- NaN-edge shape (scale includes 255) ---", flush=True) + shape_nan = (4, 8, 32) + scale_shape_nan = shape_nan[:-1] + (shape_nan[-1] * 2 // 32,) + w_fp4 = torch.randint(0, 256, shape_nan, dtype=torch.uint8) + w_scale = torch.randint(0, 256, scale_shape_nan, dtype=torch.uint8) + ref_out = ref_fn(w_fp4, w_scale) + our_out = our_fn(w_fp4, w_scale_ue8m0_u8=w_scale) + both_nan = ref_out.isnan() & our_out.isnan() + eq_or_both_nan = (ref_out == our_out) | both_nan + nan_ok = bool(eq_or_both_nan.all().item()) and torch.equal( + ref_out.isnan(), our_out.isnan() + ) + print( + f"NaN-edge agree (incl NaN positions): {nan_ok} " + f"(ref_nans={int(ref_out.isnan().sum())}, our_nans={int(our_out.isnan().sum())})", + flush=True, + ) + if not nan_ok: + all_ok = False + + if all_ok: + print("ALL SHAPES BITWISE EQUAL") + return 0 + else: + print(f"FAIL: {len(mismatches)} shape(s) mismatched: {mismatches}") + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/sunrise/verify_torch_ref_w4a16_moe.py b/sunrise/verify_torch_ref_w4a16_moe.py new file mode 100644 index 000000000000..2ce830164fb3 --- /dev/null +++ b/sunrise/verify_torch_ref_w4a16_moe.py @@ -0,0 +1,149 @@ +"""Element-wise equivalence check between + sglang's torch_ref_w4a16_moe_forward (debug_utils/w4a16_moe_ref_related.py) +and + flashinfer's cutlass_fused_moe(use_w4_group_scaling=True, swiglu_limit=...) +on tiny random MXFP4 weights. + +Purpose: before running a full DSv4 bench with +SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF=1 (which takes many hours per seed), we +want a fast smoke that the torch ref matches the kernel on tiny shapes. If +this smoke diverges, the bench-scale acc numbers are not comparable. + +Run (needs CUDA + flashinfer PR #3084 installed): + uv run python sunrise/verify_torch_ref_w4a16_moe.py +""" + +from __future__ import annotations + +import torch +from flashinfer.fused_moe import ( + cutlass_fused_moe, + interleave_moe_scales_for_sm90_mixed_gemm, + interleave_moe_weights_for_sm90_mixed_gemm, +) + +from sglang.srt.debug_utils.w4a16_moe_ref_related import torch_ref_cutlass_fused_moe +from sglang.srt.layers.quantization.w4a16_deepseek import _dequant_mxfp4 + +# (batch_size, hidden_size, num_experts, top_k, intermediate_size, swiglu_limit) +# Shapes borrowed from flashinfer's W4A16_CORRECTNESS_CONFIGS so we stay inside +# the kernel's supported envelope. Swiglu limit is the DSv4 260415 value. +SHAPES = [ + (4, 128, 4, 2, 128, 10.0), + (4, 768, 8, 2, 512, 10.0), + (4, 2048, 8, 4, 1024, 10.0), + (4, 4096, 8, 4, 2048, 10.0), +] + + +def _compute_routing( + router_logits: torch.Tensor, top_k: int +) -> tuple[torch.Tensor, torch.Tensor]: + probs = torch.softmax(router_logits.float(), dim=-1) + topk_weights, topk_ids = probs.topk(top_k, dim=-1) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights.to(router_logits.dtype), topk_ids.to(torch.int64) + + +def _run_one( + batch_size: int, + hidden_size: int, + num_experts: int, + top_k: int, + intermediate_size: int, + swiglu_limit_val: float, +) -> None: + torch.manual_seed(42) + device = torch.device("cuda") + m, k, e, n = batch_size, hidden_size, num_experts, intermediate_size + + x = torch.randn(m, k, dtype=torch.bfloat16, device=device) + + w13_fp4 = torch.randint( + 0, 256, (e, 2 * n, k // 2), device=device, dtype=torch.uint8 + ) + w2_fp4 = torch.randint( + 0, 256, (e, k, n // 2), device=device, dtype=torch.uint8 + ) + w13_scale = torch.randint( + 118, 123, (e, 2 * n, k // 32), device=device, dtype=torch.uint8 + ) + w2_scale = torch.randint( + 118, 123, (e, k, n // 32), device=device, dtype=torch.uint8 + ) + + router_logits = torch.randn(m, e, dtype=torch.bfloat16, device=device) + topk_weights, topk_ids = _compute_routing(router_logits, top_k) + + swiglu_limit_tensor = torch.full( + (e,), swiglu_limit_val, dtype=torch.float32, device=device + ) + + # --- Flashinfer kernel path --- + w13_il = interleave_moe_weights_for_sm90_mixed_gemm(w13_fp4, "fp4") + w2_il = interleave_moe_weights_for_sm90_mixed_gemm(w2_fp4, "fp4") + w13_scale_il = interleave_moe_scales_for_sm90_mixed_gemm(w13_scale, group_size=32) + w2_scale_il = interleave_moe_scales_for_sm90_mixed_gemm(w2_scale, group_size=32) + + flash_output = torch.zeros_like(x) + cutlass_fused_moe( + input=x, + token_selected_experts=topk_ids.to(torch.int32).contiguous(), + token_final_scales=topk_weights.to(torch.float32).contiguous(), + fc1_expert_weights=w13_il, + fc2_expert_weights=w2_il, + output_dtype=torch.bfloat16, + quant_scales=[w13_scale_il.view(torch.int32), w2_scale_il.view(torch.int32)], + swiglu_limit=swiglu_limit_tensor, + ep_size=1, + ep_rank=0, + tp_size=1, + tp_rank=0, + use_w4_group_scaling=True, + output=flash_output, + ) + + # --- Torch ref path --- + # Kernel and ref see the same raw FP4 tensor; ref's chunk(dim=0) gives + # (w3, w1) which matches flashinfer's own reference convention in + # _run_w4a16_moe_hopper. No explicit reorder needed on either side. + w13_bf16 = _dequant_mxfp4(w13_fp4, w13_scale) + w2_bf16 = _dequant_mxfp4(w2_fp4, w2_scale) + + ref_output = torch.zeros_like(x) + torch_ref_cutlass_fused_moe( + input=x, + token_selected_experts=topk_ids.to(torch.int32).contiguous(), + token_final_scales=topk_weights.to(torch.float32).contiguous(), + fc1_expert_weights=w13_bf16, + fc2_expert_weights=w2_bf16, + output_dtype=torch.bfloat16, + swiglu_limit=swiglu_limit_tensor, + ep_size=1, + ep_rank=0, + output=ref_output, + ) + + # Compare + diff = (ref_output.float() - flash_output.float()).abs() + tol = 0.1 + 1e-1 * ref_output.float().abs() + close_pct = (diff <= tol).float().mean().item() + max_abs = diff.max().item() + print( + f"m={m} k={k} e={e} top_k={top_k} n={n} limit={swiglu_limit_val} " + f"close%={close_pct:.4%} max_abs={max_abs:.4f}" + ) + assert close_pct >= 0.99, ( + f"torch-ref vs kernel mismatch: only {close_pct:.4%} within tol; " + f"max_abs={max_abs:.4f}" + ) + + +def main() -> None: + for cfg in SHAPES: + _run_one(*cfg) + print("ALL PASS") + + +if __name__ == "__main__": + main()