-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Sharding Overlap #8473
Support Sharding Overlap #8473
Conversation
Thanks for your contribution! |
if self.config.use_flash_attention and get_env_device() != "gcu": | ||
is_casual = is_casual_mask(attention_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要删除这个。if hasattr(self.config, "casual_mask")
paddlenlp/trainer/trainer.py
Outdated
@@ -1908,6 +1907,13 @@ def get_expected_keys(inputs, keys): | |||
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) | |||
self.optimizer = fleet.distributed_optimizer(self.optimizer) | |||
|
|||
if in_sharding_parallel_mode: | |||
sharding_parallel_config = set(self.args.sharding_parallel_config.split(" ")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
training_args 文件里面处理好,就别在这里split了
llm/run_pretrain.py
Outdated
@@ -223,6 +223,10 @@ class ModelArguments: | |||
default=None, | |||
metadata={"help": "num_hidden_layers."}, | |||
) | |||
casual_mask: Optional[bool] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
casual_mask: Optional[bool] = field( | |
use_casual_mask: Optional[bool] = field( |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #8473 +/- ##
===========================================
- Coverage 54.29% 54.25% -0.05%
===========================================
Files 617 617
Lines 96339 96368 +29
===========================================
- Hits 52312 52288 -24
- Misses 44027 44080 +53 ☔ View full report in Codecov by Sentry. |
paddlenlp/trainer/trainer.py
Outdated
@@ -1908,6 +1907,12 @@ def get_expected_keys(inputs, keys): | |||
self.optimizer = mix_precision_utils.MixPrecisionOptimizer(self.optimizer) | |||
self.optimizer = fleet.distributed_optimizer(self.optimizer) | |||
|
|||
if in_sharding_parallel_mode: | |||
if "split_param" in self.args.sharding_parallel_config: | |||
self.optimizer._set_all_gather_overlap_forward(True, model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个接口需要考虑版本兼容不?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This reverts commit 7aaa788.
PR types
Performance optimization
PR changes
Models
Description
1.支持sharding overlap