Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine deepseekv2 modeling for to_static #9851

Open
wants to merge 25 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
2f9956a
refine log
zhangbo9674 Nov 8, 2024
ee5f151
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
zhangbo9674 Nov 29, 2024
377962a
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
zhangbo9674 Nov 29, 2024
3aa70a8
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
zhangbo9674 Dec 6, 2024
3caaac7
Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleNLP i…
zhangbo9674 Feb 11, 2025
796bff7
fix
zhangbo9674 Feb 12, 2025
87c33ac
add model args
zhangbo9674 Feb 12, 2025
29b05b0
suppoort_deepseekv2_autoparallel_with_DP/MP
xuxinyi389 Feb 13, 2025
20be84b
poolish
xuxinyi389 Feb 13, 2025
31ec76c
remove_env_set
xuxinyi389 Feb 13, 2025
34009e7
update_code
xuxinyi389 Feb 13, 2025
e553a3a
add_v3
xuxinyi389 Feb 14, 2025
0c96a56
support_sharding
xuxinyi389 Feb 14, 2025
0aa9fb0
move_to_v3
xuxinyi389 Feb 14, 2025
3e84fc6
fix_typo
xuxinyi389 Feb 14, 2025
495d123
update_v3_config
xuxinyi389 Feb 14, 2025
14470cf
Merge commit 'refs/pull/9862/head' of https://github.com/PaddlePaddle…
zhangbo9674 Feb 17, 2025
5835f1e
refine
zhangbo9674 Feb 17, 2025
909ffe8
refine
zhangbo9674 Feb 17, 2025
93cfe79
refine
zhangbo9674 Feb 17, 2025
c48d6c2
fix
zhangbo9674 Feb 18, 2025
7f7a486
fix
zhangbo9674 Feb 19, 2025
428bd09
fix
zhangbo9674 Feb 19, 2025
41b9107
fix
zhangbo9674 Feb 24, 2025
4b83dee
fix
zhangbo9674 Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,26 @@ class ModelArguments:
default=None,
metadata={"help": "num_hidden_layers."},
)
first_k_dense_replace: Optional[int] = field(
default=None,
metadata={"help": "first_k_dense_replace."},
)
n_routed_experts: Optional[int] = field(
default=None,
metadata={"help": "n_routed_experts."},
)
num_experts_per_tok: Optional[int] = field(
default=None,
metadata={"help": "num_experts_per_tok."},
)
hidden_size: Optional[int] = field(
default=None,
metadata={"help": "hidden_size."},
)
topk_group: Optional[int] = field(
default=None,
metadata={"help": "topk_group."},
)


def create_pretrained_dataset(
Expand Down Expand Up @@ -418,6 +438,23 @@ def main():
config.num_hidden_layers = (
model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers
)
config.num_hidden_layers = (
model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers
)
config.first_k_dense_replace = (
model_args.first_k_dense_replace
if model_args.first_k_dense_replace is not None
else config.first_k_dense_replace
)
config.n_routed_experts = (
model_args.n_routed_experts if model_args.n_routed_experts is not None else config.n_routed_experts
)
config.num_experts_per_tok = (
model_args.num_experts_per_tok if model_args.num_experts_per_tok is not None else config.num_experts_per_tok
)
config.hidden_size = model_args.hidden_size if model_args.hidden_size is not None else config.hidden_size
config.topk_group = model_args.topk_group if model_args.topk_group is not None else config.topk_group

# Config for model using dropout, such as GPT.
if hasattr(config, "hidden_dropout_prob"):
config.hidden_dropout_prob = model_args.hidden_dropout_prob
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@

t = paddle.arange(seq_len, dtype=paddle.float32)

freqs = paddle.outer(t, self.inv_freq)
freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32"))

Check warning on line 546 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L546

Added line #L546 was not covered by tests

_mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale)
Expand Down
6 changes: 1 addition & 5 deletions paddlenlp/transformers/deepseek_v2/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
_make_causal_mask,
apply_rotary_pos_emb,
get_triangle_upper_mask,
is_casual_mask,
yarn_get_mscale,
)

Expand Down Expand Up @@ -705,7 +704,7 @@
inputs_embeds = self.embed_tokens(input_ids)

# embed positions
if attn_mask_startend_row_indices is not None or get_use_casual_mask():
if attn_mask_startend_row_indices is not None or get_use_casual_mask() or self.config.use_flash_attention:

Check warning on line 707 in paddlenlp/transformers/deepseek_v2/modeling_auto.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_auto.py#L707

Added line #L707 was not covered by tests
attention_mask = None
else:
# [bs, seq_len]
Expand All @@ -717,9 +716,6 @@
attention_mask = self._prepare_decoder_attention_mask(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果attention_mask是支持use_cache的版本,那么就不是casual_mask,或者推理时使用left-padding,那么attention_mask也不是casual_mask,这里的修改不能覆盖之前的情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的改动,主要是因为动转静不支持如下场景:

if self.config.use_flash_attention:
    attention_mask = None if is_casual_mask(attention_mask) else attention_mask

当前这样该是基于目前的场景下 use_flash_attention 下 attention_mask 一定为 None,还有一种改法,就是将控制流判断后移到调用的地方,但是目前这种情况下,自动并行的切分推导对控制流的场景支持还存在一些问题,因此为了不影响后续流程,先按照前面的改法实现

if is_casual_mask(attention_mask):
    layer_outputs = decoder_layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
                attention_mask=None,
                output_attentions=output_attentions,
                past_key_value=past_key_value,
                use_cache=use_cache,
                attn_mask_startend_row_indices=attn_mask_startend_row_indices)
else:
    layer_outputs = decoder_layer(
                hidden_states=hidden_states,
                position_ids=position_ids,
                attention_mask=attention_mask,
                output_attentions=output_attentions,
                past_key_value=past_key_value,
                use_cache=use_cache,
                attn_mask_startend_row_indices=attn_mask_startend_row_indices)

attention_mask, (batch_size, seq_length), past_key_values_length, inputs_embeds.dtype
) # [bs, 1, seq_len, seq_len]
if self.config.use_flash_attention:
attention_mask = None if is_casual_mask(attention_mask) else attention_mask

# embed positions
hidden_states = inputs_embeds

Expand Down
Loading