diff --git a/tests/utils/test_timer.py b/tests/utils/test_timer.py index 8505671201..8e6a103560 100644 --- a/tests/utils/test_timer.py +++ b/tests/utils/test_timer.py @@ -11,11 +11,11 @@ from datetime import timedelta from random import random from unittest import mock +from unittest.mock import patch import torch import torch.distributed as dist -import torch.distributed.launcher as launcher -from torchtnt.utils.test_utils import get_pet_launch_config +from torch.testing._internal.common_distributed import spawn_threads_and_init_comms from torchtnt.utils.timer import ( FullSyncPeriodicTimer, get_durations_histogram, @@ -44,7 +44,8 @@ def test_timer_verbose(self) -> None: mock_info.assert_called_once() self.assertTrue("Testing timer took" in mock_info.call_args.args[0]) - def test_timer_context_manager(self) -> None: + @patch("torch.cuda.synchronize") + def test_timer_context_manager(self, _) -> None: """Test the context manager in the timer class""" # Generate 3 intervals between 0.5 and 2 seconds @@ -81,7 +82,8 @@ def test_timer_context_manager(self) -> None: @unittest.skipUnless( condition=torch.cuda.is_available(), reason="This test needs a GPU host to run." ) - def test_timer_synchronize(self) -> None: + @patch("torch.cuda.synchronize") + def test_timer_synchronize(self, mock_synchornize) -> None: """Make sure that torch.cuda.synchronize() is called when GPU is present.""" start_event = torch.cuda.Event(enable_timing=True) @@ -97,11 +99,7 @@ def test_timer_synchronize(self) -> None: # torch.cuda.synchronize() has to be called to compute the elapsed time. # Otherwise, there will be runtime error. - elapsed_time_ms = start_event.elapsed_time(end_event) - self.assert_within_tolerance(timer.recorded_durations["action_1"][0], 0.5) - self.assert_within_tolerance( - timer.recorded_durations["action_1"][0], elapsed_time_ms / 1000 - ) + self.assertEqual(mock_synchornize.call_count, 2) def test_get_timer_summary(self) -> None: """Test the get_timer_summary function""" @@ -166,7 +164,6 @@ def test_get_synced_durations_histogram(self) -> None: @staticmethod def _get_synced_durations_histogram_multi_process() -> None: - dist.init_process_group("gloo") rank = dist.get_rank() if rank == 0: recorded_durations = { @@ -218,11 +215,9 @@ def _get_synced_durations_histogram_multi_process() -> None: condition=dist.is_available(), reason="This test should only run if torch.distributed is available.", ) + @spawn_threads_and_init_comms(world_size=2) def test_get_synced_durations_histogram_multi_process(self) -> None: - config = get_pet_launch_config(2) - launcher.elastic_launch( - config, entrypoint=self._get_synced_durations_histogram_multi_process - )() + self._get_synced_durations_histogram_multi_process() def test_timer_fn(self) -> None: with log_elapsed_time("test"): @@ -232,7 +227,6 @@ def test_timer_fn(self) -> None: class FullSyncPeriodicTimerTest(unittest.TestCase): @classmethod def _full_sync_worker_without_timeout(cls) -> bool: - dist.init_process_group("gloo") process_group = dist.group.WORLD interval_threshold = timedelta(seconds=5) fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group) @@ -240,7 +234,6 @@ def _full_sync_worker_without_timeout(cls) -> bool: @classmethod def _full_sync_worker_with_timeout(cls, timeout: int) -> bool: - dist.init_process_group("gloo") process_group = dist.group.WORLD interval_threshold = timedelta(seconds=5) fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group) @@ -248,42 +241,24 @@ def _full_sync_worker_with_timeout(cls, timeout: int) -> bool: fsp_timer.check() # self._prev_work is assigned, next time the check is called, it will be executed return fsp_timer.check() # Since 8>5, we will see flag set to True + @spawn_threads_and_init_comms(world_size=2) def test_full_sync_pt_multi_process_check_false(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check if time diff > interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_without_timeout - )() + result = self._full_sync_worker_without_timeout() # Both processes should return False - self.assertFalse(result[0]) - self.assertFalse(result[1]) + self.assertFalse(result) + @spawn_threads_and_init_comms(world_size=2) def test_full_sync_pt_multi_process_check_true(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check time diff > interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(8) + result = self._full_sync_worker_with_timeout(8) # Both processes should return True - self.assertTrue(result[0]) - self.assertTrue(result[1]) + self.assertTrue(result) + @spawn_threads_and_init_comms(world_size=2) def test_full_sync_pt_multi_process_edgecase(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check time diff >= interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(5) - + result = self._full_sync_worker_with_timeout(5) # Both processes should return True - self.assertTrue(result[0]) - self.assertTrue(result[1]) - - # Launch 2 worker processes. Each will check time diff >= interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(4) + self.assertTrue(result) + result = self._full_sync_worker_with_timeout(4) # Both processes should return False - self.assertFalse(result[0]) - self.assertFalse(result[1]) + self.assertFalse(result)