Skip to content

Commit

Permalink
Serialization: take into account meta tensor when splitting the `stat…
Browse files Browse the repository at this point in the history
…e_dict` (#2591)

* Enable meta tensor serialization

* getattr is better

* style

* skip meta tensors

* update doc

* Update src/huggingface_hub/serialization/_torch.py

Co-authored-by: Lucain <[email protected]>

* oups

---------

Co-authored-by: Lucain <[email protected]>
  • Loading branch information
SunMarc and Wauplin authored Oct 10, 2024
1 parent 2c7c19d commit 8cb81ac
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,18 +368,21 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
return unique_id


def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", Union[int, Tuple[Any, ...]], int]:
def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]:
"""
Return unique identifier to a tensor storage.
Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
Multiple different tensors can share the same underlying storage. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
In the case of meta tensors, we return None since we can't tell if they share the same storage.
Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
if tensor.device.type == "meta":
return None
else:
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)


def get_torch_storage_size(tensor: "torch.Tensor") -> int:
Expand Down

0 comments on commit 8cb81ac

Please sign in to comment.