From 99010beebe310de8bce31fb0402254640bcb7710 Mon Sep 17 00:00:00 2001 From: Ace-To-HYB Date: Thu, 16 Oct 2025 16:16:44 +0800 Subject: [PATCH 1/2] Fix SP issue in qwenmoe --- .../transformers/qwen2_moe/modeling.py | 52 ++++++++++++------- .../transformers/qwen3_moe/modeling.py | 46 ++++++++++------ 2 files changed, 62 insertions(+), 36 deletions(-) diff --git a/paddleformers/transformers/qwen2_moe/modeling.py b/paddleformers/transformers/qwen2_moe/modeling.py index 7a45fa6f9ee..40dc0684b49 100644 --- a/paddleformers/transformers/qwen2_moe/modeling.py +++ b/paddleformers/transformers/qwen2_moe/modeling.py @@ -15,6 +15,7 @@ """Paddle Qwen2Moe model.""" from __future__ import annotations +import copy from functools import partial from typing import List, Optional, Tuple, Union @@ -22,7 +23,7 @@ import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed.fleet.recompute.recompute import recompute -from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS from ...nn.criterion.interface import CriterionLayer @@ -228,6 +229,10 @@ def __init__(self, config): self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and config.tensor_parallel_degree > 1: + config = copy.deepcopy(config) + config.sequence_parallel = False # gating self.gate = GeneralLinear.create(config.hidden_size, config.num_experts, has_bias=False, linear_type="default") @@ -240,14 +245,12 @@ def __init__(self, config): def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: """ """ - if self.config.sequence_parallel: - max_sequence_length = self.config.max_sequence_length - batch_size = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length - sequence_length = max_sequence_length - hidden_dim = hidden_states.shape[1] - else: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + residuals = hidden_states + orig_shape = hidden_states.shape + if self.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + + hidden_states = hidden_states.view([-1, hidden_states.shape[-1]]) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) @@ -258,7 +261,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = paddle.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype) + final_hidden_states = paddle.zeros( + (hidden_states.shape[-2], hidden_states.shape[-1]), dtype=hidden_states.dtype + ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated @@ -273,19 +278,28 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) if tokens_per_expert[expert_idx] <= 0.1: - continue - current_state = hidden_states[idx, None].reshape([-1, hidden_dim]) - current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1) - final_hidden_states.index_add_( - index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) - ) + if self.training and paddle.is_grad_enabled(): + fake_top_x = paddle.zeros(1, dtype=paddle.int64) + fakse_current_state = hidden_states[fake_top_x, None].reshape([-1, hidden_states.shape[-1]]) + fake_state = expert_layer(fakse_current_state * 0) + final_hidden_states.index_add_(index=fake_top_x, axis=0, value=fake_state.to(hidden_states.dtype)) + else: + continue + else: + current_state = hidden_states[idx, None].reshape([-1, hidden_states.shape[-1]]) + current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1) + final_hidden_states.index_add_( + index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) + ) + if self.sequence_parallel: + final_hidden_states = ScatterOp.apply(final_hidden_states) + final_hidden_states = paddle.reshape(final_hidden_states, orig_shape) - shared_expert_output = self.shared_expert(hidden_states) - shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + shared_expert_output = self.shared_expert(residuals) + shared_expert_output = F.sigmoid(self.shared_expert_gate(residuals)) * shared_expert_output final_hidden_states = final_hidden_states + shared_expert_output - final_hidden_states = final_hidden_states.reshape([batch_size, sequence_length, hidden_dim]) return final_hidden_states, router_logits diff --git a/paddleformers/transformers/qwen3_moe/modeling.py b/paddleformers/transformers/qwen3_moe/modeling.py index 01b4ad3a9cb..d7ffbafec31 100644 --- a/paddleformers/transformers/qwen3_moe/modeling.py +++ b/paddleformers/transformers/qwen3_moe/modeling.py @@ -15,6 +15,7 @@ """Paddle Qwen3Moe model.""" from __future__ import annotations +import copy from functools import partial from typing import List, Optional, Tuple, Union @@ -22,7 +23,7 @@ import paddle.nn.functional as F from paddle import Tensor, nn from paddle.distributed.fleet.utils import recompute -from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp from ...nn.attention.interface import ALL_ATTENTION_FUNCTIONS from ...nn.criterion.interface import CriterionLayer @@ -239,6 +240,10 @@ def __init__(self, config): self.num_experts = config.num_experts self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob + self.sequence_parallel = config.sequence_parallel + if self.sequence_parallel and config.tensor_parallel_degree > 1: + config = copy.deepcopy(config) + config.sequence_parallel = False # gating self.gate = GeneralLinear.create(config.hidden_size, config.num_experts, has_bias=False, linear_type="default") @@ -248,15 +253,11 @@ def __init__(self, config): def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: """ """ + orig_shape = hidden_states.shape + if self.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) - if self.config.sequence_parallel: - max_sequence_length = self.config.max_sequence_length - batch_size = hidden_states.shape[0] * self.config.tensor_parallel_degree // max_sequence_length - sequence_length = max_sequence_length - hidden_dim = hidden_states.shape[1] - else: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view([-1, hidden_dim]) + hidden_states = hidden_states.view([-1, hidden_states.shape[-1]]) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) @@ -268,7 +269,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - final_hidden_states = paddle.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype) + final_hidden_states = paddle.zeros( + (hidden_states.shape[-2], hidden_states.shape[-1]), dtype=hidden_states.dtype + ) # One hot encode the selected experts to create an expert mask # this will be used to easily index which expert is going to be sollicitated @@ -283,14 +286,23 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) if tokens_per_expert[expert_idx] <= 0.1: - continue - current_state = hidden_states[idx, None].reshape([-1, hidden_dim]) - current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1) - final_hidden_states.index_add_( - index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) - ) + if self.training and paddle.is_grad_enabled(): + fake_top_x = paddle.zeros(1, dtype=paddle.int64) + fakse_current_state = hidden_states[fake_top_x, None].reshape([-1, hidden_states.shape[-1]]) + fake_state = expert_layer(fakse_current_state * 0) + final_hidden_states.index_add_(index=fake_top_x, axis=0, value=fake_state.to(hidden_states.dtype)) + else: + continue + else: + current_state = hidden_states[idx, None].reshape([-1, hidden_states.shape[-1]]) + current_hidden_states = expert_layer(current_state) * routing_weights[idx, top_x].unsqueeze(-1) + final_hidden_states.index_add_( + index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) + ) - final_hidden_states = final_hidden_states.reshape([batch_size, sequence_length, hidden_dim]) + if self.sequence_parallel: + final_hidden_states = ScatterOp.apply(final_hidden_states) + final_hidden_states = paddle.reshape(final_hidden_states, orig_shape) return final_hidden_states, router_logits From 02f748a8abfe79ced5ce6358fc1d4d3acc0aa74f Mon Sep 17 00:00:00 2001 From: Ace-To-HYB Date: Mon, 20 Oct 2025 17:28:45 +0800 Subject: [PATCH 2/2] fix sp bug --- paddleformers/transformers/qwen2_moe/modeling.py | 9 +++++---- paddleformers/transformers/qwen3_moe/modeling.py | 6 ++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/paddleformers/transformers/qwen2_moe/modeling.py b/paddleformers/transformers/qwen2_moe/modeling.py index 40dc0684b49..2eca14f613d 100644 --- a/paddleformers/transformers/qwen2_moe/modeling.py +++ b/paddleformers/transformers/qwen2_moe/modeling.py @@ -245,10 +245,10 @@ def __init__(self, config): def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: """ """ - residuals = hidden_states - orig_shape = hidden_states.shape if self.sequence_parallel: hidden_states = GatherOp.apply(hidden_states) + residuals = hidden_states + orig_shape = hidden_states.shape hidden_states = hidden_states.view([-1, hidden_states.shape[-1]]) # router_logits: (batch * sequence_length, n_experts) @@ -291,8 +291,6 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: final_hidden_states.index_add_( index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) ) - if self.sequence_parallel: - final_hidden_states = ScatterOp.apply(final_hidden_states) final_hidden_states = paddle.reshape(final_hidden_states, orig_shape) shared_expert_output = self.shared_expert(residuals) @@ -300,6 +298,9 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: final_hidden_states = final_hidden_states + shared_expert_output + if self.sequence_parallel: + final_hidden_states = ScatterOp.apply(final_hidden_states) + return final_hidden_states, router_logits diff --git a/paddleformers/transformers/qwen3_moe/modeling.py b/paddleformers/transformers/qwen3_moe/modeling.py index d7ffbafec31..a49c15affdf 100644 --- a/paddleformers/transformers/qwen3_moe/modeling.py +++ b/paddleformers/transformers/qwen3_moe/modeling.py @@ -253,9 +253,9 @@ def __init__(self, config): def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: """ """ - orig_shape = hidden_states.shape if self.sequence_parallel: hidden_states = GatherOp.apply(hidden_states) + orig_shape = hidden_states.shape hidden_states = hidden_states.view([-1, hidden_states.shape[-1]]) # router_logits: (batch * sequence_length, n_experts) @@ -300,9 +300,11 @@ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) ) + final_hidden_states = paddle.reshape(final_hidden_states, orig_shape) + if self.sequence_parallel: final_hidden_states = ScatterOp.apply(final_hidden_states) - final_hidden_states = paddle.reshape(final_hidden_states, orig_shape) + return final_hidden_states, router_logits