Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making wrapper tensor subclass to work in serialization #2440

Merged
merged 18 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 59 additions & 12 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union, Any

from .. import constants, logging
from ._base import MAX_SHARD_SIZE, StateDictSplit, split_state_dict_into_shards_factory
Expand Down Expand Up @@ -335,18 +335,23 @@ def split_torch_state_dict_into_shards(
get_storage_id=get_torch_storage_id,
)


def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, int]:
def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
"""Returns a unique id for plain tensor
or a (potentially nested) Tuple of unique id for the flattened Tensor
if the input is a wrapper tensor subclass Tensor
"""
Return unique identifier to a tensor storage.

Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__()
return tuple(_get_unique_id(getattr(tensor, attr)) for attr in attrs)

except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass

Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
if tensor.device.type == "xla" and is_torch_tpu_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
Expand All @@ -358,13 +363,36 @@ def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", int, i
else:
unique_id = storage_ptr(tensor)

return tensor.device, unique_id, get_torch_storage_size(tensor)
return unique_id

def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", Union[int, Tuple[Any, ...]], int]:
"""
Return unique identifier to a tensor storage.

Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.

Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
"""
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, two "meta" tensors can have the exact same _get_unique_id(tensor), the exact same tensor.device but still be different, correct? If different, how can we be sure their storage size distinguish them? Can it happen that they randomly happen to have the same storage size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it just means the current approach does not generalize to meta tensor, does it work previously?

I think we'd need to reimplement the higher level sharding logic in the end in pytorch, I added some PoC in the slack, let me make a quick intro there

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it just means the current approach does not generalize to meta tensor, does it work previously?

I don't think so since we never had to serialize meta tensors. The only use case that could benefit from that is in accelerate (find tied parameters from the meta model). Right now, this is how we do for meta tensors: https://github.com/huggingface/accelerate/blob/726140cad2f2361d79da7786a7b96d0bee591c48/src/accelerate/utils/modeling.py#L677



