diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index adaac751155..89188d62580 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -440,17 +440,17 @@ def moe(self, hidden_states: paddle.Tensor, topk_indices: paddle.Tensor, topk_we return final_hidden_states.cast(hidden_states.dtype) def forward(self, hidden_states): - residuals = hidden_states - orig_shape = hidden_states.shape if self.sequence_parallel: hidden_states = GatherOp.apply(hidden_states) + residuals = hidden_states + orig_shape = hidden_states.shape topk_indices, topk_weights = self.gate(hidden_states) hidden_states = hidden_states.reshape((-1, hidden_states.shape[-1])) hidden_states = self.moe(hidden_states, topk_indices, topk_weights) - if self.sequence_parallel: - hidden_states = ScatterOp.apply(hidden_states) hidden_states = paddle.reshape(hidden_states, orig_shape) hidden_states = hidden_states + self.shared_experts(residuals) + if self.sequence_parallel: + hidden_states = ScatterOp.apply(hidden_states) return hidden_states