Skip to content

Commit

Permalink
split test_torchsnapshot_saver to two test files
Browse files Browse the repository at this point in the history
Summary: Extract GPU requiring test case from `test_torchsnapshot_saver`  to `test_torchsnapshot_saver_gpu`

Differential Revision: D49482055
  • Loading branch information
galrotem authored and facebook-github-bot committed Sep 21, 2023
1 parent 0b926c3 commit 8554fae
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 78 deletions.
90 changes: 12 additions & 78 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank, PGWrapper
from torchtnt.utils.env import init_from_env, seed
from torchtnt.utils.test_utils import get_pet_launch_config, spawn_multi_process
from torchtnt.utils.test_utils import get_pet_launch_config


class TorchSnapshotSaverTest(unittest.TestCase):
# pyre-fixme[4]: Attribute must be annotated.
cuda_available = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

def test_save_every_n_train_steps(self) -> None:
input_dim = 2
Expand Down Expand Up @@ -302,17 +301,15 @@ def test_save_on_train_end(self) -> None:
)
self.assertTrue(os.path.exists(os.path.join(temp_dir, expected_path)))

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
distributed_available, reason="Torch distributed is needed to run"
)
def test_directory_sync_collective(self) -> None:
spawn_multi_process(
2,
"gloo",
self._directory_sync_collective,
)
config = get_pet_launch_config(2)
launcher.elastic_launch(
config,
entrypoint=self._directory_sync_collective,
)()

@staticmethod
def _directory_sync_collective() -> None:
Expand All @@ -332,62 +329,6 @@ def _directory_sync_collective() -> None:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp,
)

@staticmethod
def _save_restore_fsdp() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
replicated=["**"],
)
temp_dir = snapshot_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
snapshot_cb.restore(ckpt_path, my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

def test_saver_invalid_args(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
with self.assertRaisesRegex(
Expand Down Expand Up @@ -427,10 +368,8 @@ def test_latest_checkpoint_path(self) -> None:
os.mkdir(path_4)
self.assertEqual(get_latest_checkpoint_path(temp_dir), path_3)

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
distributed_available, reason="Torch distributed is needed to run"
)
def test_latest_checkpoint_path_distributed(self) -> None:
config = get_pet_launch_config(2)
Expand Down Expand Up @@ -474,17 +413,12 @@ def _latest_checkpoint_path_distributed() -> None:
if is_rank0:
shutil.rmtree(temp_dir) # delete temp directory

# pyre-fixme[56]: Pyre was not able to infer the type of argument
# `torch.distributed.is_available()` to decorator factory `unittest.skipUnless`.
@unittest.skipUnless(
torch.distributed.is_available(), reason="Torch distributed is needed to run"
distributed_available, reason="Torch distributed is needed to run"
)
def test_save_restore_ddp(self) -> None:
spawn_multi_process(
2,
"gloo",
self._save_restore_ddp,
)
config = get_pet_launch_config(2)
launcher.elastic_launch(config, entrypoint=self._save_restore_ddp)()

@staticmethod
def _save_restore_ddp() -> None:
Expand Down
78 changes: 78 additions & 0 deletions tests/framework/callbacks/test_torchsnapshot_saver_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import tempfile
import unittest

import torch

from torchtnt.framework._test_utils import DummyAutoUnit, generate_random_dataloader
from torchtnt.framework.callbacks.torchsnapshot_saver import TorchSnapshotSaver
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.test_utils import spawn_multi_process


class TorchSnapshotSaverGPUTest(unittest.TestCase):
cuda_available: bool = torch.cuda.is_available()
distributed_available: bool = torch.distributed.is_available()

@unittest.skipUnless(
condition=distributed_available, reason="Torch distributed is needed to run"
)
@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_save_restore_fsdp(self) -> None:
spawn_multi_process(
2,
"nccl",
self._save_restore_fsdp,
)

@staticmethod
def _save_restore_fsdp() -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
save_every_n_epochs = 1

my_unit = DummyAutoUnit(module=torch.nn.Linear(input_dim, 2), strategy="fsdp")
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
if get_global_rank() == 0:
temp_dir = tempfile.mkdtemp()
else:
temp_dir = ""

snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_epochs=save_every_n_epochs,
replicated=["**"],
)
temp_dir = snapshot_cb.dirpath
train(my_unit, dataloader, max_epochs=max_epochs, callbacks=[snapshot_cb])

tc = unittest.TestCase()
try:
my_new_unit = DummyAutoUnit(
module=torch.nn.Linear(input_dim, 2), strategy="fsdp"
)
tc.assertNotEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
# get latest checkpoint
ckpt_path = os.path.join(temp_dir, f"epoch_{max_epochs}_step_10")
snapshot_cb.restore(ckpt_path, my_new_unit)
tc.assertEqual(
my_new_unit.optimizer.state_dict(), my_unit.optimizer.state_dict()
)
finally:
if get_global_rank() == 0:
shutil.rmtree(temp_dir) # delete temp directory

0 comments on commit 8554fae

Please sign in to comment.