diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 967fb222b2e3..3b2e4421ba23 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -830,11 +830,11 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): if self.q_lora_rank is None: self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False) else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_a_proj = linear_utils.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False) self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False) - self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_a_proj_with_mqa = linear_utils.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False) self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=False) @@ -846,17 +846,17 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): else: # for without tensor parallel if self.q_lora_rank is None: - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False) + self.q_proj = linear_utils.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False) else: - self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_a_proj = linear_utils.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank) - self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False) + self.q_b_proj = linear_utils.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False) - self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_a_proj_with_mqa = linear_utils.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank) - self.kv_b_proj = nn.Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False) + self.kv_b_proj = linear_utils.Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False) - self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + self.o_proj = linear_utils.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) # fmt: on self._init_rope() @@ -1293,6 +1293,7 @@ def _init_weights(self, layer): mpu.VocabParallelEmbedding, mpu.RowParallelLinear, mpu.ColumnParallelLinear, + linear_utils.Linear, linear_utils.RowSequenceParallelLinear, linear_utils.ColumnSequenceParallelLinear, ), @@ -1888,7 +1889,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.model = DeepseekV2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias_attr=False) + self.score = linear_utils.Linear(config.hidden_size, self.num_labels, bias_attr=False) # Initialize weights and apply final processing self.post_init() diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 90d4feae6c72..09bed4f29756 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -23,6 +23,8 @@ from paddle.distributed.communication import stream from paddle.distributed.communication.group import Group +from paddlenlp.utils.tools import get_env_device + from .moe_gate import PretrainedMoEGate @@ -219,11 +221,7 @@ def expert_forward(self, dispatched_input): expert_output = paddle.stack(expert_outputs, axis=1) # [ecm] return expert_output - def forward( - self, - hidden_state: paddle.Tensor, - used_token: paddle.Tensor = None, - ): + def forward(self, hidden_state: paddle.Tensor, used_token: paddle.Tensor = None, is_train=False): """_summary_ Args: @@ -247,7 +245,21 @@ def forward( # combine_weights : sec # dispatch_mask : sec # self.exp_counts : - dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input) + + if get_env_device() == "xpu": + dispatch_mask = paddle.cast(dispatch_mask, hidden_state.dtype) + from paddle_xpu.layers.nn import xpu_matmul + + dispatched_input = xpu_matmul()( + dispatch_mask.reshape([dispatch_mask.shape[0], -1]), + reshaped_input, + transpose_x=True, + training=is_train, + ) + else: + dispatched_input = paddle.einsum( + "sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input + ) if self.expert_parallel_degree > 1: dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group) @@ -266,7 +278,16 @@ def forward( expert_output = _AllToAll.apply(expert_output, self.moe_group) # combine withe expert weights - combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) + if get_env_device() == "xpu": + from paddle_xpu.layers.nn import xpu_matmul + + combined_output = xpu_matmul()( + combine_weights.reshape([combine_weights.shape[0], -1]).cast(hidden_state[0].dtype), + expert_output.reshape([-1, expert_output.shape[-1]]), + training=is_train, + ) + else: + combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output) a = combined_output.reshape(hidden_state.shape)