Fix Qwen3 MoE: also guard EP all-reduce with not use_reduce_scatter (follow-up to #23731)#23734
Conversation
Follow-up to sgl-project#23731. The TP branch in Qwen3MoeSparseMoeBlock.forward_normal already guards on `not use_reduce_scatter` (LayerCommunicator's older reduce-scatter path, distinct from the should_use_dp_reduce_scatterv path added by sgl-project#22642). The EP branch was missing this guard, so when LayerCommunicator's post-attention scatter does reduce-scatter the EP all-reduce double-reduces and corrupts logits. This is the same root cause that affected the periodiclabs internal sglang fork; on a Qwen3-30B-A3B-Instruct-2507 run with tp=2 dp=2 ep=2 + dp_attention, max|Δlp| against tp2_only drops from 1.68 to 0.279 once both branches are guarded. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request updates the conditional logic in the forward_normal method of the Qwen3 MoE model to include use_reduce_scatter when determining whether to perform an expert parallel all-reduce. The review feedback suggests further extending this guard condition to include a check for should_use_flashinfer_cutlass_moe_fp4_allgather, which would align the expert parallel logic with the tensor parallel branch and prevent potential double-reduction issues when specialized communication paths are active.
| and not should_allreduce_fusion | ||
| and not use_reduce_scatter | ||
| and not should_use_dp_reduce_scatterv() |
There was a problem hiding this comment.
To fully align the EP all-reduce guard with the TP all-reduce guard below (lines 344-349), you should also include not should_use_flashinfer_cutlass_moe_fp4_allgather(). This ensures that if the FlashInfer specialized communication path is active, the EP all-reduce is correctly skipped to avoid double-reduction, maintaining consistency between the TP and EP branches.
and not should_allreduce_fusion
and not use_reduce_scatter
and not should_use_flashinfer_cutlass_moe_fp4_allgather()
and not should_use_dp_reduce_scatterv()…follow-up to sgl-project#23731) (sgl-project#23734) Co-authored-by: Byron Hsu <byron@periodiclabs.ai> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Motivation
Follow-up to #23731 (which fixed double-reduce when DP attention + EP + reduce_scatterv).
The TP branch in
Qwen3MoeSparseMoeBlock.forward_normalalready guards onnot use_reduce_scatter— the olderLayerCommunicator-driven reduce-scatter path, which is distinct from theshould_use_dp_reduce_scatterv()path added by #22642. The EP branch was missing this guard, so whenLayerCommunicator's post-attention scatter does reduce-scatter, the EP all-reduce double-reduces and corrupts logits.This is the same root cause as #23729 / #23731, just on the older
use_reduce_scattercodepath instead of the newershould_use_dp_reduce_scattervone. Both guards belong on the EP branch (and the TP branch already has both).Modifications
Add
not use_reduce_scatterto the EP all-reduce guard, mirroring the TP branch.Checklist
🤖 Generated with Claude Code