Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TPU] Add support for Pytorch XLA FSDP #17421

Closed
wants to merge 25 commits into from

Conversation

gkroiz
Copy link
Contributor

@gkroiz gkroiz commented Apr 20, 2023

What does this PR do?

Address feature request #13209 to implement FSDP in Pytorch XLA for TPUs. Use the XLAFSDPStrategy to use FSDP on TPUs.

Fixes #13209

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

@gkroiz gkroiz marked this pull request as draft April 20, 2023 07:34
@github-actions github-actions bot added the fabric lightning.fabric.Fabric label Apr 20, 2023
@carmocca carmocca added feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package strategy: xla labels Apr 20, 2023
@carmocca carmocca self-assigned this Apr 20, 2023
@carmocca carmocca added this to the 2.1 milestone Apr 20, 2023
@gkroiz
Copy link
Contributor Author

gkroiz commented Apr 20, 2023

cc @Liyang90 for visibility

TODO: implement checkpointing for XLAFSDP, implement XLAFSDP in Pytorch
@github-actions github-actions bot removed the pl Generic label for PyTorch Lightning package label Apr 21, 2023
@gkroiz
Copy link
Contributor Author

gkroiz commented Apr 21, 2023

When looking at the xla_fsdp documentation, checkpoints are saved as sharded checkpoints, which can be combined into a single checkpoint using consolidate_sharded_model_checkpoints. To save the sharded checkpoints, the user will need to specify device-specific paths (i.e, for 4 devices, 4 separate paths). Here is a good reference example of how to save checkpoints and consolidate.

This is the current idea I have:

    def save_checkpoint(
        self, path: _PATH, state: Dict[str, Union[Module, Optimizer, Any]], storage_options: Optional[Any] = None
    ) -> None:
        """Save model, optimizer, and other state as a checkpoint file.

        Args:
            path: A path to where the file(s) should be saved
            state: A dictionary with contents to be saved. If the dict contains modules or optimizers, their
                state-dict will be retrieved and converted automatically.
            storage_options: Additional options for the ``CheckpointIO`` plugin
        """
        rank_zero_warn("Saving checkpoints in the XLAFSDPStrategy requires saving a sharded checkpoint for each device. \
        Please make sure the path specified is device specific!")

        state = self._convert_stateful_objects_in_state(state)

        if 'model' not in state:
             raise ValueError('XLAFSDPStrategy requires the saved state to include \'model\'.')
        if 'shard_metadata' not in state:
             raise ValueError('XLAFSDPStrategy requires the saved state to include \'shard_metadata\'.')
        if 'optimizer' not in state
             raise ValueError('XLAFSDPStrategy requires the saved state to include \'optimizer\'.')

        self.checkpoint_io.save_checkpoint(state, path, storage_options=storage_options)

Would we also want to combine the sharded checkpoints in save_checkpoint()? My only concern here is that the use would need to specify several other function arguments like ckpt_prefix, ckpt_suffix, and save_path. Combining sharded checkpoints is beneficial as the end result is a single checkpoint that can be loaded in for inference etc. The downside is that this strategy's save_checkpoint() will require several more work from the user to use the correct function arguments.

What are everyone's thoughts? If we do not include checkpoint consolidation directly in the save_checkpoint() function, we can also (1) create a separate function that does this, or (2) reference some documentation on the matter.

cc @carmocca @Liyang90

@carmocca
Copy link
Contributor

@gkroiz We should try to stay as close to PyTorch's FSDP implementation. For example, see Fabric's FSDP implementation which is our most up-to-date version at the moment: https://github.com/Lightning-AI/lightning/blob/a5c43d3b2b80f5fc769d7ed0ea511c0bd6733c6b/src/lightning/fabric/strategies/fsdp.py#L306-L368
Here, the user specifies the directory where the shards are saved and the shard names are managed automatically by the strategy or an internal.

XLA's FSDP implementation was forked from the original fairscale implementation. The PyTorch version is quickly evolving, so it would be useful to know what's the plan for XLA's. For instance, XLA's doesn't offer the use_orig_params argument or the new saving APIs. This means unifying might not be easy (so it's not a hard requirement), but if we do it then it should be easier to update if XLA's version is updated to match in the future.

For consolidation, this would be done separately and as an opt-in. We are in the process of deciding the desired design for this in Lightning (cc @awaelchli): Our goal is to support {saving,loading} {sharded,consolidated} checkpoints. But to limit the scope of this work we can focus on saving and loading just sharded versions. See for example this comment: #16526 (comment)

src/lightning/fabric/strategies/xla_fsdp.py Outdated Show resolved Hide resolved
src/lightning/fabric/strategies/xla_fsdp.py Outdated Show resolved Hide resolved
examples/fabric/tpu/train_fsdp_mnist.py Outdated Show resolved Hide resolved
examples/fabric/tpu/train_fsdp_mnist.py Outdated Show resolved Hide resolved
src/lightning/fabric/strategies/xla_fsdp.py Show resolved Hide resolved
@gkroiz
Copy link
Contributor Author

gkroiz commented Apr 24, 2023

Thank you for #17421 (comment), I agree that we want to stay as close as possible to PyTorch's FSDP implementation but difficulties arise since the XLA FSDP implementation is somewhat behind. I'll make some adjustments to try to match https://github.com/Lightning-AI/lightning/blob/a5c43d3b2b80f5fc769d7ed0ea511c0bd6733c6b/src/lightning/fabric/strategies/fsdp.py#L306-L368 and then we can further adjust from there

gkroiz added 4 commits April 25, 2023 06:49
- changed optimizer_step in Fabric xla_fsdp strategy
- removed xla_fsdp mnist E2E example
- minor cleanup
- added testing file for Pytorch fsdp_xla strategy
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Apr 26, 2023
Copy link
Contributor

@carmocca carmocca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great. I'll do a thorough pass this week

src/lightning/fabric/connector.py Outdated Show resolved Hide resolved
src/lightning/fabric/connector.py Outdated Show resolved Hide resolved
src/lightning/fabric/connector.py Outdated Show resolved Hide resolved
@gkroiz
Copy link
Contributor Author

gkroiz commented Jun 6, 2023

Want to check in here to make sure this PR does not become too stale and that it eventually lands. Other than updating these changes with master, what else is needed?

gkroiz added a commit to gkroiz/lightning that referenced this pull request Jul 21, 2023
@gkroiz
Copy link
Contributor Author

gkroiz commented Jul 21, 2023

Due to the large size of this PR, I think it is best to split this into two separate PRs, one for Fabric and one for Trainer.

@gkroiz
Copy link
Contributor Author

gkroiz commented Jul 21, 2023

Fabric support for Pytorch XLA FSDP on TPUs: #18126

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fabric lightning.fabric.Fabric feature Is an improvement or enhancement pl Generic label for PyTorch Lightning package strategy: xla
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature request: FSDP native strategy for TPUs
2 participants