Skip to content

[Qwen3.5] Qwen3.5-27B inference repeat bug fix#19411

Merged
ispobock merged 1 commit intosgl-project:mainfrom
AlfredYyong:qwen3_5-tp-allreduce-fusion-last-layer
Feb 26, 2026
Merged

[Qwen3.5] Qwen3.5-27B inference repeat bug fix#19411
ispobock merged 1 commit intosgl-project:mainfrom
AlfredYyong:qwen3_5-tp-allreduce-fusion-last-layer

Conversation

@AlfredYyong
Copy link
Contributor

@AlfredYyong AlfredYyong commented Feb 26, 2026

Motivation

fix #19393
fix #19322

When deploying the Qwen3.5-27B model with tp=2, the model produces repetitive (degenerate) outputs, while tp=1 works correctly.

Root Cause

Qwen3_5LinearDecoderLayer and Qwen3_5AttentionDecoderLayer do not pass is_last_layer to LayerCommunicator (defaults to False for all layers, including the last one).

When tp >= 2 on sm90+ GPUs with flashinfer available, should_fuse_mlp_allreduce_with_next_layer() returns True for every layer including the last one, because not self.is_last_layer is always True. This causes the last layer's MLP output to skip all-reduce and postprocess_layer, but there is no subsequent layer to perform the deferred all-reduce. The final self.norm(hidden_states, residual) then adds un-reduced partial MLP output to the already-reduced residual, producing incorrect hidden states per TP rank, which leads to wrong logits and repetitive text generation.

With tp=1, should_fuse_mlp_allreduce_with_next_layer() always returns False (requires tp_size > 1), so the issue never triggers.

All other models using allreduce fusion (DeepSeek V2, Qwen3 MoE, GLM4 MoE, SDAR MoE, etc.) correctly set is_last_layer.

Modifications

Added is_last_layer=(layer_id == config.num_hidden_layers - 1) to LayerCommunicator initialization in both:

  • Qwen3_5LinearDecoderLayer (line 355)
  • Qwen3_5AttentionDecoderLayer (line 546)

This is consistent with how other models handle it, e.g.:

  • qwen3_moe.py: is_last_layer=(self.layer_id == self.config.num_hidden_layers - 1)

Accuracy Tests

  • tp=1: No change (allreduce fusion was never triggered).
  • tp=2: Previously produced repetitive/degenerate output; after fix, should produce correct output matching tp=1.

deploy cmd:

nohup python -m sglang.launch_server \
  --model Qwen/Qwen3.5-27B \
  --tp 2 \
  --reasoning-parser qwen3 \
  --tool-call-parser qwen3_coder \
  --speculative-algo NEXTN \
  --speculative-num-steps 3 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4 \
  --host 0.0.0.0 \
  --port 8123 > sgl_server.log 2>&1 &

fix before:

curl --location --request POST 'http://localhost:8123/v1/chat/completions' \
--header 'Content-Type: application/json' \
--data-raw '{
  "model": "qwen35-27b",
  "messages": [
    {
      "role": "system",
      "content": "you are a useful assistant."
    },
    {
      "role": "user",
      "content": "why are electric vehicle batteries more durable than mobile phone batteries?"
    }
  ],
  "max_tokens": 1000,
  "stream": false
}'

output:
image

=================

after fixing:

image

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @AlfredYyong, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a significant issue affecting the Qwen3.5-27B model's inference stability and output quality when deployed with tensor parallelism. By correctly configuring the is_last_layer parameter within the LayerCommunicator, the fix ensures that all-reduce operations are properly handled across layers, preventing the generation of repetitive and incorrect text, thereby improving model reliability for distributed inference.

Highlights

  • Qwen3.5-27B Inference Bug Fix: Resolved a critical bug in the Qwen3.5-27B model that caused repetitive (degenerate) output when using tensor parallelism (tp=2) on sm90+ GPUs with FlashInfer, while tp=1 worked correctly.
  • Root Cause Identification: Identified that Qwen3_5LinearDecoderLayer and Qwen3_5AttentionDecoderLayer incorrectly defaulted is_last_layer to False in LayerCommunicator, leading to skipped all-reduce operations for the last layer's MLP output.
  • is_last_layer Parameter Addition: Added is_last_layer=(layer_id == config.num_hidden_layers - 1) to the LayerCommunicator initialization in both Qwen3_5LinearDecoderLayer and Qwen3_5AttentionDecoderLayer to ensure correct all-reduce fusion.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/models/qwen3_5.py
    • Added the is_last_layer parameter to the LayerCommunicator initialization in Qwen3_5LinearDecoderLayer.
    • Added the is_last_layer parameter to the LayerCommunicator initialization in Qwen3_5AttentionDecoderLayer.
