Skip to content

Port MXFP4 Marlin MoE support to JIT kernel path#24490

Merged
Fridge003 merged 6 commits into
sgl-project:dsv4-rebasefrom
yhyang201:dsv4-rebase_marlin
May 7, 2026
Merged

Port MXFP4 Marlin MoE support to JIT kernel path#24490
Fridge003 merged 6 commits into
sgl-project:dsv4-rebasefrom
yhyang201:dsv4-rebase_marlin

Conversation

@yhyang201
Copy link
Copy Markdown
Collaborator

@yhyang201 yhyang201 commented May 6, 2026

Summary

  • Port kernel-level and Python-level MXFP4 (E8M0) Marlin MoE changes from Deepseek_v4 support w4(mxfp4)a16 on hopper #23686 (AOT sgl-kernel) to the JIT kernel path on dsv4-rebase
  • Enables DeepSeek-V4 MXFP4 quantized inference via --moe-runner-backend marlin on Hopper GPUs

Changes

CUDA kernel (marlin_template.h)

  • Add is_8bit_scale generalization (replaces hardcoded kFE2M1f checks)
  • Optimize read_moe_block_data with cp_async4_pred + warp-level __reduce_add_sync
  • Fix scale stride/fetch logic for 8-bit scale types (E4M3, E8M0)
  • Enable dequant_fp8_scales with compile-time guards
  • Unify scale/zp pipe fetch with div_ceil pattern

Python

  • fused_marlin_moe.py: MXFP4 detection (float8_e8m0fnu scales), swiglu_limit_func, clamp_limit parameter, conditional moe_sum_reduce vs torch.sum
  • marlin.py (runner): FP16→BF16 upcast for MXFP4, clamp_limit forwarding
  • marlin_utils_fp4.py (new): Weight repack + scale permutation for MXFP4 Marlin
  • mxfp4_marlin_moe.py (new): Mxfp4MarlinMoEMethod class (weight loading, runner creation, inference)
  • fp8.py: Marlin MXFP4 routing in get_quant_method() + early return in FP4 weight processing

Correctness Verification

AIME25 (DeepSeek-V4-Pro, MXFP4 Marlin, H200 × 8)

Server config: --tp 8 --moe-runner-backend marlin --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --mem-fraction-static 0.85

Eval config: sgl-eval run aime25 --n-repeats 16 --temperature 1.0 --top-p 1.0 --max-tokens 400000 --num-threads 8

Metric Score Threshold
pass@1 96.25% (462/480) >= 95%
cons@16 100% (30/30)

All 30 questions answered correctly by majority vote. Weakest: Q13, Q14 (13/16 = 81%), Q19 (12/16 = 75%) — all comfortably pass.

Test plan

  • Test with DeepSeek-V4-Pro MXFP4 checkpoint using --moe-runner-backend marlin + EAGLE speculative decoding
  • AIME25 pass@1 = 96.25% (threshold: >= 95%)
  • Verify no regression on existing GPTQ/AWQ Marlin models
  • Compare Marlin MXFP4 outputs with FlashInfer MXFP4 path

… kernel path

Kernel-level changes (marlin_template.h):
- Add is_8bit_scale generalization replacing hardcoded kFE2M1f checks
- Optimize read_moe_block_data with async cp_async4_pred + warp reduce
- Fix scale stride/fetch logic for 8-bit scale types
- Enable dequant_fp8_scales with E4M3/E8M0 compile guards
- Unify scale/zp pipe fetch with div_ceil pattern

Python-level changes:
- Add MXFP4 detection and fp16->bf16 upcast in Marlin runner
- Add swiglu_limit_func and clamp_limit support in fused_marlin_moe
- Add get_scalar_type support for float8_e8m0fnu -> float4_e2m1f
- Create marlin_utils_fp4.py for MXFP4 weight preparation
- Create mxfp4_marlin_moe.py with Mxfp4MarlinMoEMethod class
- Add Marlin MXFP4 routing in fp8.py get_quant_method
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces MXFP4 (E8M0 scales) support for Mixture of Experts (MoE) layers using the Marlin backend, specifically optimized for DeepSeek-V3/V4. The changes include significant updates to the Marlin CUDA template to support 8-bit scales, asynchronous data loading, and warp-level reductions, alongside new Python utilities for weight repacking and scale processing. Review feedback highlighted a high-severity issue where the routed_scaling_factor was bypassed in the MXFP4 reduction path, potentially leading to incorrect results. Additionally, improvements were suggested regarding backend-specific guard flags and the inclusion of expert_map in MarlinMoeQuantInfo to ensure compatibility with Expert Parallelism.

Comment on lines +262 to +263
if is_mxfp4_marlin:
return torch.sum(intermediate_cache3, dim=1, out=output)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The routed_scaling_factor is ignored when is_mxfp4_marlin is true. This will lead to incorrect results for models that rely on this scaling factor (like DeepSeek-V3/V4). You should use moe_sum_reduce which correctly handles the scaling factor and is more efficient than torch.sum.

    if is_mxfp4_marlin:
        if routed_scaling_factor is None:
            routed_scaling_factor = 1.0
        moe_sum_reduce(
            intermediate_cache3,
            output,
            routed_scaling_factor,
        )
        return output

Comment on lines +63 to +64
if getattr(layer, "_mega_moe_weights_built", False):
return
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The guard _mega_moe_weights_built is specific to the DeepGEMM backend. For Marlin, you should use a backend-specific flag or set this flag after processing to avoid redundant repacking if process_weights_after_loading is called multiple times.

Suggested change
if getattr(layer, "_mega_moe_weights_built", False):
return
if getattr(layer, "_dsv4_mxfp4_backend", None) == "marlin":
return

Comment on lines +102 to +111
quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_weight,
w2_qweight=layer.w2_weight,
w13_scales=layer.w13_weight_scale_inv,
w2_scales=layer.w2_weight_scale_inv,
w13_g_idx_sort_indices=None,
w2_g_idx_sort_indices=None,
weight_bits=4,
is_k_full=True,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expert_map is missing from MarlinMoeQuantInfo. This will cause issues when running with Expert Parallelism (EP), as the Marlin kernel relies on this map to identify local experts.

Suggested change
quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_weight,
w2_qweight=layer.w2_weight,
w13_scales=layer.w13_weight_scale_inv,
w2_scales=layer.w2_weight_scale_inv,
w13_g_idx_sort_indices=None,
w2_g_idx_sort_indices=None,
weight_bits=4,
is_k_full=True,
)
quant_info = MarlinMoeQuantInfo(
w13_qweight=layer.w13_weight,
w2_qweight=layer.w2_weight,
w13_scales=layer.w13_weight_scale_inv,
w2_scales=layer.w2_weight_scale_inv,
w13_g_idx_sort_indices=None,
w2_g_idx_sort_indices=None,
weight_bits=4,
is_k_full=True,
expert_map=getattr(layer, "expert_map", None),
)

yhyang201 added 5 commits May 6, 2026 07:42
Rename intermediate_size to intermediate_size_per_partition to match
the call site in layer.py, which passes this as a keyword argument.
The MXFP4 path used torch.sum without applying routed_scaling_factor,
while the non-MXFP4 path applied it via moe_sum_reduce. This caused
models with large scaling factors (e.g. Pro rsf=2.5) to produce
degraded output as MoE contributions were scaled down.
…SE_RSF_SHARED_ADD

Replaces the previous approach of applying routed_scaling_factor
directly in fused_marlin_moe.py. Now mirrors the Blackwell
Mxfp4FlashinferTrtllmMoEMethod pattern exactly:

- ENV ON (default): rsf applied in maybe_fuse_routed_scale_and_shared_add
- ENV OFF: rsf applied in Mxfp4MarlinMoEMethod.apply() directly

Verified: DeepSeek-V4-Pro TP8 GSM8K 98% (49/50).
Remove 7 SGLANG_MXFP4_MARLIN_* debug environment variables that were
used during development to test weight/scale transformation steps.
All code paths are now hardcoded to their validated defaults.
@Fridge003 Fridge003 merged commit 6efeee8 into sgl-project:dsv4-rebase May 7, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants