-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(2/n) Support 2D Parallelism - Distributed Checkpoints #19852
Conversation
d968658
to
83c4830
Compare
b2bd273
to
a5f01b5
Compare
⚡ Required checks status: All passing 🟢Groups summary🟢 pytorch_lightning: Tests workflowThese checks are required after the changes to 🟢 pytorch_lightning: Azure GPU
These checks are required after the changes to 🟢 pytorch_lightning: Benchmarks
These checks are required after the changes to 🟢 fabric: Docs
These checks are required after the changes to 🟢 pytorch_lightning: Docs
These checks are required after the changes to 🟢 lightning_fabric: CPU workflowThese checks are required after the changes to 🟢 lightning_fabric: Azure GPU
These checks are required after the changes to 🟢 mypy
These checks are required after the changes to 🟢 installThese checks are required after the changes to Thank you for your contribution! 💜
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19852 +/- ##
=========================================
- Coverage 84% 59% -25%
=========================================
Files 425 420 -5
Lines 34917 34912 -5
=========================================
- Hits 29292 20525 -8767
- Misses 5625 14387 +8762 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. It's nice to be able to ignore the non 2.3 code paths
f44d667
to
9d08998
Compare
74a0441
to
6948957
Compare
Looks great! |
What does this PR do?
Follow up to #19846
Adds the checkpointing logic to support all features we have in Fabric. This uses the new
torch.distributed.checkpoint
APIs from PyTorch 2.3+.The caveat is that loading non-distributed checkpoints into a distributed model is now more difficult. I have not yet been able to resolve all edge cases around loading optimizer state, so for this iteration of the PR it is not yet supported to load optimizer states.
📚 Documentation preview 📚: https://pytorch-lightning--19852.org.readthedocs.build/en/19852/
cc @Borda @awaelchli @carmocca @justusschock