Skip to content

fix glm4 moe#8883

Merged
zhyncs merged 1 commit intomainfrom
cheng/fix/glm4
Aug 7, 2025
Merged

fix glm4 moe#8883
zhyncs merged 1 commit intomainfrom
cheng/fix/glm4

Conversation

@ch-wan
Copy link
Copy Markdown
Collaborator

@ch-wan ch-wan commented Aug 7, 2025

Motivation

GLM4 inherits deepseek but #8539's update is not applied to GLM4. Some logits are broken.

Modifications

Accuracy Test

Benchmark & Profiling

Checklist

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @ch-wan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request aims to resolve an issue where GLM4 Mixture of Experts (MoE) models exhibited broken logits due to a missing update from a previous pull request (#8539). The core of the fix involves refining the communication strategy within the MoE layers by introducing and enabling a 'reduce-scatter' mechanism, which should correctly handle tensor parallel operations and restore proper model behavior.

Highlights

  • MoE Communication Logic Adjustment: Modified the forward_normal_dual_stream and forward_normal methods in glm4_moe.py to introduce a new boolean parameter, use_reduce_scatter. This parameter is now incorporated into the conditional logic that determines whether tensor_model_parallel_all_reduce is performed, allowing for more flexible or optimized communication strategies within the MoE layer.
  • Enable Reduce-Scatter for MoE Layer: Updated the __init__ method to pass allow_reduce_scatter=True when initializing the MoELayer. This explicitly enables a reduce-scatter communication pattern for the GLM4 MoE implementation, which is crucial for addressing the reported issue with broken logits and aligning with previous updates.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue in GLM-4 MoE models by incorporating a reduce_scatter related fix from DeepSeek V2. The changes add the use_reduce_scatter flag to avoid redundant all_reduce operations when reduce_scatter is used.

However, the implementation is incomplete. The use_reduce_scatter flag is not passed to the MLP/MoE layers, rendering the changes ineffective. I've provided comments on these issues and a suggestion to refactor duplicated code for better maintainability. Addressing these points will ensure the fix is complete.

Comment on lines +578 to 580
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use_reduce_scatter parameter is added to the signature of forward_normal, but it's not used in the method's body. This prevents the reduce_scatter optimization from being applied.

Update the conditions for calling tensor_model_parallel_all_reduce to include and not use_reduce_scatter.

layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While enabling allow_reduce_scatter here is correct, the forward method of Glm4MoeDecoderLayer does not pass the use_reduce_scatter flag to self.mlp. This makes the changes ineffective, as use_reduce_scatter will always be its default value (False) inside the MLP/MoE layers.

To fix this, Glm4MoeDecoderLayer.forward needs to:

  1. Calculate use_reduce_scatter by calling self.layer_communicator.should_use_reduce_scatter(forward_batch).
  2. Pass this flag to self.mlp(...).

Additionally, Glm4MoeMLP.forward needs to be updated to accept and use the use_reduce_scatter parameter.

Comment on lines +554 to +557
if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not use_reduce_scatter
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and maintainability, extract the condition into a local variable. This makes the logic clearer and avoids repeating the same complex condition.

            should_all_reduce = (
                self.tp_size > 1
                and not can_fuse_mlp_allreduce
                and not use_reduce_scatter
            )
            if should_all_reduce:

@zhyncs zhyncs merged commit 5b6acc1 into main Aug 7, 2025
7 of 56 checks passed
@zhyncs zhyncs deleted the cheng/fix/glm4 branch August 7, 2025 01:02
narutolhy pushed a commit to narutolhy/sglang that referenced this pull request Aug 17, 2025
MahmoudAshraf97 pushed a commit to MahmoudAshraf97/sglang that referenced this pull request Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants