Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP + Tensor Parallelism #7897

Merged
merged 10 commits into from
Dec 16, 2023
Merged

Conversation

erhoo82
Copy link
Collaborator

@erhoo82 erhoo82 commented Nov 16, 2023

What does this PR do ?

This supports torch FSDP training with Tensor-parallelism.

Set model.fsdp=True to enable FSDP. Gradient reduction data type can be set as model.fsdp_grad_reduce_dtype=bf16.

FSDP mode is not compatible with the Distributed Optimizer and O2 optimizer wrapper, thus should set model.megatron_amp_O2=False and not use model.optim.name=distributed_fused_adam.

To store checkpoint in FSDP sharded format, set model.fsdp_sharded_checkpoint=True.

Changelog

  • Add specific line by line info of high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

@github-actions github-actions bot added the NLP label Nov 16, 2023
@athitten
Copy link
Collaborator

jenkins

@erhoo82 erhoo82 force-pushed the slym/fsdp_patch_merge branch 2 times, most recently from 8993f8c to fd08420 Compare November 22, 2023 22:40
nemo/collections/nlp/parts/nlp_overrides.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
nemo/collections/nlp/parts/fsdp_patch.py Fixed Show fixed Hide fixed
@athitten
Copy link
Collaborator

jenkins

1 similar comment
@ericharper
Copy link
Collaborator

jenkins

@stu1130
Copy link

stu1130 commented Dec 4, 2023

@erhoo82 Thanks for hard work. Do you have timeline in mind when the PR will get merged? Thanks!

@erhoo82
Copy link
Collaborator Author

erhoo82 commented Dec 9, 2023

@stu1130 Have been caught up by high priority work. Trying to close soon.

@erhoo82 erhoo82 changed the title FSDP + Tensor Parallelism with patch FSDP + Tensor Parallelism Dec 9, 2023
Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Omit parameters with sequence-parallel updates and tensor-parallel duplicates from FSDP sharding

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

merge the list of grads to clip

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

clean up

Signed-off-by: Sangkug Lym <[email protected]>

pass cpu_initialization to TE transformer layer

Signed-off-by: Sangkug Lym <[email protected]>

clean up

Signed-off-by: Sangkug Lym <[email protected]>

Support FSDP sharded checkpoint

Signed-off-by: Sangkug Lym <[email protected]>

Do not use DistCkpt when using FSDP

Signed-off-by: Sangkug Lym <[email protected]>

cleanup

Signed-off-by: Sangkug Lym <[email protected]>

cleanup

Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Sangkug Lym <[email protected]>
@erhoo82
Copy link
Collaborator Author

erhoo82 commented Dec 10, 2023

@athitten, @ericharper

  • I have removed the torch FSDP patch file as the bug was fixed at PyT23.09 container.
  • I have fixed some confusion in the mixed precision definition. Currently, NeMo sets the master param data type as FP32 when megatron_O2=False, which is the case of FSDP. I removed explicit model data type casting to FP32 after FSDP because it is already FP32 when megatron_O2=False. Also, param_dtype in FSDP sets only the compute data type in forward and backward propagation, which matches the original intention of FSDP design.

@athitten
Copy link
Collaborator

jenkins

@erhoo82
Copy link
Collaborator Author

erhoo82 commented Dec 12, 2023

@athitten Is the below error relevant?
ModuleNotFoundError: No module named 'transformer_engine'

@athitten
Copy link
Collaborator

jenkins

@athitten
Copy link
Collaborator

jenkins

@@ -64,17 +82,55 @@
optim_state_to_sharding_state,
)
from megatron.core.transformer.module import Float16Module as MCoreFloat16Module
from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@erhoo82 @ericharper do we want to move this to a separate try except block ? Cause we would have HAVE_MEGATRON_CORE=False if the import fails with TransformerEngine not found error, although megatron core would be available and can cause error for blocks that depend just on mcore and not TE.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can leave it as is for now. There is a separate effort to improve the import guarding.

@athitten
Copy link
Collaborator

jenkins

@athitten
Copy link
Collaborator

jenkins

@@ -728,6 +731,21 @@ def allreduce_sequence_parallel_gradients(self):
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)

def allreduce_fsdp_sharding_omitted_gradients(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't have to be this PR, but shouldn't these calls go into an fsdp strategy rather than the model code? Maybe @athitten , you can address this is in a future PR.

@@ -54,7 +54,8 @@
HAVE_MEGATRON_CORE = False


def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
@torch.no_grad()
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, use_fsdp=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@athitten , we should be getting this from mcore. (for future pr)

Comment on lines +501 to +504
sharding_strategy: FSDP parameter sharding strategy.
grad_reduce_dtype: Data type for FSDP gradient shard ReduceScatter.
sharded_checkpoint: Store/load FSDP-sharded checkpoints.
precision: Precision recipe to be used with FSDP.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can go in a future PR, but I'd like to see more detail here. Like what are the possible values and also links to more information.

Comment on lines +571 to +578
if precision in ["16-true", "16-mixed", 16]:
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif precision in ["bf16-true", "bf16-mixed", "bf16"]:
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif precision == 32:
param_dtype = reduce_dtype = buffer_dtype = torch.float
else:
raise ValueError(f"Was unable to infer precision type, received {precision!r}.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again for a future PR but we need to use a utility function for this that can be shared across all of nemo.

return state_dict_type_context


class NLPFSDPStrategy(FSDPStrategy):
Copy link
Collaborator

@ericharper ericharper Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also noting here for a future PR that nlp_overrides.py is too long now and we should refactor it a bit.

Copy link
Collaborator

@ericharper ericharper left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@ericharper ericharper merged commit 10edd11 into NVIDIA:main Dec 16, 2023
11 checks passed
@erhoo82 erhoo82 deleted the slym/fsdp_patch_merge branch January 2, 2024 01:44
pzelasko pushed a commit to pzelasko/NeMo that referenced this pull request Jan 3, 2024
* FSDP with Tensor Parallelism

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Omit parameters with sequence-parallel updates and tensor-parallel duplicates from FSDP sharding

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

merge the list of grads to clip

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

clean up

Signed-off-by: Sangkug Lym <[email protected]>

pass cpu_initialization to TE transformer layer

Signed-off-by: Sangkug Lym <[email protected]>

clean up

Signed-off-by: Sangkug Lym <[email protected]>

Support FSDP sharded checkpoint

Signed-off-by: Sangkug Lym <[email protected]>

Do not use DistCkpt when using FSDP

Signed-off-by: Sangkug Lym <[email protected]>

cleanup

Signed-off-by: Sangkug Lym <[email protected]>

cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* torch FSDP patch

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* Remove torch FSDP patch

Signed-off-by: Sangkug Lym <[email protected]>

* fix FSDP mixed precision settings

Signed-off-by: Sangkug Lym <[email protected]>

* Disable pipelined tensor-parallel communication overlap when using FSDP

Signed-off-by: Sangkug Lym <[email protected]>

* Add import of MCoreTransformerLayer to try, except

Signed-off-by: Abhishree Thittenamane <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Abhishree Thittenamane <[email protected]>
Co-authored-by: Abhishree Thittenamane <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Piotr Żelasko <[email protected]>
ssh-meister pushed a commit to ssh-meister/NeMo that referenced this pull request Feb 15, 2024
* FSDP with Tensor Parallelism

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Omit parameters with sequence-parallel updates and tensor-parallel duplicates from FSDP sharding

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

merge the list of grads to clip

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

clean up

Signed-off-by: Sangkug Lym <[email protected]>

pass cpu_initialization to TE transformer layer

Signed-off-by: Sangkug Lym <[email protected]>

clean up

Signed-off-by: Sangkug Lym <[email protected]>

Support FSDP sharded checkpoint

Signed-off-by: Sangkug Lym <[email protected]>

Do not use DistCkpt when using FSDP

Signed-off-by: Sangkug Lym <[email protected]>

cleanup

Signed-off-by: Sangkug Lym <[email protected]>

cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* torch FSDP patch

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* Remove torch FSDP patch

Signed-off-by: Sangkug Lym <[email protected]>

* fix FSDP mixed precision settings

Signed-off-by: Sangkug Lym <[email protected]>

* Disable pipelined tensor-parallel communication overlap when using FSDP

Signed-off-by: Sangkug Lym <[email protected]>

* Add import of MCoreTransformerLayer to try, except

Signed-off-by: Abhishree Thittenamane <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Abhishree Thittenamane <[email protected]>
Co-authored-by: Abhishree Thittenamane <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Sasha Meister <[email protected]>
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
* FSDP with Tensor Parallelism

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Omit parameters with sequence-parallel updates and tensor-parallel duplicates from FSDP sharding

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

merge the list of grads to clip

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

clean up

Signed-off-by: Sangkug Lym <[email protected]>

pass cpu_initialization to TE transformer layer

Signed-off-by: Sangkug Lym <[email protected]>

clean up

Signed-off-by: Sangkug Lym <[email protected]>

Support FSDP sharded checkpoint

Signed-off-by: Sangkug Lym <[email protected]>

Do not use DistCkpt when using FSDP

Signed-off-by: Sangkug Lym <[email protected]>

cleanup

Signed-off-by: Sangkug Lym <[email protected]>

cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* torch FSDP patch

Signed-off-by: Sangkug Lym <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* Remove torch FSDP patch

Signed-off-by: Sangkug Lym <[email protected]>

* fix FSDP mixed precision settings

Signed-off-by: Sangkug Lym <[email protected]>

* Disable pipelined tensor-parallel communication overlap when using FSDP

Signed-off-by: Sangkug Lym <[email protected]>

* Add import of MCoreTransformerLayer to try, except

Signed-off-by: Abhishree Thittenamane <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Abhishree Thittenamane <[email protected]>
Co-authored-by: Abhishree Thittenamane <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants