Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizing distributed Adam when running with one work queue #1551

Merged
merged 9 commits into from
Jan 20, 2023

Conversation

timmoon10
Copy link
Contributor

@timmoon10 timmoon10 commented Dec 7, 2022

When @erhoo82 ran NeMo-Megatron with CUDA_DEVICE_MAX_CONNECTIONS=1, he observed poor overlapping between the model's backward compute and the distributed Adam optimizer's gradient reduce-scatters. In particular, if multiple reductions are launched in a row, then only the last one is possible to overlap. This PR makes several changes to optimize performance:

  • When dist Adam launches multiple grad reduce-scatters at the same time, it coalesces them with NCCL group calls.
  • Support initializing multiple params together, so their grad reductions are launched together.
  • Support variable-sized buckets. This is not currently used (large buckets tend to increase memory overheads), but it may be a useful feature in the future.

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excuse me for the delay but there seems to be a quite recent change in pytorch with which this pull request failed

apex/contrib/optimizers/distributed_fused_adam.py Outdated Show resolved Hide resolved
apex/contrib/optimizers/distributed_fused_adam.py Outdated Show resolved Hide resolved
@crcrpar crcrpar modified the milestones: 23.01, 23.02 Jan 19, 2023
@crcrpar crcrpar merged commit 75f401e into NVIDIA:master Jan 20, 2023
timmoon10 added a commit to timmoon10/apex that referenced this pull request Mar 2, 2023
Handles checkpoints generated before NVIDIA#1551.
yuanzhedong pushed a commit to yuanzhedong/apex that referenced this pull request Jul 14, 2023
…1551)

* Coalesce reduce-scatters in distributed Adam

* Support variable-size param buckets in dist Adam optimizer

* Support contiguous grad buffer with variable-size param buckets

* Add dist Adam unit test with contiguous grad buffers

* Optimize compute/communication overlap in dist Adam optim step

* Restore Dist Adam default of splitting params across default-sized buckets

* Support initializing multiple dist Adam param buckets together

The buckets perform communication together, so they are effectively a large bucket.

* Handle recent change in PyTorch API for coalescing NCCL calls
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants