Skip to content

Commit

Permalink
UT RE - enable for timer util (pytorch#506)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#506

# Context
[Describe motivations and existing situation that led to creating this diff. Don't be cheap with context, it is the basis for a good code review.]

# This diff
[List all the changes that this diff introduces and explain the ones that are not trivial. Give directions for the reviewer if needed.]

# What’s next
[If this diff is part of a stack or if it has direct continuation in a future diff, share these plans with your reviewer.]

Differential Revision: D48475432

fbshipit-source-id: 7d4e4f0b33b85bc215534e142628c111ef0edb7d
  • Loading branch information
galrotem authored and facebook-github-bot committed Aug 19, 2023
1 parent 95f6944 commit 590e13d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 48 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pytest-cov
torchsnapshot-nightly
pyre-check
torchvision
expecttest
6 changes: 6 additions & 0 deletions tests/utils/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def test_seed_range(self) -> None:
seed(42)

def test_deterministic_true(self) -> None:
# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
torch.backends.__allow_nonbracketed_mutation_flag = True

for det_debug_mode, det_debug_mode_str in [(1, "warn"), (2, "error")]:
warn_only = det_debug_mode == 1
for deterministic in (det_debug_mode, det_debug_mode_str):
Expand All @@ -99,6 +102,9 @@ def test_deterministic_true(self) -> None:
)

def test_deterministic_false(self) -> None:
# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
torch.backends.__allow_nonbracketed_mutation_flag = True

for deterministic in ("default", 0):
with self.subTest(deterministic=deterministic):
seed(42, deterministic=deterministic)
Expand Down
65 changes: 17 additions & 48 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from datetime import timedelta
from random import random
from unittest import mock
from unittest.mock import patch

import torch
import torch.distributed as dist
import torch.distributed.launcher as launcher
from torchtnt.utils.test_utils import get_pet_launch_config
from torch.testing._internal.common_distributed import spawn_threads_and_init_comms
from torchtnt.utils.timer import (
FullSyncPeriodicTimer,
get_durations_histogram,
Expand Down Expand Up @@ -44,7 +44,8 @@ def test_timer_verbose(self) -> None:
mock_info.assert_called_once()
self.assertTrue("Testing timer took" in mock_info.call_args.args[0])

def test_timer_context_manager(self) -> None:
@patch("torch.cuda.synchronize")
def test_timer_context_manager(self, _) -> None:
"""Test the context manager in the timer class"""

# Generate 3 intervals between 0.5 and 2 seconds
Expand Down Expand Up @@ -81,7 +82,8 @@ def test_timer_context_manager(self) -> None:
@unittest.skipUnless(
condition=torch.cuda.is_available(), reason="This test needs a GPU host to run."
)
def test_timer_synchronize(self) -> None:
@patch("torch.cuda.synchronize")
def test_timer_synchronize(self, mock_synchornize) -> None:
"""Make sure that torch.cuda.synchronize() is called when GPU is present."""

start_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -97,11 +99,7 @@ def test_timer_synchronize(self) -> None:

# torch.cuda.synchronize() has to be called to compute the elapsed time.
# Otherwise, there will be runtime error.
elapsed_time_ms = start_event.elapsed_time(end_event)
self.assert_within_tolerance(timer.recorded_durations["action_1"][0], 0.5)
self.assert_within_tolerance(
timer.recorded_durations["action_1"][0], elapsed_time_ms / 1000
)
self.assertEqual(mock_synchornize.call_count, 2)

def test_get_timer_summary(self) -> None:
"""Test the get_timer_summary function"""
Expand Down Expand Up @@ -166,7 +164,6 @@ def test_get_synced_durations_histogram(self) -> None:

@staticmethod
def _get_synced_durations_histogram_multi_process() -> None:
dist.init_process_group("gloo")
rank = dist.get_rank()
if rank == 0:
recorded_durations = {
Expand Down Expand Up @@ -218,11 +215,9 @@ def _get_synced_durations_histogram_multi_process() -> None:
condition=dist.is_available(),
reason="This test should only run if torch.distributed is available.",
)
@spawn_threads_and_init_comms(world_size=2)
def test_get_synced_durations_histogram_multi_process(self) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(
config, entrypoint=self._get_synced_durations_histogram_multi_process
)()
self._get_synced_durations_histogram_multi_process()

def test_timer_fn(self) -> None:
with log_elapsed_time("test"):
Expand All @@ -232,58 +227,32 @@ def test_timer_fn(self) -> None:
class FullSyncPeriodicTimerTest(unittest.TestCase):
@classmethod
def _full_sync_worker_without_timeout(cls) -> bool:
dist.init_process_group("gloo")
process_group = dist.group.WORLD
interval_threshold = timedelta(seconds=5)
fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group)
return fsp_timer.check()

@classmethod
def _full_sync_worker_with_timeout(cls, timeout: int) -> bool:
dist.init_process_group("gloo")
process_group = dist.group.WORLD
interval_threshold = timedelta(seconds=5)
fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group)
time.sleep(timeout)
fsp_timer.check() # self._prev_work is assigned, next time the check is called, it will be executed
return fsp_timer.check() # Since 8>5, we will see flag set to True

@spawn_threads_and_init_comms(world_size=2)
def test_full_sync_pt_multi_process_check_false(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check if time diff > interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_without_timeout
)()
result = self._full_sync_worker_without_timeout()
# Both processes should return False
self.assertFalse(result[0])
self.assertFalse(result[1])

def test_full_sync_pt_multi_process_check_true(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check time diff > interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(8)
# Both processes should return True
self.assertTrue(result[0])
self.assertTrue(result[1])
self.assertFalse(result)

@spawn_threads_and_init_comms(world_size=2)
def test_full_sync_pt_multi_process_edgecase(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check time diff >= interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(5)

result = self._full_sync_worker_with_timeout(5)
# Both processes should return True
self.assertTrue(result[0])
self.assertTrue(result[1])

# Launch 2 worker processes. Each will check time diff >= interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(4)
self.assertTrue(result)

result = self._full_sync_worker_with_timeout(4)
# Both processes should return False
self.assertFalse(result[0])
self.assertFalse(result[1])
self.assertFalse(result)

0 comments on commit 590e13d

Please sign in to comment.