Skip to content

Fix Qwen3 MoE: also guard EP all-reduce with not use_reduce_scatter (follow-up to #23731)#23734

Merged
ByronHsu merged 1 commit intosgl-project:mainfrom
ByronHsu:fix/qwen3-moe-ep-use-reduce-scatter-guard
Apr 26, 2026
Merged

Fix Qwen3 MoE: also guard EP all-reduce with not use_reduce_scatter (follow-up to #23731)#23734
ByronHsu merged 1 commit intosgl-project:mainfrom
ByronHsu:fix/qwen3-moe-ep-use-reduce-scatter-guard

Conversation

@ByronHsu
Copy link
Copy Markdown
Collaborator

@ByronHsu ByronHsu commented Apr 26, 2026

Motivation

Follow-up to #23731 (which fixed double-reduce when DP attention + EP + reduce_scatterv).

The TP branch in Qwen3MoeSparseMoeBlock.forward_normal already guards on not use_reduce_scatter — the older LayerCommunicator-driven reduce-scatter path, which is distinct from the should_use_dp_reduce_scatterv() path added by #22642. The EP branch was missing this guard, so when LayerCommunicator's post-attention scatter does reduce-scatter, the EP all-reduce double-reduces and corrupts logits.

This is the same root cause as #23729 / #23731, just on the older use_reduce_scatter codepath instead of the newer should_use_dp_reduce_scatterv one. Both guards belong on the EP branch (and the TP branch already has both).

Modifications

Add not use_reduce_scatter to the EP all-reduce guard, mirroring the TP branch.

if (
    self.ep_size > 1
    and not should_allreduce_fusion
    and not use_reduce_scatter   # added
    and not should_use_dp_reduce_scatterv()
):
    final_hidden_states = moe_expert_parallel_all_reduce(final_hidden_states)

Checklist

  • Format your code according to the Code Formatting with Pre-Commit
  • Add unit tests as outlined in the Running Unit Tests
  • Update documentation / docstrings / example tutorials as needed

🤖 Generated with Claude Code

Follow-up to sgl-project#23731. The TP branch in Qwen3MoeSparseMoeBlock.forward_normal
already guards on `not use_reduce_scatter` (LayerCommunicator's older
reduce-scatter path, distinct from the should_use_dp_reduce_scatterv
path added by sgl-project#22642). The EP branch was missing this guard, so when
LayerCommunicator's post-attention scatter does reduce-scatter the EP
all-reduce double-reduces and corrupts logits.

This is the same root cause that affected the periodiclabs internal
sglang fork; on a Qwen3-30B-A3B-Instruct-2507 run with
tp=2 dp=2 ep=2 + dp_attention, max|Δlp| against tp2_only drops from
1.68 to 0.279 once both branches are guarded.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the conditional logic in the forward_normal method of the Qwen3 MoE model to include use_reduce_scatter when determining whether to perform an expert parallel all-reduce. The review feedback suggests further extending this guard condition to include a check for should_use_flashinfer_cutlass_moe_fp4_allgather, which would align the expert parallel logic with the tensor parallel branch and prevent potential double-reduction issues when specialized communication paths are active.

Comment on lines 337 to 339
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_dp_reduce_scatterv()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To fully align the EP all-reduce guard with the TP all-reduce guard below (lines 344-349), you should also include not should_use_flashinfer_cutlass_moe_fp4_allgather(). This ensures that if the FlashInfer specialized communication path is active, the EP all-reduce is correctly skipped to avoid double-reduction, maintaining consistency between the TP and EP branches.

            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
            and not should_use_dp_reduce_scatterv()

@ByronHsu ByronHsu merged commit 71029ab into sgl-project:main Apr 26, 2026
57 of 65 checks passed
ByronHsu added a commit that referenced this pull request Apr 26, 2026
…follow-up to #23731) (#23734)

Co-authored-by: Byron Hsu <byron@periodiclabs.ai>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
vguduruTT pushed a commit to vguduruTT/sglang that referenced this pull request May 2, 2026
…follow-up to sgl-project#23731) (sgl-project#23734)

Co-authored-by: Byron Hsu <byron@periodiclabs.ai>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant