Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add finetune fused & add mc2 #8139

Merged
merged 1 commit into from
Apr 3, 2024
Merged

Conversation

NINGBENZHE
Copy link
Contributor

No description provided.

Copy link

paddle-bot bot commented Mar 18, 2024

Thanks for your contribution!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MC2的单独放在一个文件?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个MC2属于sp的一部分,暂不移动?

@@ -228,6 +228,11 @@ def scaled_dot_product_attention(
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attention_mask = attention_mask.cast(alibi.dtype) + alibi
if get_env_device() == "npu":
if attention_mask is not None:
attention_mask = attention_mask.astype("bool")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个地方看下在外面判断attn_mask一次,不是bool类型才cast, 可以提高性能。看下不在flash_attn里面cast, 在传入Transformer之前判断并cast一次就行?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已解决

@@ -239,6 +244,7 @@ def scaled_dot_product_attention(
attention_mask is None,
True,
False,
is_casual
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议别判断casual了,SFT/lora容易出错, 直接给False。或者能确定casual能判断正确。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

@SylarTiaNII SylarTiaNII left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要修改


@dataclass
@add_start_docstrings(ModelArgument.__doc__)
class SFTModelArguments(ModelArgument):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体挪到arguments.py里面,注意pylint。

@@ -240,7 +240,7 @@ def scaled_dot_product_attention(
attention_mask is None,
True,
False,
False,
False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_causal_mask能修复吗?

if is_casual and alibi is None:
attention_mask = None
if get_env_device != "npu":
is_casual = is_casual_mask(attention_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果这里没问题,可以从这里下发一个变量传到FA的输入中。

@@ -0,0 +1,550 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么会新增文件?

Copy link

codecov bot commented Mar 30, 2024

Codecov Report

Attention: Patch coverage is 3.15789% with 92 lines in your changes are missing coverage. Please review.

Project coverage is 55.10%. Comparing base (2273ee7) to head (768f465).

❗ Current head 768f465 differs from pull request most recent head 150c01e. Consider uploading reports for the commit 150c01e to get more accurate results

Files Patch % Lines
...dlenlp/transformers/mc2_seqence_parallel_linear.py 0.00% 71 Missing ⚠️
paddlenlp/transformers/llama/modeling.py 12.50% 21 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8139      +/-   ##
===========================================
- Coverage    55.15%   55.10%   -0.05%     
===========================================
  Files          601      602       +1     
  Lines        91764    91850      +86     
===========================================
+ Hits         50611    50613       +2     
- Misses       41153    41237      +84     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

import os

import paddle
import paddle_custom_device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个默认不安装吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是NPU执行MC2才会引用到的文件

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我知道,可是其他地方import到这个东西的话,会报错。

ScatterOp,
mark_as_sequence_parallel_parameter,
)

if int(os.getenv("MC2", 0)):
Copy link
Collaborator

@gongweibao gongweibao Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

稍微复杂点?
FLAGS_NPU_MC2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断一下设备,然后判断FLAGS

@NINGBENZHE NINGBENZHE closed this Apr 2, 2024
@NINGBENZHE NINGBENZHE reopened this Apr 2, 2024
@NINGBENZHE NINGBENZHE force-pushed the develop branch 3 times, most recently from a356945 to 768f465 Compare April 3, 2024 08:39
@@ -102,6 +106,64 @@ class ModelArgument:
default=None, metadata={"help": "Build-in pretrained model name or the path to local model."}
)
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"})
tokenizer_name_or_path: Optional[str] = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续建议run_pretrain.py也能直接引入arguments的选项

@wawltor wawltor merged commit ae7dc15 into PaddlePaddle:develop Apr 3, 2024
5 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants