diff --git a/paddleformers/transformers/glm4_moe/modeling.py b/paddleformers/transformers/glm4_moe/modeling.py index 1b8d763f86a..561efe6c005 100644 --- a/paddleformers/transformers/glm4_moe/modeling.py +++ b/paddleformers/transformers/glm4_moe/modeling.py @@ -231,6 +231,9 @@ def forward( else: mix_layer = self.qkv_proj(hidden_states) if self.sequence_parallel: + max_sequence_length = self.config.max_sequence_length + bsz = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length + q_len = max_sequence_length target_shape = [ bsz, q_len, @@ -311,13 +314,10 @@ def forward(self, hidden_states): Args: hidden_states (_type_): [batch_size * seq_len, hidden_size] """ - # compute gating score with paddle.amp.auto_cast(False): hidden_states = hidden_states.cast(self.weight.dtype) - logits = F.linear(hidden_states.cast("float32"), self.weight.cast("float32").t()) - scores = self.gate_score_func(logits=logits) scores = scores.cast(paddle.float32) diff --git a/paddleformers/transformers/moe_gate.py b/paddleformers/transformers/moe_gate.py index d666515c44d..208ffbc66bc 100644 --- a/paddleformers/transformers/moe_gate.py +++ b/paddleformers/transformers/moe_gate.py @@ -139,12 +139,32 @@ def _cal_seq_aux_loss(self, gates, top_k, topk_idx) -> paddle.Tensor: Returns: paddle.Tensor: The value of sequence auxiliary loss. """ - batch_size, seq_len, _ = gates.shape - ce = paddle.zeros([batch_size, self.num_experts]) - topk_idx = topk_idx.reshape([batch_size, -1]) - ce.put_along_axis_(indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add") - ce = ce / (seq_len * top_k / self.num_experts) - aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean() + if self.config.sequence_parallel: + # [bs * seq_len, dim] + # Todo: Temporary measure to be compatible with SP input dimensions: + # this function affects loss_aux, but the glm4moe model does not actually use this result. + # Correctness unvalidated; to be verified later. + max_sequence_length = self.config.max_sequence_length + local_total_tokens, local_num_experts = gates.shape + batch_size = local_total_tokens * self.config.tensor_parallel_degree // max_sequence_length + seq_len = max_sequence_length + ce = paddle.zeros([local_total_tokens, local_num_experts]) + ce.put_along_axis_( + indices=topk_idx, values=paddle.ones_like(topk_idx, dtype=ce.dtype), axis=1, reduce="add" + ) + ce = ce / (top_k / local_num_experts) + gates_mean = paddle.mean(gates, axis=tuple(range(len(gates.shape) - 1))).unsqueeze(0) + aux_loss = (ce * gates_mean).sum(axis=1).mean() + else: + # [bs, seq_len, dim] + batch_size, seq_len, num_experts = gates.shape + ce = paddle.zeros([batch_size, self.num_experts]) + topk_idx = topk_idx.reshape([batch_size, -1]) + ce.put_along_axis_( + indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add" + ) + ce = ce / (seq_len * top_k / self.num_experts) + aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean() return aux_loss def _cal_z_loss(self, logits) -> paddle.Tensor: @@ -473,7 +493,8 @@ def topkgating( gates: paddle.Tensor, ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: """Implements TopKGating on logits.""" - batch_size, seq_len, d_model = gates.shape + # batch_size, seq_len, d_model = gates.shape + d_model = gates.shape[-1] gates_ori = gates gates = gates.reshape([-1, d_model]) @@ -553,7 +574,8 @@ def topkgating( def topkgating_nodrop(self, gates: paddle.Tensor): """Implements TopKGating on logits.""" - batch_size, seq_len, d_model = gates.shape + # batch_size, seq_len, d_model = gates.shape + d_model = gates.shape[-1] gates_ori = gates gates = gates.reshape([-1, d_model])