diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index 4744c0c31859..67ef6ca79d12 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -527,7 +527,10 @@ def __init__( self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() def forward_normal_dual_stream( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() @@ -548,21 +551,32 @@ def forward_normal_dual_stream( current_stream.wait_stream(self.alt_stream) if self.ep_size > 1: - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if ( + self.tp_size > 1 + and not can_fuse_mlp_allreduce + and not use_reduce_scatter + ): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states ) final_hidden_states += shared_output else: final_hidden_states += shared_output - if self.tp_size > 1 and not can_fuse_mlp_allreduce: + if ( + self.tp_size > 1 + and not can_fuse_mlp_allreduce + and not use_reduce_scatter + ): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states ) return final_hidden_states def forward_normal( - self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False + self, + hidden_states: torch.Tensor, + can_fuse_mlp_allreduce: bool = False, + use_reduce_scatter: bool = False, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -681,6 +695,7 @@ def __init__( layer_scatter_modes=self.layer_scatter_modes, input_layernorm=self.input_layernorm, post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, ) def forward(