Skip to content

refactor(moe): centralize post-experts all-reduce skip predicate#23748

Merged
Kangyan-Zhou merged 2 commits into
sgl-project:mainfrom
ByronHsu:byron/refactor-post-experts-allreduce-helper
Apr 27, 2026
Merged

refactor(moe): centralize post-experts all-reduce skip predicate#23748
Kangyan-Zhou merged 2 commits into
sgl-project:mainfrom
ByronHsu:byron/refactor-post-experts-allreduce-helper

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu commented Apr 26, 2026

Motivation

The post-experts EP and TP all-reduce paths in MoE models gate on the same growing list of "downstream will absorb the all-reduce" predicates:

  • should_allreduce_fusionLayerCommunicator will fuse with next layer's residual all-reduce
  • use_reduce_scatterLayerCommunicator's post-attention scatter does reduce-scatter
  • should_use_dp_reduce_scatterv() — DP reduce-scatterv combine path (Replace all-reduce + dp_scatter with reduce_scatterv for DP attention #22642)
  • should_use_flashinfer_cutlass_moe_fp4_allgather() — flashinfer kernel absorbs the TP all-reduce (TP path only)

Each new skip path has historically been added by sweeping every model file by hand. EP-vs-TP drift on this single predicate has produced two recent correctness bugs:

This PR centralizes the predicate so:

  1. Adding a new skip reason is a one-line change in one place instead of a sweep.
  2. EP and TP can't drift apart by accident — the only difference is the boolean is_tp_path, which selects the TP-only flashinfer guard.

Modifications

Commit 1 — Add helper. Add should_skip_post_experts_all_reduce to python/sglang/srt/layers/moe/utils.py and export it from python/sglang/srt/layers/moe/__init__.py. Refactor Qwen3MoeSparseMoeBlock.forward_normal to use it.

def should_skip_post_experts_all_reduce(
    *,
    is_tp_path: bool,
    use_reduce_scatter: bool = False,
    should_allreduce_fusion: bool = False,
) -> bool:
    if should_allreduce_fusion or use_reduce_scatter:
        return True
    if should_use_dp_reduce_scatterv():
        return True
    if is_tp_path and should_use_flashinfer_cutlass_moe_fp4_allgather():
        return True
    return False

Commit 2 — Migrate every other MoE model. Sweep across all models with a post-experts all-reduce gated on these predicates: bailing_moe, bailing_moe_linear, deepseek_v2, exaone_moe, glm4_moe, hunyuan_v3, llada2, llama4, mimo_v2_flash, minimax_m2, qwen2_moe, sarvam_moe, sdar_moe, step3p5.

Validation

Every migration was verified by enumerating all 16 truth-table combinations of the four input flags. Categorization:

A. Byte-identical refactor — the original predicate already covered all four flags:

bailing_moe, deepseek_v2 (×2 sites), glm4_moe (×2), minimax_m2, mimo_v2_flash, qwen2_moe, sarvam_moe (×2), sdar_moe, step3p5, plus qwen3_moe (commit 1).

B. Adds flashinfer guard (latent fix) — original predicate was a strict subset. Helper now also skips when should_use_flashinfer_cutlass_moe_fp4_allgather() is True. That predicate gates on the global FP4 + flashinfer cutlass runner config, so it returns False outside that config (no-op) and correctly skips when active (avoiding double-reduce):

bailing_moe_linear, exaone_moe, llada2, llama4, hunyuan_v3 (TP path).

hunyuan_v3 EP path is byte-identical (helper omits the flashinfer check on EP). Models without should_allreduce_fusion / use_reduce_scatter in their forward signature get the helper defaults, which are no-ops.

Cat A unexpected diffs: 0
Cat B unexpected diffs: 0   (only diff: flashinfer=True alone, the latent-fix path)
Cat C unexpected diffs: 0
Cat D EP unexpected diffs: 0
Cat D TP unexpected diffs: 0   (only diff: flashinfer=True alone, the latent-fix path)

Checklist

  • Format your code according to the Code Formatting with Pre-Commit
  • Add unit tests as outlined in the Running Unit Tests
  • Update documentation / docstrings / example tutorials as needed

🤖 Generated with Claude Code

The post-experts EP and TP all-reduce paths in qwen3_moe both gate on
the same set of "downstream will absorb the all-reduce" predicates:

  - should_allreduce_fusion (LayerCommunicator fuses with next layer)
  - use_reduce_scatter (LayerCommunicator's post-attention reduce-scatter)
  - should_use_dp_reduce_scatterv() (DP reduce-scatterv combine path)
  - should_use_flashinfer_cutlass_moe_fp4_allgather() (TP path only)

Each new skip path has been added by sweeping every model file by hand,
and EP-vs-TP drift has caused two recent correctness bugs (sgl-project#23729 and
the follow-up fixed by sgl-project#23734). Centralize the predicate so adding a
new skip reason is a one-line change in one place, and EP and TP can
no longer drift apart by accident.

The helper is byte-identical to the existing qwen3_moe predicates --
verified by enumerating all 16 truth-table combinations of the four
inputs. hunyuan_v3 (the only other model with a separate EP-then-TP
post-experts all-reduce pattern) is intentionally not migrated here:
its current predicate is a strict subset, and switching it to the
helper would silently add a flashinfer guard, which is a behavioral
change that belongs in a separate PR.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

…_reduce

Migrate every MoE model with a post-experts all-reduce gated on the
"downstream will absorb the all-reduce" predicates to the helper:

  bailing_moe, bailing_moe_linear, deepseek_v2, exaone_moe, glm4_moe,
  hunyuan_v3, llada2, llama4, mimo_v2_flash, minimax_m2, qwen2_moe,
  sarvam_moe, sdar_moe, step3p5

All migrations were verified by enumerating the 16 truth-table combos
of the four input flags. Categorization:

A. Byte-identical refactor -- the original predicate already covered all
   four flags (should_allreduce_fusion, use_reduce_scatter, dp_reduce_
   scatterv, flashinfer_cutlass_moe_fp4_allgather):
     bailing_moe, deepseek_v2 (x2 sites), glm4_moe (x2), minimax_m2,
     mimo_v2_flash, qwen2_moe, sarvam_moe (x2), sdar_moe, step3p5.

B. Adds flashinfer guard (latent fix): the original predicate was a
   strict subset of the helper's TP path. Helper now also skips when
   should_use_flashinfer_cutlass_moe_fp4_allgather() is True; since
   that predicate is gated on global FP4 + flashinfer cutlass runner,
   it returns False outside that config (no-op there) and correctly
   skips when active (avoiding double-reduce):
     bailing_moe_linear, exaone_moe, llada2, llama4, hunyuan_v3 (TP).

In Cat C, hunyuan_v3 EP path is byte-identical (helper omits the
flashinfer check on EP). Models without `should_allreduce_fusion` /
`use_reduce_scatter` in their forward signature get the helper
defaults, which are no-ops.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@ByronHsu ByronHsu requested a review from fzyzcjy as a code owner April 26, 2026 04:46
@ByronHsu
Copy link
Copy Markdown
Collaborator Author

/tag-run-ci-label

@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented Apr 26, 2026

/rerun-failed-ci

1 similar comment
@ch-wan
Copy link
Copy Markdown
Collaborator

ch-wan commented Apr 26, 2026

/rerun-failed-ci

@Kangyan-Zhou Kangyan-Zhou merged commit 85376a6 into sgl-project:main Apr 27, 2026
252 of 273 checks passed
@hnyls2002 hnyls2002 mentioned this pull request Apr 29, 2026
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
…-project#23748)

Co-authored-by: Byron Hsu <byron@periodiclabs.ai>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants