Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opt xpu perf for deepseek #9916

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,11 +830,11 @@
if self.q_lora_rank is None:
self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False)
else:
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)
self.q_a_proj = linear_utils.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)

Check warning on line 833 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L833

Added line #L833 was not covered by tests
self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False)
self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=False)

self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
self.kv_a_proj_with_mqa = linear_utils.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)

Check warning on line 837 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L837

Added line #L837 was not covered by tests
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False)
self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=False)

Expand All @@ -846,17 +846,17 @@
else:
# for without tensor parallel
if self.q_lora_rank is None:
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False)
self.q_proj = linear_utils.Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False)
else:
self.q_a_proj = nn.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)
self.q_a_proj = linear_utils.Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias)

Check warning on line 851 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L851

Added line #L851 was not covered by tests
self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank)
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False)
self.q_b_proj = linear_utils.Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False)

self.kv_a_proj_with_mqa = nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)
self.kv_a_proj_with_mqa = linear_utils.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias)

Check warning on line 855 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L855

Added line #L855 was not covered by tests
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank)
self.kv_b_proj = nn.Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False)
self.kv_b_proj = linear_utils.Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False)

self.o_proj = nn.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
self.o_proj = linear_utils.Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
# fmt: on

self._init_rope()
Expand Down Expand Up @@ -1293,6 +1293,7 @@
mpu.VocabParallelEmbedding,
mpu.RowParallelLinear,
mpu.ColumnParallelLinear,
linear_utils.Linear,
linear_utils.RowSequenceParallelLinear,
linear_utils.ColumnSequenceParallelLinear,
),
Expand Down Expand Up @@ -1888,7 +1889,7 @@
super().__init__(config)
self.num_labels = config.num_labels
self.model = DeepseekV2Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias_attr=False)
self.score = linear_utils.Linear(config.hidden_size, self.num_labels, bias_attr=False)

Check warning on line 1892 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1892

Added line #L1892 was not covered by tests

# Initialize weights and apply final processing
self.post_init()
Expand Down
35 changes: 28 additions & 7 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from paddle.distributed.communication import stream
from paddle.distributed.communication.group import Group

from paddlenlp.utils.tools import get_env_device

from .moe_gate import PretrainedMoEGate


Expand Down Expand Up @@ -219,11 +221,7 @@
expert_output = paddle.stack(expert_outputs, axis=1) # [ecm]
return expert_output

def forward(
self,
hidden_state: paddle.Tensor,
used_token: paddle.Tensor = None,
):
def forward(self, hidden_state: paddle.Tensor, used_token: paddle.Tensor = None, is_train=False):
"""_summary_

Args:
Expand All @@ -247,7 +245,21 @@
# combine_weights : sec
# dispatch_mask : sec
# self.exp_counts :
dispatched_input = paddle.einsum("sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input)

if get_env_device() == "xpu":
dispatch_mask = paddle.cast(dispatch_mask, hidden_state.dtype)
from paddle_xpu.layers.nn import xpu_matmul

Check warning on line 251 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L250-L251

Added lines #L250 - L251 were not covered by tests

dispatched_input = xpu_matmul()(

Check warning on line 253 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L253

Added line #L253 was not covered by tests
dispatch_mask.reshape([dispatch_mask.shape[0], -1]),
reshaped_input,
transpose_x=True,
training=is_train,
)
else:
dispatched_input = paddle.einsum(
"sec,sm->ecm", paddle.cast(dispatch_mask, hidden_state.dtype), reshaped_input
)

if self.expert_parallel_degree > 1:
dispatched_input = _AllToAll.apply(dispatched_input, self.moe_group)
Expand All @@ -266,7 +278,16 @@
expert_output = _AllToAll.apply(expert_output, self.moe_group)

# combine withe expert weights
combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output)
if get_env_device() == "xpu":
from paddle_xpu.layers.nn import xpu_matmul

Check warning on line 282 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L282

Added line #L282 was not covered by tests

combined_output = xpu_matmul()(

Check warning on line 284 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L284

Added line #L284 was not covered by tests
combine_weights.reshape([combine_weights.shape[0], -1]).cast(hidden_state[0].dtype),
expert_output.reshape([-1, expert_output.shape[-1]]),
training=is_train,
)
else:
combined_output = paddle.einsum("sec,ecm->sm", combine_weights.cast(hidden_state[0].dtype), expert_output)

a = combined_output.reshape(hidden_state.shape)

Expand Down
Loading