-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP + Tensor Parallelism #7897
Conversation
jenkins |
8993f8c
to
fd08420
Compare
jenkins |
1 similar comment
jenkins |
@erhoo82 Thanks for hard work. Do you have timeline in mind when the PR will get merged? Thanks! |
@stu1130 Have been caught up by high priority work. Trying to close soon. |
Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Omit parameters with sequence-parallel updates and tensor-parallel duplicates from FSDP sharding Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci merge the list of grads to clip Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci clean up Signed-off-by: Sangkug Lym <[email protected]> pass cpu_initialization to TE transformer layer Signed-off-by: Sangkug Lym <[email protected]> clean up Signed-off-by: Sangkug Lym <[email protected]> Support FSDP sharded checkpoint Signed-off-by: Sangkug Lym <[email protected]> Do not use DistCkpt when using FSDP Signed-off-by: Sangkug Lym <[email protected]> cleanup Signed-off-by: Sangkug Lym <[email protected]> cleanup Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci
Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Sangkug Lym <[email protected]>
dc8dba7
to
b738f9d
Compare
|
Signed-off-by: Sangkug Lym <[email protected]>
jenkins |
@athitten Is the below error relevant? |
jenkins |
Signed-off-by: Abhishree Thittenamane <[email protected]>
jenkins |
for more information, see https://pre-commit.ci
@@ -64,17 +82,55 @@ | |||
optim_state_to_sharding_state, | |||
) | |||
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module | |||
from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@erhoo82 @ericharper do we want to move this to a separate try except block ? Cause we would have HAVE_MEGATRON_CORE=False
if the import fails with TransformerEngine not found error
, although megatron core would be available and can cause error for blocks that depend just on mcore and not TE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can leave it as is for now. There is a separate effort to improve the import guarding.
jenkins |
Signed-off-by: Abhishree Thittenamane <[email protected]>
jenkins |
@@ -728,6 +731,21 @@ def allreduce_sequence_parallel_gradients(self): | |||
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): | |||
buf.copy_(synced) | |||
|
|||
def allreduce_fsdp_sharding_omitted_gradients(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't have to be this PR, but shouldn't these calls go into an fsdp strategy rather than the model code? Maybe @athitten , you can address this is in a future PR.
@@ -54,7 +54,8 @@ | |||
HAVE_MEGATRON_CORE = False | |||
|
|||
|
|||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): | |||
@torch.no_grad() | |||
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, use_fsdp=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@athitten , we should be getting this from mcore. (for future pr)
sharding_strategy: FSDP parameter sharding strategy. | ||
grad_reduce_dtype: Data type for FSDP gradient shard ReduceScatter. | ||
sharded_checkpoint: Store/load FSDP-sharded checkpoints. | ||
precision: Precision recipe to be used with FSDP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can go in a future PR, but I'd like to see more detail here. Like what are the possible values and also links to more information.
if precision in ["16-true", "16-mixed", 16]: | ||
param_dtype = reduce_dtype = buffer_dtype = torch.float16 | ||
elif precision in ["bf16-true", "bf16-mixed", "bf16"]: | ||
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16 | ||
elif precision == 32: | ||
param_dtype = reduce_dtype = buffer_dtype = torch.float | ||
else: | ||
raise ValueError(f"Was unable to infer precision type, received {precision!r}.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again for a future PR but we need to use a utility function for this that can be shared across all of nemo.
return state_dict_type_context | ||
|
||
|
||
class NLPFSDPStrategy(FSDPStrategy): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also noting here for a future PR that nlp_overrides.py is too long now and we should refactor it a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
* FSDP with Tensor Parallelism Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Omit parameters with sequence-parallel updates and tensor-parallel duplicates from FSDP sharding Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci merge the list of grads to clip Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci clean up Signed-off-by: Sangkug Lym <[email protected]> pass cpu_initialization to TE transformer layer Signed-off-by: Sangkug Lym <[email protected]> clean up Signed-off-by: Sangkug Lym <[email protected]> Support FSDP sharded checkpoint Signed-off-by: Sangkug Lym <[email protected]> Do not use DistCkpt when using FSDP Signed-off-by: Sangkug Lym <[email protected]> cleanup Signed-off-by: Sangkug Lym <[email protected]> cleanup Signed-off-by: Sangkug Lym <[email protected]> * torch FSDP patch Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup Signed-off-by: Sangkug Lym <[email protected]> * Remove torch FSDP patch Signed-off-by: Sangkug Lym <[email protected]> * fix FSDP mixed precision settings Signed-off-by: Sangkug Lym <[email protected]> * Disable pipelined tensor-parallel communication overlap when using FSDP Signed-off-by: Sangkug Lym <[email protected]> * Add import of MCoreTransformerLayer to try, except Signed-off-by: Abhishree Thittenamane <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym <[email protected]> Signed-off-by: Abhishree Thittenamane <[email protected]> Co-authored-by: Abhishree Thittenamane <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Piotr Żelasko <[email protected]>
* FSDP with Tensor Parallelism Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Omit parameters with sequence-parallel updates and tensor-parallel duplicates from FSDP sharding Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci merge the list of grads to clip Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci clean up Signed-off-by: Sangkug Lym <[email protected]> pass cpu_initialization to TE transformer layer Signed-off-by: Sangkug Lym <[email protected]> clean up Signed-off-by: Sangkug Lym <[email protected]> Support FSDP sharded checkpoint Signed-off-by: Sangkug Lym <[email protected]> Do not use DistCkpt when using FSDP Signed-off-by: Sangkug Lym <[email protected]> cleanup Signed-off-by: Sangkug Lym <[email protected]> cleanup Signed-off-by: Sangkug Lym <[email protected]> * torch FSDP patch Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup Signed-off-by: Sangkug Lym <[email protected]> * Remove torch FSDP patch Signed-off-by: Sangkug Lym <[email protected]> * fix FSDP mixed precision settings Signed-off-by: Sangkug Lym <[email protected]> * Disable pipelined tensor-parallel communication overlap when using FSDP Signed-off-by: Sangkug Lym <[email protected]> * Add import of MCoreTransformerLayer to try, except Signed-off-by: Abhishree Thittenamane <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym <[email protected]> Signed-off-by: Abhishree Thittenamane <[email protected]> Co-authored-by: Abhishree Thittenamane <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Sasha Meister <[email protected]>
* FSDP with Tensor Parallelism Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Omit parameters with sequence-parallel updates and tensor-parallel duplicates from FSDP sharding Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci merge the list of grads to clip Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci clean up Signed-off-by: Sangkug Lym <[email protected]> pass cpu_initialization to TE transformer layer Signed-off-by: Sangkug Lym <[email protected]> clean up Signed-off-by: Sangkug Lym <[email protected]> Support FSDP sharded checkpoint Signed-off-by: Sangkug Lym <[email protected]> Do not use DistCkpt when using FSDP Signed-off-by: Sangkug Lym <[email protected]> cleanup Signed-off-by: Sangkug Lym <[email protected]> cleanup Signed-off-by: Sangkug Lym <[email protected]> * torch FSDP patch Signed-off-by: Sangkug Lym <[email protected]> [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleanup Signed-off-by: Sangkug Lym <[email protected]> * Remove torch FSDP patch Signed-off-by: Sangkug Lym <[email protected]> * fix FSDP mixed precision settings Signed-off-by: Sangkug Lym <[email protected]> * Disable pipelined tensor-parallel communication overlap when using FSDP Signed-off-by: Sangkug Lym <[email protected]> * Add import of MCoreTransformerLayer to try, except Signed-off-by: Abhishree Thittenamane <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sangkug Lym <[email protected]> Signed-off-by: Abhishree Thittenamane <[email protected]> Co-authored-by: Abhishree Thittenamane <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
What does this PR do ?
This supports torch FSDP training with Tensor-parallelism.
Set
model.fsdp=True
to enable FSDP. Gradient reduction data type can be set asmodel.fsdp_grad_reduce_dtype=bf16
.FSDP mode is not compatible with the Distributed Optimizer and O2 optimizer wrapper, thus should set
model.megatron_amp_O2=False
and not usemodel.optim.name=distributed_fused_adam
.To store checkpoint in FSDP sharded format, set model.fsdp_sharded_checkpoint=True.
Changelog
Usage
# Add a code snippet demonstrating how to use this
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information