-
Notifications
You must be signed in to change notification settings - Fork 580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Serialization] support loading torch state dict from disk #2687
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice PR @hanouticelina ! That's promising for our other libs :) I've made a first pass on the PR and left a few comments. Overall looks good though I think we can update a few things to expose only 1 or 2 methods in the library.
Also, would it be possible to open a PR on transformers
to showcase how it would be used? No need to make updates everywhere in the lib', just 1 example is enough for now
@Wauplin thanks for the review! I've addressed the comments and created a draft PR hanouticelina/transformers#1 to illustrate the integration (I've opened it on my personal fork for now while this PR is still WIP) –would love to have your feedback on that! |
Sorry for the policing, but does a similar PR need to be opened in |
@sayakpaul, yes! we plan to open a PR in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job @hanouticelina ! This is a super well documented and tested PR which is much appreciated for such a key-part of the library! Tested it locally and it seems to work as I'd expect it^^
@LysandreJik @ArthurZucker would it be possible to take a closer look at this PR and especially hanouticelina/transformers#1 to confirm everything's fine for you as well (worth case we ship and make hot-fixes if necessary). We'd like to ship this quickly :)
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " | ||
"you save your model with the `save_torch_model` method." | ||
) | ||
return load_file(checkpoint_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safetensors supports loading directly to a device: https://huggingface.co/docs/safetensors/api/torch#safetensors.torch.load_file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed! thanks for pointing this out.
Note that in transformers, we always load the state dict in cpu or in a meta device and then the weights get moved to their respective devices during model dispatch in the from_pretrained() method.
Also meta device is not supported with safetensors
. i guess we can fall back to cpu with a warning when meta device is specified in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From a quick look these two helpers could indeed replace a few codepaths in transformers. There is quite a bit of complexity there so I would recommend merging an initial version, setting up a branch in transformers to switch to these helpers and running the CI on it.
Thanks for your work @hanouticelina ! Any work that makes our from_pretrained
method and our modeling_utils
module simpler are welcome
thanks @LysandreJik, indeed, the best way to test these helpers is to run transformers CI on it! I will merge this and then will create the branch directly in transformers |
Implement helpers to load a torch state dict from disk. For the implementation, it's mostly an importing from transformers and diffusers implementation with additional error handling and some refactoring. the loading can be done from a single file or from shards. It handles both safetensors and pickle files. Saving torch state dict has been previously added in #2314.
This PR:
load_torch_model()
helper function that takes ann.Module
and a checkpoint path (either a single file or a directory) as input and loads the weights into the model.usage example:
load_state_dict_from_file()
: loads a single checkpoint file.Note: PRs will be opened in transformers, diffusers and accelerate to integrate these helpers once huggingface_hub v0.27.0 is released.
here is example of an integration in transformers: hanouticelina/transformers#1
cc @sayakpaul for diffusers, @muellerzr and @SunMarc for accelerate and @ArthurZucker for transformers. Happy to get any feedback on this. The goal is the same as for the saving helpers: standardize things across our libraries and establish consistent conventions!