Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serialization] support loading torch state dict from disk #2687

Merged
merged 18 commits into from
Dec 13, 2024

Conversation

hanouticelina
Copy link
Contributor

@hanouticelina hanouticelina commented Dec 2, 2024

Implement helpers to load a torch state dict from disk. For the implementation, it's mostly an importing from transformers and diffusers implementation with additional error handling and some refactoring. the loading can be done from a single file or from shards. It handles both safetensors and pickle files. Saving torch state dict has been previously added in #2314.


This PR:

  • adds load_torch_model() helper function that takes a nn.Module and a checkpoint path (either a single file or a directory) as input and loads the weights into the model.

usage example:

from huggingface_hub import load_torch_model
model = ... # A PyTorch model

# load the weights into model
load_torch_model(model, "path/to/checkpoint")
  • adds a low-level helper that can be used directly by transformers, diffusers and accelerate:
    • load_state_dict_from_file(): loads a single checkpoint file.
  • tests have been added and documentation have also been updated.

Note: PRs will be opened in transformers, diffusers and accelerate to integrate these helpers once huggingface_hub v0.27.0 is released.
here is example of an integration in transformers: hanouticelina/transformers#1

cc @sayakpaul for diffusers, @muellerzr and @SunMarc for accelerate and @ArthurZucker for transformers. Happy to get any feedback on this. The goal is the same as for the saving helpers: standardize things across our libraries and establish consistent conventions!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR @hanouticelina ! That's promising for our other libs :) I've made a first pass on the PR and left a few comments. Overall looks good though I think we can update a few things to expose only 1 or 2 methods in the library.

Also, would it be possible to open a PR on transformers to showcase how it would be used? No need to make updates everywhere in the lib', just 1 example is enough for now

src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
docs/source/en/package_reference/serialization.md Outdated Show resolved Hide resolved
docs/source/en/package_reference/serialization.md Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
@hanouticelina
Copy link
Contributor Author

@Wauplin thanks for the review! I've addressed the comments and created a draft PR hanouticelina/transformers#1 to illustrate the integration (I've opened it on my personal fork for now while this PR is still WIP) –would love to have your feedback on that!

@sayakpaul
Copy link
Member

Sorry for the policing, but does a similar PR need to be opened in diffusers too? 👀

@hanouticelina
Copy link
Contributor Author

Sorry for the policing, but does a similar PR need to be opened in diffusers too? 👀

@sayakpaul, yes! we plan to open a PR in diffusers once these helpers are released. For now, I've created hanouticelina/transformers/pull/1 just as an example integration

Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job @hanouticelina ! This is a super well documented and tested PR which is much appreciated for such a key-part of the library! Tested it locally and it seems to work as I'd expect it^^

@LysandreJik @ArthurZucker would it be possible to take a closer look at this PR and especially hanouticelina/transformers#1 to confirm everything's fine for you as well (worth case we ship and make hot-fixes if necessary). We'd like to ship this quickly :)

src/huggingface_hub/__init__.py Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Show resolved Hide resolved
src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_torch_model` method."
)
return load_file(checkpoint_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safetensors supports loading directly to a device: https://huggingface.co/docs/safetensors/api/torch#safetensors.torch.load_file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed! thanks for pointing this out.
Note that in transformers, we always load the state dict in cpu or in a meta device and then the weights get moved to their respective devices during model dispatch in the from_pretrained() method.

Also meta device is not supported with safetensors. i guess we can fall back to cpu with a warning when meta device is specified in this case.

src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From a quick look these two helpers could indeed replace a few codepaths in transformers. There is quite a bit of complexity there so I would recommend merging an initial version, setting up a branch in transformers to switch to these helpers and running the CI on it.

Thanks for your work @hanouticelina ! Any work that makes our from_pretrained method and our modeling_utils module simpler are welcome

@hanouticelina
Copy link
Contributor Author

thanks @LysandreJik, indeed, the best way to test these helpers is to run transformers CI on it! I will merge this and then will create the branch directly in transformers

@hanouticelina hanouticelina merged commit b75f8d9 into main Dec 13, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants