Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UT RE - enable for timer util #506

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pytest-cov
torchsnapshot-nightly
pyre-check
torchvision
expecttest
13 changes: 7 additions & 6 deletions tests/utils/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from unittest.mock import Mock, patch

import numpy as np

Expand Down Expand Up @@ -82,14 +83,14 @@ def test_seed_range(self) -> None:
# should not raise any exceptions
seed(42)

def test_deterministic_true(self) -> None:
@patch("torchtnt.utils.env._set_cudnn_determinstic_mode")
def test_deterministic_true(self, set_cudnn_determinstic_mode_mock: Mock) -> None:
for det_debug_mode, det_debug_mode_str in [(1, "warn"), (2, "error")]:
warn_only = det_debug_mode == 1
for deterministic in (det_debug_mode, det_debug_mode_str):
with self.subTest(deterministic=deterministic):
seed(42, deterministic=deterministic)
self.assertTrue(torch.backends.cudnn.deterministic)
self.assertFalse(torch.backends.cudnn.benchmark)
set_cudnn_determinstic_mode_mock.assert_called_with(True)
self.assertEqual(
det_debug_mode, torch.get_deterministic_debug_mode()
)
Expand All @@ -98,12 +99,12 @@ def test_deterministic_true(self) -> None:
warn_only, torch.is_deterministic_algorithms_warn_only_enabled()
)

def test_deterministic_false(self) -> None:
@patch("torchtnt.utils.env._set_cudnn_determinstic_mode")
def test_deterministic_false(self, set_cudnn_determinstic_mode_mock: Mock) -> None:
for deterministic in ("default", 0):
with self.subTest(deterministic=deterministic):
seed(42, deterministic=deterministic)
self.assertFalse(torch.backends.cudnn.deterministic)
self.assertTrue(torch.backends.cudnn.benchmark)
set_cudnn_determinstic_mode_mock.assert_called_with(False)
self.assertEqual(0, torch.get_deterministic_debug_mode())
self.assertFalse(torch.are_deterministic_algorithms_enabled())
self.assertFalse(torch.is_deterministic_algorithms_warn_only_enabled())
Expand Down
65 changes: 17 additions & 48 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from datetime import timedelta
from random import random
from unittest import mock
from unittest.mock import patch

import torch
import torch.distributed as dist
import torch.distributed.launcher as launcher
from torchtnt.utils.test_utils import get_pet_launch_config
from torch.testing._internal.common_distributed import spawn_threads_and_init_comms
from torchtnt.utils.timer import (
FullSyncPeriodicTimer,
get_durations_histogram,
Expand Down Expand Up @@ -44,7 +44,8 @@ def test_timer_verbose(self) -> None:
mock_info.assert_called_once()
self.assertTrue("Testing timer took" in mock_info.call_args.args[0])

def test_timer_context_manager(self) -> None:
@patch("torch.cuda.synchronize")
def test_timer_context_manager(self, _) -> None:
"""Test the context manager in the timer class"""

# Generate 3 intervals between 0.5 and 2 seconds
Expand Down Expand Up @@ -81,7 +82,8 @@ def test_timer_context_manager(self) -> None:
@unittest.skipUnless(
condition=torch.cuda.is_available(), reason="This test needs a GPU host to run."
)
def test_timer_synchronize(self) -> None:
@patch("torch.cuda.synchronize")
def test_timer_synchronize(self, mock_synchornize) -> None:
"""Make sure that torch.cuda.synchronize() is called when GPU is present."""

start_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -97,11 +99,7 @@ def test_timer_synchronize(self) -> None:

# torch.cuda.synchronize() has to be called to compute the elapsed time.
# Otherwise, there will be runtime error.
elapsed_time_ms = start_event.elapsed_time(end_event)
self.assert_within_tolerance(timer.recorded_durations["action_1"][0], 0.5)
self.assert_within_tolerance(
timer.recorded_durations["action_1"][0], elapsed_time_ms / 1000
)
self.assertEqual(mock_synchornize.call_count, 2)

def test_get_timer_summary(self) -> None:
"""Test the get_timer_summary function"""
Expand Down Expand Up @@ -166,7 +164,6 @@ def test_get_synced_durations_histogram(self) -> None:

@staticmethod
def _get_synced_durations_histogram_multi_process() -> None:
dist.init_process_group("gloo")
rank = dist.get_rank()
if rank == 0:
recorded_durations = {
Expand Down Expand Up @@ -218,11 +215,9 @@ def _get_synced_durations_histogram_multi_process() -> None:
condition=dist.is_available(),
reason="This test should only run if torch.distributed is available.",
)
@spawn_threads_and_init_comms(world_size=2)
def test_get_synced_durations_histogram_multi_process(self) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(
config, entrypoint=self._get_synced_durations_histogram_multi_process
)()
self._get_synced_durations_histogram_multi_process()

def test_timer_fn(self) -> None:
with log_elapsed_time("test"):
Expand All @@ -232,58 +227,32 @@ def test_timer_fn(self) -> None:
class FullSyncPeriodicTimerTest(unittest.TestCase):
@classmethod
def _full_sync_worker_without_timeout(cls) -> bool:
dist.init_process_group("gloo")
process_group = dist.group.WORLD
interval_threshold = timedelta(seconds=5)
fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group)
return fsp_timer.check()

@classmethod
def _full_sync_worker_with_timeout(cls, timeout: int) -> bool:
dist.init_process_group("gloo")
process_group = dist.group.WORLD
interval_threshold = timedelta(seconds=5)
fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group)
time.sleep(timeout)
fsp_timer.check() # self._prev_work is assigned, next time the check is called, it will be executed
return fsp_timer.check() # Since 8>5, we will see flag set to True

@spawn_threads_and_init_comms(world_size=2)
def test_full_sync_pt_multi_process_check_false(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check if time diff > interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_without_timeout
)()
result = self._full_sync_worker_without_timeout()
# Both processes should return False
self.assertFalse(result[0])
self.assertFalse(result[1])

def test_full_sync_pt_multi_process_check_true(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check time diff > interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(8)
# Both processes should return True
self.assertTrue(result[0])
self.assertTrue(result[1])
self.assertFalse(result)

@spawn_threads_and_init_comms(world_size=2)
def test_full_sync_pt_multi_process_edgecase(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check time diff >= interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(5)

result = self._full_sync_worker_with_timeout(5)
# Both processes should return True
self.assertTrue(result[0])
self.assertTrue(result[1])

# Launch 2 worker processes. Each will check time diff >= interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(4)
self.assertTrue(result)

result = self._full_sync_worker_with_timeout(4)
# Both processes should return False
self.assertFalse(result[0])
self.assertFalse(result[1])
self.assertFalse(result)
17 changes: 9 additions & 8 deletions torchtnt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,12 @@ def seed(seed: int, deterministic: Optional[Union[str, int]] = None) -> None:
_log.debug(f"Setting deterministic debug mode to {deterministic}")
torch.set_deterministic_debug_mode(deterministic)
deterministic_debug_mode = torch.get_deterministic_debug_mode()
if deterministic_debug_mode == 0:
_log.debug("Disabling cuDNN deterministic mode")
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
else:
_log.debug("Enabling cuDNN deterministic mode")
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
_set_cudnn_determinstic_mode(deterministic_debug_mode != 0)


def _set_cudnn_determinstic_mode(is_determinstic: bool = True) -> None:
_log.debug(
f"{'Enabling' if is_determinstic else 'Disabling'} cuDNN deterministic mode"
)
torch.backends.cudnn.deterministic = is_determinstic
torch.backends.cudnn.benchmark = not is_determinstic