Skip to content

Fix reduce_scatterv producer contract for SUM_LEN#24785

Merged
ch-wan merged 1 commit into
sgl-project:mainfrom
YAMY1234:fix/dp-reduce-scatterv-sum-len
May 10, 2026
Merged

Fix reduce_scatterv producer contract for SUM_LEN#24785
ch-wan merged 1 commit into
sgl-project:mainfrom
YAMY1234:fix/dp-reduce-scatterv-sum-len

Conversation

@YAMY1234
Copy link
Copy Markdown
Collaborator

@YAMY1234 YAMY1234 commented May 9, 2026

Motivation

Related to #23554.

_scatter_hidden_states() uses DP reduce_scatterv whenever should_use_dp_reduce_scatterv() is true. However, LayerCommunicator.should_use_reduce_scatter() only reported true for _scatter_hidden_states in MAX_LEN mode. In SUM_LEN prefill, producers such as dense MLP layers could still perform their TP all-reduce before the communicator performed reduce_scatterv, which double-reduced the hidden states and caused long-prompt accuracy regressions.

Kimi K2.6 exposes this because it has an early dense MLP layer (first_k_dense_replace=1), so the dense producer path must follow the same reduce-scatter contract as the communicator postprocess path.

Modifications

Make should_use_reduce_scatter() return true for _scatter_hidden_states whenever should_use_dp_reduce_scatterv() is true. This aligns the producer contract with the communicator path: if the communicator will run DP reduce_scatterv, the producer skips its internal TP all-reduce.

Accuracy Tests

GSM8K on Kimi K2.6, GB300, TP8/DP8/EP8, DP attention, num_examples=1319, num_threads=1319, num_shots=20, max_tokens=16384.

Bad baseline before this fix:

  • score 0.5642802155504234, invalid extracted answers 422, empty responses 410.

This PR:

  • score 0.9753656658968437, correct 1267, wrong 32, invalid extracted answers 0, empty responses 0.

Full final GSM8K log:

GSM8K Config: endpoint=http://localhost:30580; num_examples=1319; max_tokens=16384; num_threads=1319; num_shots=20; temperature=default; top_p=default; top_k=default
Running GSM8K evaluation...
/usr/local/lib/python3.12/dist-packages/torchao/quantization/quant_api.py:1731: SyntaxWarning: invalid escape sequence '\.'
  """Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs).
Downloading from https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl to /tmp/test.jsonl
/tmp/test.jsonl:   0%|          | 0.00/237k [00:00<?, ?B/s]
/tmp/test.jsonl: 732kB [00:00, 38.7MB/s]
ChatCompletionSampler initialized with self.system_message=None self.temperature=0.0 self.max_tokens=16384 self.reasoning_effort=None self.extra_body=None
  0%|          | 0/1299 [00:00<?, ?it/s]
  0%|          | 1/1299 [01:39<35:58:03, 99.76s/it]
  1%|▏         | 18/1299 [01:46<1:31:51,  4.30s/it]
  2%|▏         | 20/1299 [02:23<2:13:48,  6.28s/it]
  5%|▌         | 68/1299 [02:41<28:40,  1.40s/it]
  6%|▌         | 74/1299 [04:22<1:09:25,  3.40s/it]
  8%|▊         | 100/1299 [05:42<1:05:07,  3.26s/it]
100%|██████████| 1299/1299 [05:42<00:00,  3.79it/s]
Total latency: 343.077 s
Score: 0.975
Output throughput: 3616.264 token/s
[METRIC] gsm8k_score=0.9753656658968437 labels={"model": "moonshotai/Kimi-K2.6", "eval": "gsm8k"}
[METRIC] gsm8k_latency=343.0773218280001 labels={"model": "moonshotai/Kimi-K2.6", "eval": "gsm8k"}
Writing report to /tmp/gsm8k_moonshotai_Kimi-K2.6.html
{'score:std': np.float64(0.15500801168472017), 'score': np.float64(0.9753656658968437), 'latency': 343.0773218280001, 'output_throughput': 3616.263509897623}
Writing results to /tmp/gsm8k_moonshotai_Kimi-K2.6.json
Results saved to: /logs/accuracy/gsm8k_moonshotai_Kimi-K2.6.json
GSM8K evaluation complete

Speed Tests and Profiling

This is a correctness fix for the producer/communicator reduce-scatter contract, not a performance-oriented change. The GSM8K validation run above reported total latency 343.077 s and output throughput 3616.264 token/s.

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the should_use_reduce_scatter method in communicator.py to permit reduce-scatter operations when should_use_dp_reduce_scatterv() is true, expanding the conditions beyond the previous requirement of maximum length padding. I have no feedback to provide.

@YAMY1234 YAMY1234 marked this pull request as ready for review May 9, 2026 16:08
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@YAMY1234
Copy link
Copy Markdown
Collaborator Author

YAMY1234 commented May 9, 2026

/tag-and-rerun-ci

Copy link
Copy Markdown
Contributor

@mmangkad mmangkad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

ChatCompletionSampler initialized with self.system_message=None self.temperature=1.0 self.max_tokens=512 self.reasoning_effort=None self.extra_body={'chat_template_kwargs': {'thinking': False}}
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1299/1299 [01:08<00:00, 19.00it/s]
Total latency: 69.474 s
Score: 0.965
Output throughput: 2176.574 token/s
[METRIC] gsm8k_score=0.9645881447267128 labels={"model": "moonshotai/Kimi-K2.6", "eval": "gsm8k"}
[METRIC] gsm8k_latency=69.47386905899998 labels={"model": "moonshotai/Kimi-K2.6", "eval": "gsm8k"}
Writing report to /tmp/gsm8k_moonshotai_Kimi-K2.6.html
{'score:std': np.float64(0.18481844004154702), 'score': np.float64(0.9645881447267128), 'latency': 69.47386905899998, 'output_throughput': 2176.573754249705}
Writing results to /tmp/gsm8k_moonshotai_Kimi-K2.6.json

@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented May 10, 2026

I can also confirm this fix

python3 benchmark/gsm8k/bench_sglang.py --num-shots 20 --num-questions 1209 --parallel 1209 --platinum
Loading GSM8K Platinum dataset from HuggingFace...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1209/1209 [00:50<00:00, 23.89it/s]
Accuracy: 0.979
Invalid: 0.000
Latency: 50.623 s
Output throughput: 2378.002 token/s

@b8zhong b8zhong enabled auto-merge (squash) May 10, 2026 14:16
@ch-wan ch-wan disabled auto-merge May 10, 2026 23:50
@ch-wan ch-wan merged commit b202778 into sgl-project:main May 10, 2026
239 of 267 checks passed
ltcs11 added a commit to ltcs11/sglang that referenced this pull request May 11, 2026
* main: (87 commits)
  [Fix] Disable FlashInfer allreduce fusion under deterministic inference (sgl-project#24629)
  fix: STANDALONE spec-decode hidden-size mismatch crash (sgl-project#24217)
  Followup fix for Custom AR V2 in non NVL scenarios (sgl-project#24742)
  Fix reduce_scatterv producer contract for SUM_LEN (sgl-project#24785)
  [NPU]Documentation update for communications quantization feature (sgl-project#24668)
  [Session R3] Add routed_experts_start_len for absolute routing slice control (sgl-project#24851)
  [Model] Add MiniCPM-V 4.6 support (sgl-project#24855)
  Support Intern-S2-Preview (sgl-project#24875)
  [PD] Unify dsv4 dispatch with swa (sgl-project#24888)
  Optimize MHC pipeline: DeepGemm, fused norm, fused hc_head (sgl-project#24775)
  Fix PD bootstrap failure handling (sgl-project#24772)
  [Spec] Cleanup idle stub and shape-check patterns (sgl-project#24881)
  [Bug] Add dsv4 state_type branch to mooncake disaggregation (sgl-project#24878)
  [Spec V1] Split draft-extend phase from `EagleDraftInput` into new `EagleDraftExtendInput` (sgl-project#24859)
  [Gemma4] Optimize Gemm4 with fused Q/K/V RMSNorm + per-expert FP8 ckpt loader (sgl-project#24696)
  [spec decoding] support kimi-k2.5-eagle3-mla (sgl-project#24826)
  [SPEC V2] fix: skip stale state updates in spec-v2 overlap (sgl-project#23456)
  [RL] Call torch.cuda.empty_cache() for `in-place` pause mode to avoid OOM (sgl-project#24854)
  [diffusion] CI: add cache-dit CI tests (sgl-project#19213)
  [Utils] Make request dump robust to unpicklable server_args and large meta_info (sgl-project#24767)
  ...

# Conflicts:
#	python/sglang/srt/utils/common.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants