Skip to content

Commit

Permalink
opt xpu perf for deepseek
Browse files Browse the repository at this point in the history
  • Loading branch information
QingshuChen committed Feb 20, 2025
1 parent 8f4e0f0 commit e53d6b0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
19 changes: 10 additions & 9 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,11 +830,11 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
if self.q_lora_rank is None:
self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False)
else:
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)
self.q_a_proj = linear_utils.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)
self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False)
self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False)

self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
self.kv_a_proj_with_mqa = linear_utils.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False)
self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=False)

Expand All @@ -846,17 +846,17 @@ def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False):
else:
# for without tensor parallel
if self.q_lora_rank is None:
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False)
self.q_proj = linear_utils.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False)
else:
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)
self.q_a_proj = linear_utils.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)
self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank)
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False)
self.q_b_proj = linear_utils.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False)

self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
self.kv_a_proj_with_mqa = linear_utils.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank)
self.kv_b_proj = nn.Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False)
self.kv_b_proj = linear_utils.Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False)

self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
self.o_proj = linear_utils.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
# fmt: on

self._init_rope()
Expand Down Expand Up @@ -1293,6 +1293,7 @@ def _init_weights(self, layer):
mpu.VocabParallelEmbedding,
mpu.RowParallelLinear,
mpu.ColumnParallelLinear,
linear_utils.Linear,
linear_utils.RowSequenceParallelLinear,
linear_utils.ColumnSequenceParallelLinear,
),
Expand Down Expand Up @@ -1888,7 +1889,7 @@ def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = DeepseekV2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias_attr=False)
self.score = linear_utils.Linear(config.hidden_size, self.num_labels, bias_attr=False)

# Initialize weights and apply final processing
self.post_init()
Expand Down
20 changes: 18 additions & 2 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from paddle.distributed.communication import stream
from paddle.distributed.communication.group import Group

from paddlenlp.utils.tools import get_env_device
from .moe_gate import PretrainedMoEGate


Expand Down Expand Up @@ -174,6 +175,11 @@ def __init__(
self.all_to_all_dropout = all_to_all_dropout
self.enable_recompute = False

if get_env_device() == "xpu":
from paddle_xpu.layers.nn import xpu_matmul
self.xpu_matmul1 = xpu_matmul()
self.xpu_matmul2 = xpu_matmul()

self.experts = nn.LayerList([])
for i in range(self.moe_num_experts):
if i // self.moe_num_experts_per_device == self.moe_rank:
Expand Down Expand Up @@ -223,6 +229,7 @@ def forward(
self,
hidden_state: paddle.Tensor,
used_token: paddle.Tensor = None,
is_train=False
):
"""_summary_
Expand All @@ -247,7 +254,13 @@ def forward(
# combine_weights : sec
# dispatch_mask : sec
# self.exp_counts :
dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input)

if get_env_device() == "xpu":
dispatch_mask = paddle.cast(dispatch_mask, hidden_state.dtype)
dispatched_input = self.xpu_matmul1(dispatch_mask.reshape([dispatch_mask.shape[0], -1]), reshaped_input, transpose_x=True,
training=is_train)
else:
dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input)

if self.expert_parallel_degree > 1:
dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group)
Expand All @@ -266,7 +279,10 @@ def forward(
expert_output = _AllToAll.apply(expert_output, self.moe_group)

# combine withe expert weights
combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output)
if get_env_device() == "xpu":
combined_output = self.xpu_matmul2(combine_weights.reshape([combine_weights.shape[0], -1]).cast(hidden_state[0].dtype), expert_output.reshape([-1, expert_output.shape[-1]]), training=is_train)
else:
combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output)

a = combined_output.reshape(hidden_state.shape)

Expand Down

0 comments on commit e53d6b0

Please sign in to comment.