Activity
  • Fixed issues [Bug] Qwen3.5-27B None Response #19393 and [Bug] The model Qwen3.5-35B-A3B outputs None during inference on h20. #19322, which reported repetitive output for Qwen3.5-27B with tp=2.
  • Provided a detailed root cause analysis explaining the interaction between is_last_layer, tp >= 2, and should_fuse_mlp_allreduce_with_next_layer().
  • Described the specific code modifications made to Qwen3_5LinearDecoderLayer and Qwen3_5AttentionDecoderLayer.
  • Included visual evidence (images) demonstrating the improvement in output quality after the fix for tp=2.
  • Updated the PR checklist, marking code formatting and accuracy/speed benchmark items as complete.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes a bug causing repetitive outputs in the Qwen3.5-27B model when using tensor parallelism. The issue was due to the is_last_layer flag not being passed to the LayerCommunicator, which led to an incorrect all-reduce fusion on the final layer. The fix applies this flag in both Qwen3_5LinearDecoderLayer and Qwen3_5AttentionDecoderLayer, which is the correct approach. I've also left a medium-severity comment suggesting a refactoring to address code duplication between these two decoder layer classes, which would improve long-term maintainability.

input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
is_last_layer=(layer_id == config.num_hidden_layers - 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While adding is_last_layer is correct, it highlights a broader code duplication issue. The entire LayerCommunicator initialization block (lines 350-356) is identical to the one in Qwen3_5AttentionDecoderLayer (lines 541-547). Furthermore, the MLP forward logic in both forward methods is also duplicated (e.g., lines 377-403 and 625-650).

To improve maintainability and reduce redundancy, consider refactoring this common logic into a shared base class or a helper function. This would centralize the logic and ensure future changes are applied consistently.

@JustinTong0323
Copy link
Collaborator

/tag-and-rerun-ci

@yizhang2077
Copy link
Collaborator

yizhang2077 commented Feb 26, 2026

I think this pr is reasonable, but allreduce + GemmaRMSNorm fusion is not implemented and I think this bug will not happen in this case?

@AlfredYyong
Copy link
Contributor Author

I think this pr is reasonable, but allreduce + GemmaRMSNorm fusion is not implemented and I think this bug will not happen in this case?

thanks for the review!~ The allreduce + GemmaRMSNorm fusion kernel is indeed not implemented, but i think the bug is caused by the allreduce deferral mechanism, not the fusion kernel itself.

should_fuse_mlp_allreduce_with_next_layer() does not check whether the layernorm supports forward_with_allreduce_fusion. It only checks flashinfer availability, is_last_layer, and tp > 1.

so when it returns True for the last layer, the MLP down_proj skips all-reduce and postprocess_layer() is skipped.

for intermediate layers, the next layer prepare_attn handles the deferred all-reduce via the fallback path (plain tensor_model_parallel_all_reduce + layernorm). But for the last layer, there's no next layer to execute either path, this will trigger bug

@yizhang2077
Copy link
Collaborator

make sense

@ispobock ispobock merged commit bdc1e46 into sgl-project:main Feb 26, 2026
147 of 169 checks passed
magicYang1573 pushed a commit to magicYang1573/sglang that referenced this pull request Mar 9, 2026
lawrence-harmonic added a commit to lawrence-harmonic/sglang that referenced this pull request Mar 10, 2026
zhuzilin added a commit that referenced this pull request Mar 11, 2026
cherry pick: #19411
and use old qwen-vl image process

* [slime] fix qwen3.5 and qwen-vl

Co-authored-by: Copilot <copilot@github.com>
@0xd8b
Copy link

0xd8b commented Mar 12, 2026

sglang 0.5.9 ,h20 gpu tp=2 qwen3.5-27b, modified the file, is still meet the bug ,it not work

@0xd8b
Copy link

0xd8b commented Mar 12, 2026

reply like : The user has not provided an input with any content or context. I will respond with a greeting that is friendly and friendly.

Hello, hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello! Hello!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] Qwen3.5-27B None Response [Bug] The model Qwen3.5-35B-A3B outputs None during inference on h20.

5 participants