-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save and load sharded checkpoints with FSDP in Fabric #17323
Conversation
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
⚡ Required checks status: All passing 🟢Groups summary🟢 pytorch_lightning: Tests workflow
These checks are required after the changes to 🟢 pytorch_lightning: Azure GPU
These checks are required after the changes to 🟢 fabric: Docs
These checks are required after the changes to 🟢 lightning_fabric: CPU workflowThese checks are required after the changes to 🟢 lightning_fabric: Azure GPU
These checks are required after the changes to 🟢 mypy
These checks are required after the changes to 🟢 installThese checks are required after the changes to 🟢 link-check
These checks are required after the changes to Thank you for your contribution! 💜
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
Co-authored-by: Luca Antiga <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that porting this to support torch less than 2.0 is important. If you are worried about silent errors, we can raise an error at the start of FSDP if the torch version is lower than 2.0, suggesting to upgrade.
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## master #17323 +/- ##
=========================================
- Coverage 83% 59% -24%
=========================================
Files 415 410 -5
Lines 31437 31427 -10
=========================================
- Hits 26048 18596 -7452
- Misses 5389 12831 +7442 |
for more information, see https://pre-commit.ci
What does this PR do?
Fixes #14816
This PR enables the following:
The checkpoint file structure looks like this (if devices=2):
The ".metadata" file is from the FSDP file writer, the "*.distcp" are the distributed checkpoint files holding the tensors, and the "meta.pt" is a file that Fabric's FSDPStrategy saves with all user dict data next to model and optimizer (from the example above:
{"other": "anything}
)Future Work
This is a minimal implementation for sharded checkpointing and loading. It is the best choice for large models and is the most memory efficient that FSDP can offer right now (offload to CPU, sharded state dict, chunk-wise filewriter). In the future, we need to
While testing, I stumbled upon this bug in PyTorch: pytorch/pytorch#99079
cc @Borda @awaelchli @carmocca @justusschock