Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split test_torchsnapshot_saver to two test files #548

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 12 additions & 78 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, PGWrapper
from torchtnt.utils.env import init_from_env, seed
from torchtnt.utils.test_utils import get_pet_launch_config, spawn_multi_process
from torchtnt.utils.test_utils import get_pet_launch_config


class TorchSnapshotSaverTest(unittest.TestCase):
# pyre-fixme[4]: Attribute must be annotated.
cuda_available = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_save_every_n_train_steps(self) -> None:
input_dim = 2
Expand Down Expand Up @@ -302,17 +301,15 @@ def test_save_on_train_end(self) -> None:
)
self.assertTrue(os.path.exists(os.path.join(temp_dir, expected_path)))

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
distributed_available, reason="Torch distributed is needed to run"
)
def test_directory_sync_collective(self) -> None:
spawn_multi_process(
2,
"gloo",
self._directory_sync_collective,
)
config = get_pet_launch_config(2)
launcher.elastic_launch(
config,
entrypoint=self._directory_sync_collective,
)()

@staticmethod
def _directory_sync_collective() -> None:
Expand All @@ -332,62 +329,6 @@ def _directory_sync_collective() -> None:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp,
)

@staticmethod
def _save_restore_fsdp() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
replicated=["**"],
)
temp_dir = snapshot_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
snapshot_cb.restore(ckpt_path, my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

def test_saver_invalid_args(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
with self.assertRaisesRegex(
Expand Down Expand Up @@ -427,10 +368,8 @@ def test_latest_checkpoint_path(self) -> None:
os.mkdir(path_4)
self.assertEqual(get_latest_checkpoint_path(temp_dir), path_3)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
distributed_available, reason="Torch distributed is needed to run"
)
def test_latest_checkpoint_path_distributed(self) -> None:
config = get_pet_launch_config(2)
Expand Down Expand Up @@ -474,17 +413,12 @@ def _latest_checkpoint_path_distributed() -> None:
if is_rank0:
shutil.rmtree(temp_dir) # delete temp directory

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
distributed_available, reason="Torch distributed is needed to run"
)
def test_save_restore_ddp(self) -> None:
spawn_multi_process(
2,
"gloo",
self._save_restore_ddp,
)
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._save_restore_ddp)()

@staticmethod
def _save_restore_ddp() -> None:
Expand Down
78 changes: 78 additions & 0 deletions tests/framework/callbacks/test_torchsnapshot_saver_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import tempfile
import unittest

import torch

from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
from torchtnt.framework.callbacks.torchsnapshot_saver import TorchSnapshotSaver
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.test_utils import spawn_multi_process


class TorchSnapshotSaverGPUTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp,
)

@staticmethod
def _save_restore_fsdp() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
replicated=["**"],
)
temp_dir = snapshot_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
snapshot_cb.restore(ckpt_path, my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory
Loading