From efab62028559aa5161cc44f40b230a9ee200d048 Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Fri, 18 Aug 2023 21:08:09 -0700 Subject: [PATCH] UT RE - enable for timer util (#506) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/506 # Context [Describe motivations and existing situation that led to creating this diff. Don't be cheap with context, it is the basis for a good code review.] # This diff [List all the changes that this diff introduces and explain the ones that are not trivial. Give directions for the reviewer if needed.] # What’s next [If this diff is part of a stack or if it has direct continuation in a future diff, share these plans with your reviewer.] Differential Revision: D48475432 fbshipit-source-id: 607d36d9596b964e59143374f8b92ae4ca9f9b8b --- dev-requirements.txt | 1 + tests/utils/test_env.py | 4 +++ tests/utils/test_timer.py | 65 ++++++++++----------------------------- 3 files changed, 22 insertions(+), 48 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 202f51218b..60cc076eda 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -4,3 +4,4 @@ pytest-cov torchsnapshot-nightly pyre-check torchvision +expecttest diff --git a/tests/utils/test_env.py b/tests/utils/test_env.py index 144a4b9403..5b1378d5f8 100644 --- a/tests/utils/test_env.py +++ b/tests/utils/test_env.py @@ -17,6 +17,10 @@ class EnvTest(unittest.TestCase): + def setUp(self) -> None: + # workaround "RuntimeError: not allowed to set torch.backends.cudnn flags" + torch.backends.__allow_nonbracketed_mutation_flag = True + def test_init_from_env(self) -> None: """Integration test to confirm consistency across device initialization utilities.""" device = init_from_env() diff --git a/tests/utils/test_timer.py b/tests/utils/test_timer.py index 8505671201..9f27973b98 100644 --- a/tests/utils/test_timer.py +++ b/tests/utils/test_timer.py @@ -11,11 +11,11 @@ from datetime import timedelta from random import random from unittest import mock +from unittest.mock import patch import torch import torch.distributed as dist -import torch.distributed.launcher as launcher -from torchtnt.utils.test_utils import get_pet_launch_config +from torch.testing._internal.common_distributed import spawn_threads_and_init_comms from torchtnt.utils.timer import ( FullSyncPeriodicTimer, get_durations_histogram, @@ -44,7 +44,8 @@ def test_timer_verbose(self) -> None: mock_info.assert_called_once() self.assertTrue("Testing timer took" in mock_info.call_args.args[0]) - def test_timer_context_manager(self) -> None: + @patch("torch.cuda.synchronize") + def test_timer_context_manager(self, _) -> None: """Test the context manager in the timer class""" # Generate 3 intervals between 0.5 and 2 seconds @@ -81,7 +82,8 @@ def test_timer_context_manager(self) -> None: @unittest.skipUnless( condition=torch.cuda.is_available(), reason="This test needs a GPU host to run." ) - def test_timer_synchronize(self) -> None: + @patch("torch.cuda.synchronize") + def test_timer_synchronize(self, mock_synchornize) -> None: """Make sure that torch.cuda.synchronize() is called when GPU is present.""" start_event = torch.cuda.Event(enable_timing=True) @@ -97,11 +99,7 @@ def test_timer_synchronize(self) -> None: # torch.cuda.synchronize() has to be called to compute the elapsed time. # Otherwise, there will be runtime error. - elapsed_time_ms = start_event.elapsed_time(end_event) - self.assert_within_tolerance(timer.recorded_durations["action_1"][0], 0.5) - self.assert_within_tolerance( - timer.recorded_durations["action_1"][0], elapsed_time_ms / 1000 - ) + self.assertEqual(mock_synchornize.call_count, 2) def test_get_timer_summary(self) -> None: """Test the get_timer_summary function""" @@ -166,7 +164,6 @@ def test_get_synced_durations_histogram(self) -> None: @staticmethod def _get_synced_durations_histogram_multi_process() -> None: - dist.init_process_group("gloo") rank = dist.get_rank() if rank == 0: recorded_durations = { @@ -218,11 +215,9 @@ def _get_synced_durations_histogram_multi_process() -> None: condition=dist.is_available(), reason="This test should only run if torch.distributed is available.", ) + @spawn_threads_and_init_comms(world_size=2) def test_get_synced_durations_histogram_multi_process(self) -> None: - config = get_pet_launch_config(2) - launcher.elastic_launch( - config, entrypoint=self._get_synced_durations_histogram_multi_process - )() + self._get_synced_durations_histogram_multi_process() def test_timer_fn(self) -> None: with log_elapsed_time("test"): @@ -232,7 +227,6 @@ def test_timer_fn(self) -> None: class FullSyncPeriodicTimerTest(unittest.TestCase): @classmethod def _full_sync_worker_without_timeout(cls) -> bool: - dist.init_process_group("gloo") process_group = dist.group.WORLD interval_threshold = timedelta(seconds=5) fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group) @@ -240,7 +234,6 @@ def _full_sync_worker_without_timeout(cls) -> bool: @classmethod def _full_sync_worker_with_timeout(cls, timeout: int) -> bool: - dist.init_process_group("gloo") process_group = dist.group.WORLD interval_threshold = timedelta(seconds=5) fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group) @@ -248,42 +241,18 @@ def _full_sync_worker_with_timeout(cls, timeout: int) -> bool: fsp_timer.check() # self._prev_work is assigned, next time the check is called, it will be executed return fsp_timer.check() # Since 8>5, we will see flag set to True + @spawn_threads_and_init_comms(world_size=2) def test_full_sync_pt_multi_process_check_false(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check if time diff > interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_without_timeout - )() + result = self._full_sync_worker_without_timeout() # Both processes should return False - self.assertFalse(result[0]) - self.assertFalse(result[1]) - - def test_full_sync_pt_multi_process_check_true(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check time diff > interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(8) - # Both processes should return True - self.assertTrue(result[0]) - self.assertTrue(result[1]) + self.assertFalse(result) + @spawn_threads_and_init_comms(world_size=2) def test_full_sync_pt_multi_process_edgecase(self) -> None: - config = get_pet_launch_config(2) - # Launch 2 worker processes. Each will check time diff >= interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(5) - + result = self._full_sync_worker_with_timeout(5) # Both processes should return True - self.assertTrue(result[0]) - self.assertTrue(result[1]) - - # Launch 2 worker processes. Each will check time diff >= interval threshold - result = launcher.elastic_launch( - config, entrypoint=self._full_sync_worker_with_timeout - )(4) + self.assertTrue(result) + result = self._full_sync_worker_with_timeout(4) # Both processes should return False - self.assertFalse(result[0]) - self.assertFalse(result[1]) + self.assertFalse(result)