-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Revised partial checkpoint support for Sagemaker Model Parallel #16950
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR. Two comments on it:
- This breaks the current behavior of the
Trainerwhere each checkpoint can be loaded as a model. In particular, this will push to the Hub the partial checkpoints with no config during training whenpush_to_hub=True(whereas a regular training pushes models that can be used). - The feature is always on. Maybe we should let the user decide if they want it or not?
| if is_sagemaker_mp_enabled(): | ||
| smp.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME), partial=True) | ||
| else: | ||
| self.model.save_pretrained(output_dir, state_dict=state_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No calling save_pretrained here means the config will not be saved and the checkpoint won't be able to be loaded with from_pretrained independently of the training. It's not a regular checkpoint anyway, so maybe it's okay. Flagging this here anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SMP checkpoints are saved partially hence we do not want to shard SMP checkpoints. In order to use save_pretrained for SMP, we need to skip shard_checkpoint for SMP. Independent ofmax_shard_size shard_checkpoint is called and we hit errors in shard_checkpoint since SMP checkpoints are different. If shard_checkpoint can be optional in save_pretrained, we can use save_pretrained with save_function=smp.save. In my previous PR I tried to skip shard_checkpoint for SMP, but feedback was not change save_pretrained.
from_pretrained won't work for SMP models. We are working on how to support fine-tuning. In this PR, I added support partial checkpoint saving/loading during training.
| checkpoint_file_exists = ( | ||
| glob.glob(os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + "_*") | ||
| if is_sagemaker_mp_enabled() | ||
| else os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used several times, could we refactor it in a util function that takes the filename?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will do that.
| checkpoint_file_exists = ( | ||
| glob.glob(os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + "_*") | ||
| if is_sagemaker_mp_enabled() | ||
| else os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used several times, could we refactor it in a util function that takes the filename?
Thanks for reviewing. In order user to decide to save/load partial checkpoints or not, we need new training args. In my previous PR, I got feedback not to introduce new HF training args. So we decided to support partial checkpointing as default. |
|
There are plenty of other ways to control whether a feature is on or off. For instance, you could use the environment variable Since this partial checkpointing is completely incompatible with |
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
What does this PR do?
smp.rdp_rank()instead ofsmp.rank()for partial checkpoint saving inshould_save.local_state_dict()with partial checkpoint saving.smp.savefor SMP.smp.loadfor SMP. Reorders partial checkpoint loading to happen after wrapping of model, sincesmp.loadcan only load to a smp model.filename_0_0orfilename_0_0_0).load_best_model_at_endsupport for SMPFixes # (issue)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.