Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddleformers/transformers/glm4_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def forward(
else:
mix_layer = self.qkv_proj(hidden_states)
if self.sequence_parallel:
max_sequence_length = self.config.max_sequence_length
bsz = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length
q_len = max_sequence_length
target_shape = [
bsz,
q_len,
Expand Down Expand Up @@ -311,13 +314,10 @@ def forward(self, hidden_states):
Args:
hidden_states (_type_): [batch_size * seq_len, hidden_size]
"""

# compute gating score
with paddle.amp.auto_cast(False):
hidden_states = hidden_states.cast(self.weight.dtype)

logits = F.linear(hidden_states.cast("float32"), self.weight.cast("float32").t())

scores = self.gate_score_func(logits=logits)
scores = scores.cast(paddle.float32)

Expand Down
38 changes: 30 additions & 8 deletions paddleformers/transformers/moe_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,32 @@ def _cal_seq_aux_loss(self, gates, top_k, topk_idx) -> paddle.Tensor:
Returns:
paddle.Tensor: The value of sequence auxiliary loss.
"""
batch_size, seq_len, _ = gates.shape
ce = paddle.zeros([batch_size, self.num_experts])
topk_idx = topk_idx.reshape([batch_size, -1])
ce.put_along_axis_(indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add")
ce = ce / (seq_len * top_k / self.num_experts)
aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean()
if self.config.sequence_parallel:
# [bs * seq_len, dim]
# Todo: Temporary measure to be compatible with SP input dimensions:
# this function affects loss_aux, but the glm4moe model does not actually use this result.
# Correctness unvalidated; to be verified later.
max_sequence_length = self.config.max_sequence_length
local_total_tokens, local_num_experts = gates.shape
batch_size = local_total_tokens * self.config.tensor_parallel_degree // max_sequence_length
seq_len = max_sequence_length
ce = paddle.zeros([local_total_tokens, local_num_experts])
ce.put_along_axis_(
indices=topk_idx, values=paddle.ones_like(topk_idx, dtype=ce.dtype), axis=1, reduce="add"
)
ce = ce / (top_k / local_num_experts)
gates_mean = paddle.mean(gates, axis=tuple(range(len(gates.shape) - 1))).unsqueeze(0)
aux_loss = (ce * gates_mean).sum(axis=1).mean()
else:
# [bs, seq_len, dim]
batch_size, seq_len, num_experts = gates.shape
ce = paddle.zeros([batch_size, self.num_experts])
topk_idx = topk_idx.reshape([batch_size, -1])
ce.put_along_axis_(
indices=topk_idx, values=paddle.ones([batch_size, seq_len * top_k]), axis=1, reduce="add"
)
ce = ce / (seq_len * top_k / self.num_experts)
aux_loss = (ce * paddle.mean(gates, axis=1)).sum(axis=1).mean()
return aux_loss

def _cal_z_loss(self, logits) -> paddle.Tensor:
Expand Down Expand Up @@ -473,7 +493,8 @@ def topkgating(
gates: paddle.Tensor,
) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Implements TopKGating on logits."""
batch_size, seq_len, d_model = gates.shape
# batch_size, seq_len, d_model = gates.shape
d_model = gates.shape[-1]
gates_ori = gates
gates = gates.reshape([-1, d_model])

Expand Down Expand Up @@ -553,7 +574,8 @@ def topkgating(

def topkgating_nodrop(self, gates: paddle.Tensor):
"""Implements TopKGating on logits."""
batch_size, seq_len, d_model = gates.shape
# batch_size, seq_len, d_model = gates.shape
d_model = gates.shape[-1]
gates_ori = gates
gates = gates.reshape([-1, d_model])

Expand Down