def get_torch_storage_size(tensor: "torch.Tensor") -> int:
"""
Taken from https://github.com/huggingface/safetensors/blob/08db34094e9e59e2f9218f2df133b7b4aaff5a99/bindings/python/py_src/safetensors/torch.py#L31C1-L41C59
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__()
return sum(get_torch_storage_size(getattr(tensor, attr)) for attr in attrs)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass

try:
return tensor.untyped_storage().nbytes()
except AttributeError:
Expand Down Expand Up @@ -398,10 +426,19 @@ def is_torch_tpu_available(check_device=True):
return False


def storage_ptr(tensor: "torch.Tensor") -> int:
def storage_ptr(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L11.
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
return _get_unique_id(tensor)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass

try:
return tensor.untyped_storage().data_ptr()
except Exception:
Expand Down Expand Up @@ -496,6 +533,16 @@ def _is_complete(tensor: "torch.Tensor") -> bool:
"""
Taken from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L80
"""
try:
# for torch 2.1 and above we can also handle tensor subclasses
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if is_traceable_wrapper_subclass(tensor):
attrs, _ = tensor.__tensor_flatten__()
return all(_is_complete(getattr(tensor, attr)) for attr in attrs)
except ImportError:
# for torch version less than 2.1, we can fallback to original implementation
pass

return tensor.data_ptr() == storage_ptr(tensor) and tensor.nelement() * _get_dtype_size(
tensor.dtype
) == get_torch_storage_size(tensor)
Expand Down
113 changes: 113 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def _dummy_get_storage_size(item):
return sum(item)


# util functions for checking the version for pytorch
def is_wrapper_tensor_subclass_available():
try:
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
return True
except ImportError:
return False

@pytest.fixture
def dummy_state_dict() -> Dict[str, List[int]]:
return {
Expand Down Expand Up @@ -58,6 +66,25 @@ def torch_state_dict() -> Dict[str, "torch.Tensor"]:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_tensor_subclass() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

t = torch.tensor([4])
return {
"layer_1": torch.tensor([4]),
"layer_2": torch.tensor([10]),
"layer_3": torch.tensor([30]),
"layer_4": torch.tensor([2]),
"layer_5": torch.tensor([2]),
"layer_6": TwoTensor(t, t),
}
except ImportError:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
try:
Expand All @@ -75,6 +102,55 @@ def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers_tensor_subclass() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

t = torch.tensor([4])
tensor_subclass_tensor = TwoTensor(t, t)

t = torch.tensor([4])
shared_tensor_subclass_tensor = TwoTensor(t, t)
return {
"layer_1": torch.tensor([4]),
"layer_2": torch.tensor([10]),
"layer_3": torch.tensor([30]),
"layer_4": torch.tensor([2]),
"layer_5": torch.tensor([2]),
"layer_6": tensor_subclass_tensor,
"ts_shared_1": shared_tensor_subclass_tensor,
"ts_shared_2": shared_tensor_subclass_tensor,
}
except ImportError:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

if is_wrapper_tensor_subclass_available():
# TODO: need to fix safetensor support for tensor subclasses before we can add this
# to test
# shared_layer = TwoTensor(torch.tensor([4]), torch.tensor([4]))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin this seem to fail because safetensor does not support wrapper tensor subclass yet, we can enable this when we add the similar support in safetensors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we don't have to test on save_torch_state_dict using safetensors. It is possible to simply test with split_torch_state_dict_into_shards (needs to be imported) since what we really care about is the grouping of the tensors, not necessarily the serialization -for now at least-. Could you update the fixture and test in that direction?

Also, torch_state_dict_shared_layers is already taken as a name for a fixture so code quality is complaining in the CI.

shared_layer = torch.tensor([4])
else:
shared_layer = torch.tensor([4])

return {
"shared_1": shared_layer,
"unique_1": torch.tensor([10]),
"unique_2": torch.tensor([30]),
"shared_2": shared_layer,
}
except ImportError:
pytest.skip("torch is not available")


def test_single_shard(dummy_state_dict):
state_dict_split = split_state_dict_into_shards_factory(
dummy_state_dict,
Expand Down Expand Up @@ -170,6 +246,17 @@ def test_get_torch_storage_size():
assert get_torch_storage_size(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)) == 5 * 2


@requires("torch")
@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
def test_get_torch_storage_size_wrapper_tensor_subclass():
import torch
from torch.testing._internal.two_tensor import TwoTensor
t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float64)
assert get_torch_storage_size(TwoTensor(t, t)) == 5 * 8 * 2
t = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float16)
assert get_torch_storage_size(TwoTensor(t, TwoTensor(t, t))) == 5 * 2 * 3


def test_parse_size_to_int():
assert parse_size_to_int("1KB") == 1 * 10**3
assert parse_size_to_int("2MB") == 2 * 10**6
Expand Down Expand Up @@ -247,6 +334,32 @@ def test_save_torch_state_dict_unsafe_not_sharded(
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
def test_save_torch_state_dict_tensor_subclass_unsafe_not_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_tensor_subclass: Dict[str, "torch.Tensor"]
) -> None:
"""Save as pickle without sharding."""
with caplog.at_level("WARNING"):
save_torch_state_dict(torch_state_dict_tensor_subclass, tmp_path, max_shard_size="1GB", safe_serialization=False)
assert "we strongly recommend using safe serialization" in caplog.text

assert (tmp_path / "pytorch_model.bin").is_file()
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


@pytest.mark.skipif(not is_wrapper_tensor_subclass_available(), reason="requires torch 2.1 or higher")
def test_save_torch_state_dict_shared_layers_tensor_subclass_unsafe_not_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"]
) -> None:
"""Save as pickle without sharding."""
with caplog.at_level("WARNING"):
save_torch_state_dict(torch_state_dict_shared_layers_tensor_subclass, tmp_path, max_shard_size="1GB", safe_serialization=False)
assert "we strongly recommend using safe serialization" in caplog.text

assert (tmp_path / "pytorch_model.bin").is_file()
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


def test_save_torch_state_dict_unsafe_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"]
) -> None:
Expand Down
Loading