-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add finetune fused & add mc2 #8139
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MC2的单独放在一个文件?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个MC2属于sp的一部分,暂不移动?
@@ -228,6 +228,11 @@ def scaled_dot_product_attention( | |||
alibi = alibi.reshape([bsz, num_heads, 1, -1]) | |||
attention_mask = attention_mask.cast(alibi.dtype) + alibi | |||
if get_env_device() == "npu": | |||
if attention_mask is not None: | |||
attention_mask = attention_mask.astype("bool") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方看下在外面判断attn_mask一次,不是bool类型才cast, 可以提高性能。看下不在flash_attn里面cast, 在传入Transformer之前判断并cast一次就行?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已解决
@@ -239,6 +244,7 @@ def scaled_dot_product_attention( | |||
attention_mask is None, | |||
True, | |||
False, | |||
is_casual |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里建议别判断casual了,SFT/lora容易出错, 直接给False。或者能确定casual能判断正确。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要修改
llm/finetune_generation.py
Outdated
|
||
@dataclass | ||
@add_start_docstrings(ModelArgument.__doc__) | ||
class SFTModelArguments(ModelArgument): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
整体挪到arguments.py里面,注意pylint。
@@ -240,7 +240,7 @@ def scaled_dot_product_attention( | |||
attention_mask is None, | |||
True, | |||
False, | |||
False, | |||
False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_causal_mask能修复吗?
if is_casual and alibi is None: | ||
attention_mask = None | ||
if get_env_device != "npu": | ||
is_casual = is_casual_mask(attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果这里没问题,可以从这里下发一个变量传到FA的输入中。
@@ -0,0 +1,550 @@ | |||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么会新增文件?
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8139 +/- ##
===========================================
- Coverage 55.15% 55.10% -0.05%
===========================================
Files 601 602 +1
Lines 91764 91850 +86
===========================================
+ Hits 50611 50613 +2
- Misses 41153 41237 +84 ☔ View full report in Codecov by Sentry. |
import os | ||
|
||
import paddle | ||
import paddle_custom_device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个默认不安装吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是NPU执行MC2才会引用到的文件
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我知道,可是其他地方import到这个东西的话,会报错。
ScatterOp, | ||
mark_as_sequence_parallel_parameter, | ||
) | ||
|
||
if int(os.getenv("MC2", 0)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
稍微复杂点?
FLAGS_NPU_MC2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
判断一下设备,然后判断FLAGS
a356945
to
768f465
Compare
@@ -102,6 +106,64 @@ class ModelArgument: | |||
default=None, metadata={"help": "Build-in pretrained model name or the path to local model."} | |||
) | |||
use_flash_attention: bool = field(default=False, metadata={"help": "Whether to use flash attention"}) | |||
tokenizer_name_or_path: Optional[str] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续建议run_pretrain.py也能直接引入arguments的选项
No description provided.