-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] Add support for Pytorch XLA FSDP #17421
Conversation
for more information, see https://pre-commit.ci
cc @Liyang90 for visibility |
TODO: implement checkpointing for XLAFSDP, implement XLAFSDP in Pytorch
for more information, see https://pre-commit.ci
When looking at the xla_fsdp documentation, checkpoints are saved as sharded checkpoints, which can be combined into a single checkpoint using consolidate_sharded_model_checkpoints. To save the sharded checkpoints, the user will need to specify device-specific paths (i.e, for 4 devices, 4 separate paths). Here is a good reference example of how to save checkpoints and consolidate. This is the current idea I have:
Would we also want to combine the sharded checkpoints in What are everyone's thoughts? If we do not include checkpoint consolidation directly in the |
for more information, see https://pre-commit.ci
@gkroiz We should try to stay as close to PyTorch's FSDP implementation. For example, see Fabric's FSDP implementation which is our most up-to-date version at the moment: https://github.com/Lightning-AI/lightning/blob/a5c43d3b2b80f5fc769d7ed0ea511c0bd6733c6b/src/lightning/fabric/strategies/fsdp.py#L306-L368 XLA's FSDP implementation was forked from the original fairscale implementation. The PyTorch version is quickly evolving, so it would be useful to know what's the plan for XLA's. For instance, XLA's doesn't offer the For consolidation, this would be done separately and as an opt-in. We are in the process of deciding the desired design for this in Lightning (cc @awaelchli): Our goal is to support {saving,loading} {sharded,consolidated} checkpoints. But to limit the scope of this work we can focus on saving and loading just sharded versions. See for example this comment: #16526 (comment) |
Thank you for #17421 (comment), I agree that we want to stay as close as possible to PyTorch's FSDP implementation but difficulties arise since the XLA FSDP implementation is somewhat behind. I'll make some adjustments to try to match https://github.com/Lightning-AI/lightning/blob/a5c43d3b2b80f5fc769d7ed0ea511c0bd6733c6b/src/lightning/fabric/strategies/fsdp.py#L306-L368 and then we can further adjust from there |
- changed optimizer_step in Fabric xla_fsdp strategy - removed xla_fsdp mnist E2E example - minor cleanup
- added testing file for Pytorch fsdp_xla strategy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great. I'll do a thorough pass this week
src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Outdated
Show resolved
Hide resolved
- removed remaining instances of TPUAccelerator - minor adjustments throughout
for more information, see https://pre-commit.ci
Want to check in here to make sure this PR does not become too stale and that it eventually lands. Other than updating these changes with master, what else is needed? |
Due to the large size of this PR, I think it is best to split this into two separate PRs, one for Fabric and one for Trainer. |
Fabric support for Pytorch XLA FSDP on TPUs: #18126 |
What does this PR do?
Address feature request #13209 to implement FSDP in Pytorch XLA for TPUs. Use the
XLAFSDPStrategy
to use FSDP on TPUs.Fixes #13209
Before submitting
PR review
Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist