Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions paddleformers/transformers/qwen2_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""Paddle Qwen2Moe model."""
from __future__ import annotations

import copy
from functools import partial
from typing import List, Optional, Tuple, Union

import paddle
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle.distributed.fleet.recompute.recompute import recompute
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp

from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
from ...nn.criterion.interface import CriterionLayer
Expand Down Expand Up @@ -228,6 +229,10 @@ def __init__(self, config):
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and config.tensor_parallel_degree > 1:
config = copy.deepcopy(config)
config.sequence_parallel = False

# gating
self.gate = GeneralLinear.create(config.hidden_size, config.num_experts, has_bias=False, linear_type="default")
Expand All @@ -240,14 +245,12 @@ def __init__(self, config):

def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
""" """
if self.config.sequence_parallel:
max_sequence_length = self.config.max_sequence_length
batch_size = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length
sequence_length = max_sequence_length
hidden_dim = hidden_states.shape[1]
else:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
residuals = hidden_states
orig_shape = hidden_states.shape
if self.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)

hidden_states = hidden_states.view([-1, hidden_states.shape[-1]])
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

Expand All @@ -258,7 +261,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = paddle.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype)
final_hidden_states = paddle.zeros(
(hidden_states.shape[-2], hidden_states.shape[-1]), dtype=hidden_states.dtype
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
Expand All @@ -273,19 +278,28 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
if tokens_per_expert[expert_idx] <= 0.1:
continue
current_state = hidden_states[idx, None].reshape([-1, hidden_dim])
current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1)
final_hidden_states.index_add_(
index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype)
)
if self.training and paddle.is_grad_enabled():
fake_top_x = paddle.zeros(1, dtype=paddle.int64)
fakse_current_state = hidden_states[fake_top_x, None].reshape([-1, hidden_states.shape[-1]])
fake_state = expert_layer(fakse_current_state * 0)
final_hidden_states.index_add_(index=fake_top_x, axis=0, value=fake_state.to(hidden_states.dtype))
else:
continue
else:
current_state = hidden_states[idx, None].reshape([-1, hidden_states.shape[-1]])
current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1)
final_hidden_states.index_add_(
index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype)
)
if self.sequence_parallel:
final_hidden_states = ScatterOp.apply(final_hidden_states)
final_hidden_states = paddle.reshape(final_hidden_states, orig_shape)

shared_expert_output = self.shared_expert(hidden_states)
shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output
shared_expert_output = self.shared_expert(residuals)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

share expert是非splayer,输入不做gather,相当于每一张卡都只跑了部分数据,TPlayer的输入应该是完整的序列。share expert的输入也应该是gather的

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

shared_expert_output = F.sigmoid(self.shared_expert_gate(residuals)) * shared_expert_output

final_hidden_states = final_hidden_states + shared_expert_output

final_hidden_states = final_hidden_states.reshape([batch_size, sequence_length, hidden_dim])
return final_hidden_states, router_logits


Expand Down
46 changes: 29 additions & 17 deletions paddleformers/transformers/qwen3_moe/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
"""Paddle Qwen3Moe model."""
from __future__ import annotations

import copy
from functools import partial
from typing import List, Optional, Tuple, Union

import paddle
import paddle.nn.functional as F
from paddle import Tensor, nn
from paddle.distributed.fleet.utils import recompute
from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp
from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp

from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS
from ...nn.criterion.interface import CriterionLayer
Expand Down Expand Up @@ -239,6 +240,10 @@ def __init__(self, config):
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and config.tensor_parallel_degree > 1:
config = copy.deepcopy(config)
config.sequence_parallel = False

# gating
self.gate = GeneralLinear.create(config.hidden_size, config.num_experts, has_bias=False, linear_type="default")
Expand All @@ -248,15 +253,11 @@ def __init__(self, config):

def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
""" """
orig_shape = hidden_states.shape
if self.sequence_parallel:
hidden_states = GatherOp.apply(hidden_states)

if self.config.sequence_parallel:
max_sequence_length = self.config.max_sequence_length
batch_size = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length
sequence_length = max_sequence_length
hidden_dim = hidden_states.shape[1]
else:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view([-1, hidden_dim])
hidden_states = hidden_states.view([-1, hidden_states.shape[-1]])
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

Expand All @@ -268,7 +269,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

final_hidden_states = paddle.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype)
final_hidden_states = paddle.zeros(
(hidden_states.shape[-2], hidden_states.shape[-1]), dtype=hidden_states.dtype
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
Expand All @@ -283,14 +286,23 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
if tokens_per_expert[expert_idx] <= 0.1:
continue
current_state = hidden_states[idx, None].reshape([-1, hidden_dim])
current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1)
final_hidden_states.index_add_(
index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype)
)
if self.training and paddle.is_grad_enabled():
fake_top_x = paddle.zeros(1, dtype=paddle.int64)
fakse_current_state = hidden_states[fake_top_x, None].reshape([-1, hidden_states.shape[-1]])
fake_state = expert_layer(fakse_current_state * 0)
final_hidden_states.index_add_(index=fake_top_x, axis=0, value=fake_state.to(hidden_states.dtype))
else:
continue
else:
current_state = hidden_states[idx, None].reshape([-1, hidden_states.shape[-1]])
current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1)
final_hidden_states.index_add_(
index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype)
)

final_hidden_states = final_hidden_states.reshape([batch_size, sequence_length, hidden_dim])
if self.sequence_parallel:
final_hidden_states = ScatterOp.apply(final_hidden_states)
final_hidden_states = paddle.reshape(final_hidden_states, orig_shape)
return final_hidden_states, router_logits


Expand Down