-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RingFlashAttention for context parallel #8383
Add RingFlashAttention for context parallel #8383
Conversation
Thanks for your contribution! |
09fd62e
to
fbd16a1
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8383 +/- ##
===========================================
- Coverage 54.22% 53.87% -0.35%
===========================================
Files 617 620 +3
Lines 96203 97068 +865
===========================================
+ Hits 52164 52295 +131
- Misses 44039 44773 +734 ☔ View full report in Codecov by Sentry. |
|
||
# if step != cp_size - 1: | ||
# comm_buffer.wait() | ||
paddle.device.synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO:batch_isend_irecv异步流下,无法wait,需要修复。对性能有影响。
done~
f94a915
to
4e88520
Compare
cf7d334
to
88bc460
Compare
block_out, _, block_lse, _ = _C_ops.flash_attn( | ||
local_query, | ||
block_k[:, : local_q_seq_len // 2, :, :], | ||
block_v[:, : local_q_seq_len // 2, :, :], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种方式,性能可能比较慢。看看能否直接使用op的方式调用。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种方式,性能可能比较慢。看看能否直接使用op的方式调用。
done~
if attn_mask is not None: | ||
attn_masks_list = paddle.split(attn_mask, num_or_sections=cp_size * 2, axis=3) | ||
if is_causal: | ||
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contiguous ? 可能不需要这个。尽量使用切分的api,不实用运算符重载。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contiguous ? 可能不需要这个。尽量使用切分的api,不实用运算符重载。
done~
grad_comm_buffer = RingCommunicator(group, key_grad_buffer, value_grad_buffer) | ||
|
||
if is_causal: | ||
local_query_second_chunk = local_query[:, local_q_seq_len // 2 :, :, :].clone().contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个前向已经计算过了,是否可以优化不计算。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个前向已经计算过了,是否可以优化不计算。
已做优化
def wait(self): | ||
# for req in self._reqs: | ||
# req.wait() | ||
# self._reqs = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成TODO吧。不用注释。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成TODO吧。不用注释。
done~
a360468
to
ab562b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
paddlenlp/trainer/training_args.py
Outdated
@@ -583,6 +587,15 @@ class TrainingArguments: | |||
) | |||
}, | |||
) | |||
cp_parallel_degree: int = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
换成 context_parallel_degree
?
cp_parallel_degree: int = field( | |
context_parallel_degree: int = field( |
paddlenlp/trainer/trainer.py
Outdated
@@ -763,6 +764,8 @@ def train( | |||
trainable_numel = int(trainable_numel_tensor.item()) // self.args.dataset_world_size | |||
if self.args.sep_parallel_degree > 0: | |||
trainable_numel = trainable_numel // self.args.sep_parallel_degree | |||
if self.args.cp_parallel_degree > 0: | |||
trainable_numel = trainable_numel // self.args.cp_parallel_degree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cp_parallel_degree 会切分哪些参数?
paddlenlp/trainer/training_args.py
Outdated
@@ -230,6 +230,10 @@ class TrainingArguments: | |||
The paddle sequence parallel strategy. It can reduce the GPU memory of activation to 1/sep, and it is orthogonal to | |||
data parallel, sharding stage1, tensor parallel and pipeline parallel strategy. | |||
) | |||
cp_parallel_degree (`int`, *optional*, defaults to `-1`)( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个参数在 docs/trainer.md 文档中也加一下吧。
self.tensor_parallel_degree | ||
* self.sep_parallel_degree | ||
* self.cp_parallel_degree | ||
* self.pipeline_parallel_degree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
保存相关的考虑了吗?通信组需要额外建吗?
@@ -918,6 +931,7 @@ | |||
if world_size > 1: | |||
tensor_parallel_degree = max(self.tensor_parallel_degree, 1) | |||
sep_parallel_degree = max(self.sep_parallel_degree, 1) | |||
context_parallel_degree = max(self.context_parallel_degree, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我再问一下,context parellel 和 seq parallel是不是互斥的,需不需要加一个判断?还是可以一起用
@@ -897,6 +900,8 @@ | |||
for step, inputs in enumerate(epoch_iterator): | |||
if self.args.use_hybrid_parallel and self.args.sep_parallel_degree > 1: | |||
inputs = split_inputs_sequence_dim(inputs) | |||
if self.args.use_hybrid_parallel and self.args.context_parallel_degree > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
em,是不是 开了 cp 的话,相当于是 一路数据流,现在对应多份完整参数了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cp->2
tp->2
4card。 两份参数,1路数据流
for- back
两份参数 -> grad ? grad sum?
热启动模型对齐精度。第一step精度,看二step loss diff |
PR types
New features
PR changes
Models
Description
为fleet的context parallel增加ring flash attention的支持
paddle兼容性:
使用paddle中的sep group,对paddle无改动
收敛性:

将cp和sep做对比。理论上,二者的收敛结果应该完全一致。经过测试,sep和cp的收敛情况近乎一致。绿色为cp,蓝色为sep。
性能:

单机8卡小模型测试,序列长度为20k时,性能对比如图。绿色为cp,蓝色为sep。