Port MXFP4 Marlin MoE support to JIT kernel path#24490
Conversation
… kernel path Kernel-level changes (marlin_template.h): - Add is_8bit_scale generalization replacing hardcoded kFE2M1f checks - Optimize read_moe_block_data with async cp_async4_pred + warp reduce - Fix scale stride/fetch logic for 8-bit scale types - Enable dequant_fp8_scales with E4M3/E8M0 compile guards - Unify scale/zp pipe fetch with div_ceil pattern Python-level changes: - Add MXFP4 detection and fp16->bf16 upcast in Marlin runner - Add swiglu_limit_func and clamp_limit support in fused_marlin_moe - Add get_scalar_type support for float8_e8m0fnu -> float4_e2m1f - Create marlin_utils_fp4.py for MXFP4 weight preparation - Create mxfp4_marlin_moe.py with Mxfp4MarlinMoEMethod class - Add Marlin MXFP4 routing in fp8.py get_quant_method
There was a problem hiding this comment.
Code Review
This pull request introduces MXFP4 (E8M0 scales) support for Mixture of Experts (MoE) layers using the Marlin backend, specifically optimized for DeepSeek-V3/V4. The changes include significant updates to the Marlin CUDA template to support 8-bit scales, asynchronous data loading, and warp-level reductions, alongside new Python utilities for weight repacking and scale processing. Review feedback highlighted a high-severity issue where the routed_scaling_factor was bypassed in the MXFP4 reduction path, potentially leading to incorrect results. Additionally, improvements were suggested regarding backend-specific guard flags and the inclusion of expert_map in MarlinMoeQuantInfo to ensure compatibility with Expert Parallelism.
| if is_mxfp4_marlin: | ||
| return torch.sum(intermediate_cache3, dim=1, out=output) |
There was a problem hiding this comment.
The routed_scaling_factor is ignored when is_mxfp4_marlin is true. This will lead to incorrect results for models that rely on this scaling factor (like DeepSeek-V3/V4). You should use moe_sum_reduce which correctly handles the scaling factor and is more efficient than torch.sum.
if is_mxfp4_marlin:
if routed_scaling_factor is None:
routed_scaling_factor = 1.0
moe_sum_reduce(
intermediate_cache3,
output,
routed_scaling_factor,
)
return output| if getattr(layer, "_mega_moe_weights_built", False): | ||
| return |
There was a problem hiding this comment.
The guard _mega_moe_weights_built is specific to the DeepGEMM backend. For Marlin, you should use a backend-specific flag or set this flag after processing to avoid redundant repacking if process_weights_after_loading is called multiple times.
| if getattr(layer, "_mega_moe_weights_built", False): | |
| return | |
| if getattr(layer, "_dsv4_mxfp4_backend", None) == "marlin": | |
| return |
| quant_info = MarlinMoeQuantInfo( | ||
| w13_qweight=layer.w13_weight, | ||
| w2_qweight=layer.w2_weight, | ||
| w13_scales=layer.w13_weight_scale_inv, | ||
| w2_scales=layer.w2_weight_scale_inv, | ||
| w13_g_idx_sort_indices=None, | ||
| w2_g_idx_sort_indices=None, | ||
| weight_bits=4, | ||
| is_k_full=True, | ||
| ) |
There was a problem hiding this comment.
The expert_map is missing from MarlinMoeQuantInfo. This will cause issues when running with Expert Parallelism (EP), as the Marlin kernel relies on this map to identify local experts.
| quant_info = MarlinMoeQuantInfo( | |
| w13_qweight=layer.w13_weight, | |
| w2_qweight=layer.w2_weight, | |
| w13_scales=layer.w13_weight_scale_inv, | |
| w2_scales=layer.w2_weight_scale_inv, | |
| w13_g_idx_sort_indices=None, | |
| w2_g_idx_sort_indices=None, | |
| weight_bits=4, | |
| is_k_full=True, | |
| ) | |
| quant_info = MarlinMoeQuantInfo( | |
| w13_qweight=layer.w13_weight, | |
| w2_qweight=layer.w2_weight, | |
| w13_scales=layer.w13_weight_scale_inv, | |
| w2_scales=layer.w2_weight_scale_inv, | |
| w13_g_idx_sort_indices=None, | |
| w2_g_idx_sort_indices=None, | |
| weight_bits=4, | |
| is_k_full=True, | |
| expert_map=getattr(layer, "expert_map", None), | |
| ) |
Rename intermediate_size to intermediate_size_per_partition to match the call site in layer.py, which passes this as a keyword argument.
The MXFP4 path used torch.sum without applying routed_scaling_factor, while the non-MXFP4 path applied it via moe_sum_reduce. This caused models with large scaling factors (e.g. Pro rsf=2.5) to produce degraded output as MoE contributions were scaled down.
…SE_RSF_SHARED_ADD Replaces the previous approach of applying routed_scaling_factor directly in fused_marlin_moe.py. Now mirrors the Blackwell Mxfp4FlashinferTrtllmMoEMethod pattern exactly: - ENV ON (default): rsf applied in maybe_fuse_routed_scale_and_shared_add - ENV OFF: rsf applied in Mxfp4MarlinMoEMethod.apply() directly Verified: DeepSeek-V4-Pro TP8 GSM8K 98% (49/50).
Remove 7 SGLANG_MXFP4_MARLIN_* debug environment variables that were used during development to test weight/scale transformation steps. All code paths are now hardcoded to their validated defaults.
Summary
dsv4-rebase--moe-runner-backend marlinon Hopper GPUsChanges
CUDA kernel (
marlin_template.h)is_8bit_scalegeneralization (replaces hardcodedkFE2M1fchecks)read_moe_block_datawithcp_async4_pred+ warp-level__reduce_add_syncdequant_fp8_scaleswith compile-time guardsdiv_ceilpatternPython
fused_marlin_moe.py: MXFP4 detection (float8_e8m0fnuscales),swiglu_limit_func,clamp_limitparameter, conditionalmoe_sum_reducevstorch.summarlin.py(runner): FP16→BF16 upcast for MXFP4,clamp_limitforwardingmarlin_utils_fp4.py(new): Weight repack + scale permutation for MXFP4 Marlinmxfp4_marlin_moe.py(new):Mxfp4MarlinMoEMethodclass (weight loading, runner creation, inference)fp8.py: Marlin MXFP4 routing inget_quant_method()+ early return in FP4 weight processingCorrectness Verification
AIME25 (DeepSeek-V4-Pro, MXFP4 Marlin, H200 × 8)
Server config:
--tp 8 --moe-runner-backend marlin --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --mem-fraction-static 0.85Eval config:
sgl-eval run aime25 --n-repeats 16 --temperature 1.0 --top-p 1.0 --max-tokens 400000 --num-threads 8All 30 questions answered correctly by majority vote. Weakest: Q13, Q14 (13/16 = 81%), Q19 (12/16 = 75%) — all comfortably pass.
Test plan
--moe-runner-backend marlin+ EAGLE speculative decoding