-
Notifications
You must be signed in to change notification settings - Fork 590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making wrapper tensor subclass to work in serialization #2440
Conversation
40413f5
to
d6c4256
Compare
@SunMarc can you find someone to review this? |
…on (non-safetensor) Summary: huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id from adding all storage_ids from internal plain tensors, this is a bit hacky, open to more robust ideas. Test Plan: tested with script in huggingface/transformers#32364 Reviewers: Subscribers: Tasks: Tags:
d6c4256
to
40d9bda
Compare
40d9bda
to
4ebe8aa
Compare
Summary: Similar to huggingface/huggingface_hub#2440 we want to allow safetensor to handle wrapper tensor subclasses, we mainly added: 1. tensor storage size: this is done through flattening the wrapper tensor subclass and add up the storage size of all sub tensors recursively 2. storage_ptr: this is done by constructing a tuple given the "storage_ptr" for flattened tensors, this could be a nested tuple of tuple of int, e.g. ((1, 2), 3, (4, (5, 6),),) Test Plan: Added a test in test_pt_model.py, will also test manually Reviewers: Subscribers: Tasks: Tags:
Hi @jerryzh168, thanks for the opening this PR! And thanks @SunMarc for pulling me in this convo. The general logic of this PR looks ok to me -even though I'm missing some broader context I think-. The tests look good as well. However, as a In the end, the only thing that I want for |
@Wauplin I can add this to pytorch, but it will only be available in nightlies or torch 2.5+, so don't we still need to add this in huggingface_hub or other places for now? what is the version requirement for huggingface_hub/safetensors etc.? do they plan to work with all different torch versions? |
Summary: Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of `get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead This PR added `get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors `get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors Test Plan: python test/test_utils.py TestStorageUtils Reviewers: Subscribers: Tasks: Tags:
Summary: Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of `get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead This PR added `get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors `get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors Test Plan: python test/test_utils.py TestStorageUtils Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of `get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead This PR added `get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors `get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors Test Plan: python test/test_utils.py TestStorageUtils Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 155983cad1176187ef703b69ad06651cbf2ccd83 Pull Request resolved: #133524
Summary: Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of `get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead This PR added `get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors `get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors Test Plan: python test/test_utils.py TestStorageUtils Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Currently [huggingface_hub](https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/serialization/_torch.py), [safetensors](https://github.com/huggingface/safetensors/blob/main/bindings/python/py_src/safetensors/torch.py#L11), [accelerate](https://github.com/huggingface/accelerate/blob/a452327e8e04b20779882dc491e00de602d554cb/src/accelerate/utils/modeling.py#L175) all have their own implementation of `get_storage_id` and `get_storage_size`, `storage_ptr`, which makes assumption on internal implementation details of torch.Tensor, and storage, and does not work for wrapper tensor subclasses Motivated by huggingface/huggingface_hub#2440 (comment) maybe it makes more sense to add these as utils in pytorch so they can be maintained by us instead This PR added `get_storage_id`: returns a unique identifier for the tensor storage, for tensor subclasses, it returns a nested tuple of unique ids from underlying plain tensors `get_storage_size`: returns the size in bytes for the underlying storage, for tensor subclasses, it returns the sum of the size from all underlying plain tensors Test Plan: python test/test_utils.py TestStorageUtils Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 87d1bd591a8cb1f926738f8e251fc56d8cd9e3f2 Pull Request resolved: #133524
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jerryzh168, thanks for the changes! The implementation looks good to me at first glance. I've added a comment regarding the version parsing thing. In the meantime, I'll run some tests locally on my side.
In a follow-up PR, we'll add CI tests for torch 2.0 and 2.5 (for instance) to be sure both versions are compatible. I can take care of that part.
@Wauplin please take a look again, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes @jerryzh168! I left some comments below
|
||
Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278. | ||
""" | ||
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, two "meta" tensors can have the exact same _get_unique_id(tensor)
, the exact same tensor.device
but still be different, correct? If different, how can we be sure their storage size distinguish them? Can it happen that they randomly happen to have the same storage size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah it just means the current approach does not generalize to meta tensor, does it work previously?
I think we'd need to reimplement the higher level sharding logic in the end in pytorch, I added some PoC in the slack, let me make a quick intro there
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah it just means the current approach does not generalize to meta tensor, does it work previously?
I don't think so since we never had to serialize meta tensors. The only use case that could benefit from that is in accelerate (find tied parameters from the meta model). Right now, this is how we do for meta tensors: https://github.com/huggingface/accelerate/blob/726140cad2f2361d79da7786a7b96d0bee591c48/src/accelerate/utils/modeling.py#L677
tests/test_serialization.py
Outdated
# TODO: need to fix safetensor support for tensor subclasses before we can add this | ||
# to test | ||
# shared_layer = TwoTensor(torch.tensor([4]), torch.tensor([4])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Wauplin this seem to fail because safetensor does not support wrapper tensor subclass yet, we can enable this when we add the similar support in safetensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we don't have to test on save_torch_state_dict
using safetensors
. It is possible to simply test with split_torch_state_dict_into_shards
(needs to be imported) since what we really care about is the grouping of the tensors, not necessarily the serialization -for now at least-. Could you update the fixture and test in that direction?
Also, torch_state_dict_shared_layers
is already taken as a name for a fixture so code quality is complaining in the CI.
I pushed a commit to fix some linting + merge from main so that we are now testing the pipeline on both Last remaining thing is the test to complete IMO (see #2440 (comment)). Otherwise, everything looks good! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jerryzh168! Everything looks good to me now! I fixed a few formatting issues so we're now ready to merge this. Thanks again and thanks @SunMarc as well for the inputs :)
Looking forward to see deeper integration into pytorch directly!
…nfig quantized model Summary: After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization in huggingface, with this we can now add the support in transformers Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor see README for more details Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags:
…33456) * Enable non-safetensor serialization and deserialization for TorchAoConfig quantized model Summary: After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization in huggingface, with this we can now add the support in transformers Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor see README for more details Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags: * formatting * formatting * minor fix * formatting * address comments * comments * minor fix * update doc * refactor compressed tensor quantizer
…uggingface#33456) * Enable non-safetensor serialization and deserialization for TorchAoConfig quantized model Summary: After huggingface/huggingface_hub#2440 we added non-safetensor serialization and deserialization in huggingface, with this we can now add the support in transformers Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor see README for more details Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags: * formatting * formatting * minor fix * formatting * address comments * comments * minor fix * update doc * refactor compressed tensor quantizer
Summary:
huggingface_hub seriliazation relies on storage_ptr of a tensor to implement sharding logic, but wrapper_tensor_subclass does not have storage, so we unflatten the tensor and get storage_id by returning a tuple constructed from all storage_ids from internal plain or tensor subclassed tensors
Note: This PR only supported non-safetensor serialization for tensor subclasses
Test Plan:
tested with script in huggingface/transformers#32364
Reviewers:
Subscribers:
Tasks:
Tags: