diff --git a/dev-requirements.txt b/dev-requirements.txt index 202f51218b..60cc076eda 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -4,3 +4,4 @@ pytest-cov torchsnapshot-nightly pyre-check torchvision +expecttest diff --git a/tests/utils/test_env.py b/tests/utils/test_env.py index 144a4b9403..aae7d00ed2 100644 --- a/tests/utils/test_env.py +++ b/tests/utils/test_env.py @@ -6,6 +6,7 @@ # LICENSE file in the root directory of this source tree. import unittest +from unittest.mock import Mock, patch import numpy as np @@ -82,14 +83,14 @@ def test_seed_range(self) -> None: # should not raise any exceptions seed(42) - def test_deterministic_true(self) -> None: + @patch("torchtnt.utils.env._set_cudnn_determinstic_mode") + def test_deterministic_true(self, set_cudnn_determinstic_mode_mock: Mock) -> None: for det_debug_mode, det_debug_mode_str in [(1, "warn"), (2, "error")]: warn_only = det_debug_mode == 1 for deterministic in (det_debug_mode, det_debug_mode_str): with self.subTest(deterministic=deterministic): seed(42, deterministic=deterministic) - self.assertTrue(torch.backends.cudnn.deterministic) - self.assertFalse(torch.backends.cudnn.benchmark) + set_cudnn_determinstic_mode_mock.assert_called_with(True) self.assertEqual( det_debug_mode, torch.get_deterministic_debug_mode() ) @@ -98,12 +99,12 @@ def test_deterministic_true(self) -> None: warn_only, torch.is_deterministic_algorithms_warn_only_enabled() ) - def test_deterministic_false(self) -> None: + @patch("torchtnt.utils.env._set_cudnn_determinstic_mode") + def test_deterministic_false(self, set_cudnn_determinstic_mode_mock: Mock) -> None: for deterministic in ("default", 0): with self.subTest(deterministic=deterministic): seed(42, deterministic=deterministic) - self.assertFalse(torch.backends.cudnn.deterministic) - self.assertTrue(torch.backends.cudnn.benchmark) + set_cudnn_determinstic_mode_mock.assert_called_with(False) self.assertEqual(0, torch.get_deterministic_debug_mode()) self.assertFalse(torch.are_deterministic_algorithms_enabled()) self.assertFalse(torch.is_deterministic_algorithms_warn_only_enabled()) diff --git a/tests/utils/test_timer.py b/tests/utils/test_timer.py index 8505671201..9f27973b98 100644 --- a/tests/utils/test_timer.py +++ b/tests/utils/test_timer.py @@ -11,11 +11,11 @@ from datetime import timedelta from random import random from unittest import mock +from unittest.mock import patch import torch import torch.distributed as dist -import torch.distributed.launcher as launcher -from torchtnt.utils.test_utils import get_pet_launch_config +from torch.testing._internal.common_distributed import spawn_threads_and_init_comms from torchtnt.utils.timer import ( FullSyncPeriodicTimer, get_durations_histogram, @@ -44,7 +44,8 @@ def test_timer_verbose(self) -> None: mock_info.assert_called_once() self.assertTrue("Testing timer took" in mock_info.call_args.args[0]) - def test_timer_context_manager(self) -> None: + @patch("torch.cuda.synchronize") + def test_timer_context_manager(self, _) -> None: """Test the context manager in the timer class""" # Generate 3 intervals between 0.5 and 2 seconds @@ -81,7 +82,8 @@ def test_timer_context_manager(self) -> None: @unittest.skipUnless( condition=torch.cuda.is_available(), reason="This test needs a GPU host to run." ) - def test_timer_synchronize(self) -> None: + @patch("torch.cuda.synchronize") + def test_timer_synchronize(self, mock_synchornize) -> None: """Make sure that torch.cuda.synchronize() is called when GPU is present.""" start_event = torch.cuda.Event(enable_timing=True) @@ -97,11 +99,7 @@ def test_timer_synchronize(self) -> None: # torch.cuda.synchronize() has to be called to compute the elapsed time. # Otherwise, there will be runtime error. - elapsed_time_ms = start_event.elapsed_time(end_event) - self.assert_within_tolerance(timer.recorded_durations["action_1"][0], 0.5) - self.assert_within_tolerance( - timer.recorded_durations["action_1"][0], elapsed_time_ms / 1000 - ) + self.assertEqual(mock_synchornize.call_count, 2) def test_get_timer_summary(self) -> None: """Test the get_timer_summary function""" @@ -166,7 +164,6 @@ def test_get_synced_durations_histogram(self) -> None: @staticmethod def _get_synced_durations_histogram_multi_process() -> None: - dist.init_process_group("gloo") rank = dist.get_rank() if rank == 0: recorded_durations = { @@ -218,11 +215,9 @@ def _get_synced_durations_histogram_multi_process() -> None: condition=dist.is_available(), reason="This test should only run if torch.distributed is available.", ) + @spawn_threads_and_init_comms(world_size=2) def test_get_synced_durations_histogram_multi_process(self) -> None: - config = get_pet_launch_config(2) - launcher.elastic_launch( - config, entrypoint=self._get_synced_durations_histogram_multi_process - )() + self._get_synced_durations_histogram_multi_process() def test_timer_fn(self) -> None: with log_elapsed_time("test"): @@ -232,7 +227,6 @@ def test_timer_fn(self) -> None: class FullSyncPeriodicTimerTest(unittest.TestCase): @classmethod def _full_sync_worker_without_timeout(cls) -> bool: - dist.init_process_group("gloo") process_group = dist.group.WORLD interval_threshold = timedelta(seconds=5) fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group) @@ -240,7 +234,6 @@ def _full_sync_worker_without_timeout(cls) -> bool: @classmethod def _full_sync_worker_with_timeout(cls, timeout: int) -> bool: - dist.init_process_group("gloo") process_group = dist.group.WORLD interval_threshold = timedelta(seconds=5) fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group) @@ -248,42 +241,18 @@ def _full_sync_worker_with_timeout(cls, timeout: int) -> bool: fsp_timer.check() # self._prev_work is assigned, next time the check is called, it will be executed return fsp_timer.check() # Since 8>5, we will see flag set to True + @spawn_threads_and_init_comms(world_size=2) def test_full_sync_pt_multi_process_check_false(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check if time diff > interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_without_timeout - )() + result = self._full_sync_worker_without_timeout() # Both processes should return False - self.assertFalse(result[0]) - self.assertFalse(result[1]) - - def test_full_sync_pt_multi_process_check_true(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check time diff > interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(8) - # Both processes should return True - self.assertTrue(result[0]) - self.assertTrue(result[1]) + self.assertFalse(result) + @spawn_threads_and_init_comms(world_size=2) def test_full_sync_pt_multi_process_edgecase(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check time diff >= interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(5) - + result = self._full_sync_worker_with_timeout(5) # Both processes should return True - self.assertTrue(result[0]) - self.assertTrue(result[1]) - - # Launch 2 worker processes. Each will check time diff >= interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(4) + self.assertTrue(result) + result = self._full_sync_worker_with_timeout(4) # Both processes should return False - self.assertFalse(result[0]) - self.assertFalse(result[1]) + self.assertFalse(result) diff --git a/torchtnt/utils/env.py b/torchtnt/utils/env.py index d8ec383a6c..8fe5e3c948 100644 --- a/torchtnt/utils/env.py +++ b/torchtnt/utils/env.py @@ -143,11 +143,12 @@ def seed(seed: int, deterministic: Optional[Union[str, int]] = None) -> None: _log.debug(f"Setting deterministic debug mode to {deterministic}") torch.set_deterministic_debug_mode(deterministic) deterministic_debug_mode = torch.get_deterministic_debug_mode() - if deterministic_debug_mode == 0: - _log.debug("Disabling cuDNN deterministic mode") - torch.backends.cudnn.deterministic = False - torch.backends.cudnn.benchmark = True - else: - _log.debug("Enabling cuDNN deterministic mode") - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + _set_cudnn_determinstic_mode(deterministic_debug_mode != 0) + + +def _set_cudnn_determinstic_mode(is_determinstic: bool = True) -> None: + _log.debug( + f"{'Enabling' if is_determinstic else 'Disabling'} cuDNN deterministic mode" + ) + torch.backends.cudnn.deterministic = is_determinstic + torch.backends.cudnn.benchmark = not is_determinstic