diff --git a/paddleformers/transformers/qwen2_moe/modeling.py b/paddleformers/transformers/qwen2_moe/modeling.py index 7a45fa6f9ee..2eca14f613d 100644 --- a/paddleformers/transformers/qwen2_moe/modeling.py +++ b/paddleformers/transformers/qwen2_moe/modeling.py @@ -15,6 +15,7 @@ """Paddle Qwen2Moe model.""" from __future__ import annotations +import copy from functools import partial from typing import List, Optional, Tuple, Union @@ -22,7 +23,7 @@ import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed.fleet.recompute.recompute import recompute -from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS from ...nn.criterion.interface import CriterionLayer @@ -228,6 +229,10 @@ def __init__(self, config): self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and config.tensor_parallel_degree > 1: + config = copy.deepcopy(config) + config.sequence_parallel = False # gating self.gate = GeneralLinear.create(config.hidden_size, config.num_experts, has_bias=False, linear_type="default") @@ -240,14 +245,12 @@ def __init__(self, config): def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: """ """ - if self.config.sequence_parallel: - max_sequence_length = self.config.max_sequence_length - batch_size = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length - sequence_length = max_sequence_length - hidden_dim = hidden_states.shape[1] - else: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + if self.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + residuals = hidden_states + orig_shape = hidden_states.shape + + hidden_states = hidden_states.view([-1, hidden_states.shape[-1]]) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) @@ -258,7 +261,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = paddle.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype) + final_hidden_states = paddle.zeros( + (hidden_states.shape[-2], hidden_states.shape[-1]), dtype=hidden_states.dtype + ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated @@ -273,19 +278,29 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) if tokens_per_expert[expert_idx] <= 0.1: - continue - current_state = hidden_states[idx, None].reshape([-1, hidden_dim]) - current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1) - final_hidden_states.index_add_( - index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) - ) + if self.training and paddle.is_grad_enabled(): + fake_top_x = paddle.zeros(1, dtype=paddle.int64) + fakse_current_state = hidden_states[fake_top_x, None].reshape([-1, hidden_states.shape[-1]]) + fake_state = expert_layer(fakse_current_state * 0) + final_hidden_states.index_add_(index=fake_top_x, axis=0, value=fake_state.to(hidden_states.dtype)) + else: + continue + else: + current_state = hidden_states[idx, None].reshape([-1, hidden_states.shape[-1]]) + current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1) + final_hidden_states.index_add_( + index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) + ) + final_hidden_states = paddle.reshape(final_hidden_states, orig_shape) - shared_expert_output = self.shared_expert(hidden_states) - shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + shared_expert_output = self.shared_expert(residuals) + shared_expert_output = F.sigmoid(self.shared_expert_gate(residuals)) * shared_expert_output final_hidden_states = final_hidden_states + shared_expert_output - final_hidden_states = final_hidden_states.reshape([batch_size, sequence_length, hidden_dim]) + if self.sequence_parallel: + final_hidden_states = ScatterOp.apply(final_hidden_states) + return final_hidden_states, router_logits diff --git a/paddleformers/transformers/qwen3_moe/modeling.py b/paddleformers/transformers/qwen3_moe/modeling.py index 01b4ad3a9cb..a49c15affdf 100644 --- a/paddleformers/transformers/qwen3_moe/modeling.py +++ b/paddleformers/transformers/qwen3_moe/modeling.py @@ -15,6 +15,7 @@ """Paddle Qwen3Moe model.""" from __future__ import annotations +import copy from functools import partial from typing import List, Optional, Tuple, Union @@ -22,7 +23,7 @@ import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed.fleet.utils import recompute -from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS from ...nn.criterion.interface import CriterionLayer @@ -239,6 +240,10 @@ def __init__(self, config): self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and config.tensor_parallel_degree > 1: + config = copy.deepcopy(config) + config.sequence_parallel = False # gating self.gate = GeneralLinear.create(config.hidden_size, config.num_experts, has_bias=False, linear_type="default") @@ -248,15 +253,11 @@ def __init__(self, config): def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: """ """ + if self.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + orig_shape = hidden_states.shape - if self.config.sequence_parallel: - max_sequence_length = self.config.max_sequence_length - batch_size = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length - sequence_length = max_sequence_length - hidden_dim = hidden_states.shape[1] - else: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view([-1, hidden_dim]) + hidden_states = hidden_states.view([-1, hidden_states.shape[-1]]) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) @@ -268,7 +269,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = paddle.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype) + final_hidden_states = paddle.zeros( + (hidden_states.shape[-2], hidden_states.shape[-1]), dtype=hidden_states.dtype + ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated @@ -283,14 +286,25 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) if tokens_per_expert[expert_idx] <= 0.1: - continue - current_state = hidden_states[idx, None].reshape([-1, hidden_dim]) - current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1) - final_hidden_states.index_add_( - index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) - ) + if self.training and paddle.is_grad_enabled(): + fake_top_x = paddle.zeros(1, dtype=paddle.int64) + fakse_current_state = hidden_states[fake_top_x, None].reshape([-1, hidden_states.shape[-1]]) + fake_state = expert_layer(fakse_current_state * 0) + final_hidden_states.index_add_(index=fake_top_x, axis=0, value=fake_state.to(hidden_states.dtype)) + else: + continue + else: + current_state = hidden_states[idx, None].reshape([-1, hidden_states.shape[-1]]) + current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1) + final_hidden_states.index_add_( + index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) + ) + + final_hidden_states = paddle.reshape(final_hidden_states, orig_shape) + + if self.sequence_parallel: + final_hidden_states = ScatterOp.apply(final_hidden_states) - final_hidden_states = final_hidden_states.reshape([batch_size, sequence_length, hidden_dim]) return final_hidden_states, router_logits