Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 committed Aug 9, 2024
1 parent a9911b8 commit 40d9bda
Showing 1 changed file with 63 additions and 3 deletions.
66 changes: 63 additions & 3 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,22 @@ def dummy_state_dict() -> Dict[str, List[int]]:

@pytest.fixture
def torch_state_dict() -> Dict[str, "torch.Tensor"]:
try:
import torch

return {
"layer_1": torch.tensor([4]),
"layer_2": torch.tensor([10]),
"layer_3": torch.tensor([30]),
"layer_4": torch.tensor([2]),
"layer_5": torch.tensor([2]),
}
except ImportError:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_tensor_subclass() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor
Expand All @@ -69,14 +85,34 @@ def torch_state_dict_shared_layers() -> Dict[str, "torch.Tensor"]:

shared_layer = torch.tensor([4])

t = torch.tensor([4])
shared_tensor_subclass_tensor = TwoTensor(t, t)

return {
"shared_1": shared_layer,
"unique_1": torch.tensor([10]),
"unique_2": torch.tensor([30]),
"shared_2": shared_layer,
}
except ImportError:
pytest.skip("torch is not available")


@pytest.fixture
def torch_state_dict_shared_layers_tensor_subclass() -> Dict[str, "torch.Tensor"]:
try:
import torch
from torch.testing._internal.two_tensor import TwoTensor

t = torch.tensor([4])
tensor_subclass_tensor = TwoTensor(t, t)

t = torch.tensor([4])
shared_tensor_subclass_tensor = TwoTensor(t, t)
return {
"layer_1": torch.tensor([4]),
"layer_2": torch.tensor([10]),
"layer_3": torch.tensor([30]),
"layer_4": torch.tensor([2]),
"layer_5": torch.tensor([2]),
"layer_6": tensor_subclass_tensor,
"ts_shared_1": shared_tensor_subclass_tensor,
"ts_shared_2": shared_tensor_subclass_tensor,
}
Expand Down Expand Up @@ -285,6 +321,30 @@ def test_save_torch_state_dict_unsafe_not_sharded(
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


def test_save_torch_state_dict_tensor_subclass_unsafe_not_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_tensor_subclass: Dict[str, "torch.Tensor"]
) -> None:
"""Save as pickle without sharding."""
with caplog.at_level("WARNING"):
save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB", safe_serialization=False)
assert "we strongly recommend using safe serialization" in caplog.text

assert (tmp_path / "pytorch_model.bin").is_file()
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


def test_save_torch_state_dict_shared_layers_tensor_subclass_unsafe_not_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict_shared_layers_tensor_subclass: Dict[str, "torch.Tensor"]
) -> None:
"""Save as pickle without sharding."""
with caplog.at_level("WARNING"):
save_torch_state_dict(torch_state_dict, tmp_path, max_shard_size="1GB", safe_serialization=False)
assert "we strongly recommend using safe serialization" in caplog.text

assert (tmp_path / "pytorch_model.bin").is_file()
assert not (tmp_path / "pytorch_model.bin.index.json").is_file()


def test_save_torch_state_dict_unsafe_sharded(
tmp_path: Path, caplog: pytest.LogCaptureFixture, torch_state_dict: Dict[str, "torch.Tensor"]
) -> None:
Expand Down

0 comments on commit 40d9bda

Please sign in to comment.