Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,10 @@ def __init__(
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()

def forward_normal_dual_stream(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:

current_stream = torch.cuda.current_stream()
Expand All @@ -548,21 +551,32 @@ def forward_normal_dual_stream(
current_stream.wait_stream(self.alt_stream)

if self.ep_size > 1:
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not use_reduce_scatter
Comment on lines +554 to +557
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and maintainability, extract the condition into a local variable. This makes the logic clearer and avoids repeating the same complex condition.

            should_all_reduce = (
                self.tp_size > 1
                and not can_fuse_mlp_allreduce
                and not use_reduce_scatter
            )
            if should_all_reduce:

):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
final_hidden_states += shared_output
else:
final_hidden_states += shared_output
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
return final_hidden_states

def forward_normal(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
Comment on lines +578 to 580
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use_reduce_scatter parameter is added to the signature of forward_normal, but it's not used in the method's body. This prevents the reduce_scatter optimization from being applied.

Update the conditions for calling tensor_model_parallel_all_reduce to include and not use_reduce_scatter.

if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj
Expand Down Expand Up @@ -681,6 +695,7 @@ def __init__(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While enabling allow_reduce_scatter here is correct, the forward method of Glm4MoeDecoderLayer does not pass the use_reduce_scatter flag to self.mlp. This makes the changes ineffective, as use_reduce_scatter will always be its default value (False) inside the MLP/MoE layers.

To fix this, Glm4MoeDecoderLayer.forward needs to:

  1. Calculate use_reduce_scatter by calling self.layer_communicator.should_use_reduce_scatter(forward_batch).
  2. Pass this flag to self.mlp(...).

Additionally, Glm4MoeMLP.forward needs to be updated to accept and use the use_reduce_scatter parameter.

)

def forward(
Expand Down
Loading