Skip to content

Commit

Permalink
UT RE - enable for timer util (#506)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #506

# Context
[Describe motivations and existing situation that led to creating this diff. Don't be cheap with context, it is the basis for a good code review.]

# This diff
[List all the changes that this diff introduces and explain the ones that are not trivial. Give directions for the reviewer if needed.]

# What’s next
[If this diff is part of a stack or if it has direct continuation in a future diff, share these plans with your reviewer.]

Differential Revision: D48475432

fbshipit-source-id: 072f66e215fc261871f5bcbf90f93120f50fa5e5
  • Loading branch information
galrotem authored and facebook-github-bot committed Aug 21, 2023
1 parent 2bcc5ca commit c79d026
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 62 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pytest-cov
torchsnapshot-nightly
pyre-check
torchvision
expecttest
13 changes: 7 additions & 6 deletions tests/utils/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# LICENSE file in the root directory of this source tree.

import unittest
from unittest.mock import Mock, patch

import numpy as np

Expand Down Expand Up @@ -82,14 +83,14 @@ def test_seed_range(self) -> None:
# should not raise any exceptions
seed(42)

def test_deterministic_true(self) -> None:
@patch("torchtnt.utils.env._set_cudnn_determinstic_mode")
def test_deterministic_true(self, set_cudnn_determinstic_mode_mock: Mock) -> None:
for det_debug_mode, det_debug_mode_str in [(1, "warn"), (2, "error")]:
warn_only = det_debug_mode == 1
for deterministic in (det_debug_mode, det_debug_mode_str):
with self.subTest(deterministic=deterministic):
seed(42, deterministic=deterministic)
self.assertTrue(torch.backends.cudnn.deterministic)
self.assertFalse(torch.backends.cudnn.benchmark)
set_cudnn_determinstic_mode_mock.assert_called_with(True)
self.assertEqual(
det_debug_mode, torch.get_deterministic_debug_mode()
)
Expand All @@ -98,12 +99,12 @@ def test_deterministic_true(self) -> None:
warn_only, torch.is_deterministic_algorithms_warn_only_enabled()
)

def test_deterministic_false(self) -> None:
@patch("torchtnt.utils.env._set_cudnn_determinstic_mode")
def test_deterministic_false(self, set_cudnn_determinstic_mode_mock: Mock) -> None:
for deterministic in ("default", 0):
with self.subTest(deterministic=deterministic):
seed(42, deterministic=deterministic)
self.assertFalse(torch.backends.cudnn.deterministic)
self.assertTrue(torch.backends.cudnn.benchmark)
set_cudnn_determinstic_mode_mock.assert_called_with(False)
self.assertEqual(0, torch.get_deterministic_debug_mode())
self.assertFalse(torch.are_deterministic_algorithms_enabled())
self.assertFalse(torch.is_deterministic_algorithms_warn_only_enabled())
Expand Down
65 changes: 17 additions & 48 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
from datetime import timedelta
from random import random
from unittest import mock
from unittest.mock import patch

import torch
import torch.distributed as dist
import torch.distributed.launcher as launcher
from torchtnt.utils.test_utils import get_pet_launch_config
from torch.testing._internal.common_distributed import spawn_threads_and_init_comms
from torchtnt.utils.timer import (
FullSyncPeriodicTimer,
get_durations_histogram,
Expand Down Expand Up @@ -44,7 +44,8 @@ def test_timer_verbose(self) -> None:
mock_info.assert_called_once()
self.assertTrue("Testing timer took" in mock_info.call_args.args[0])

def test_timer_context_manager(self) -> None:
@patch("torch.cuda.synchronize")
def test_timer_context_manager(self, _) -> None:
"""Test the context manager in the timer class"""

# Generate 3 intervals between 0.5 and 2 seconds
Expand Down Expand Up @@ -81,7 +82,8 @@ def test_timer_context_manager(self) -> None:
@unittest.skipUnless(
condition=torch.cuda.is_available(), reason="This test needs a GPU host to run."
)
def test_timer_synchronize(self) -> None:
@patch("torch.cuda.synchronize")
def test_timer_synchronize(self, mock_synchornize) -> None:
"""Make sure that torch.cuda.synchronize() is called when GPU is present."""

start_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -97,11 +99,7 @@ def test_timer_synchronize(self) -> None:

# torch.cuda.synchronize() has to be called to compute the elapsed time.
# Otherwise, there will be runtime error.
elapsed_time_ms = start_event.elapsed_time(end_event)
self.assert_within_tolerance(timer.recorded_durations["action_1"][0], 0.5)
self.assert_within_tolerance(
timer.recorded_durations["action_1"][0], elapsed_time_ms / 1000
)
self.assertEqual(mock_synchornize.call_count, 2)

def test_get_timer_summary(self) -> None:
"""Test the get_timer_summary function"""
Expand Down Expand Up @@ -166,7 +164,6 @@ def test_get_synced_durations_histogram(self) -> None:

@staticmethod
def _get_synced_durations_histogram_multi_process() -> None:
dist.init_process_group("gloo")
rank = dist.get_rank()
if rank == 0:
recorded_durations = {
Expand Down Expand Up @@ -218,11 +215,9 @@ def _get_synced_durations_histogram_multi_process() -> None:
condition=dist.is_available(),
reason="This test should only run if torch.distributed is available.",
)
@spawn_threads_and_init_comms(world_size=2)
def test_get_synced_durations_histogram_multi_process(self) -> None:
config = get_pet_launch_config(2)
launcher.elastic_launch(
config, entrypoint=self._get_synced_durations_histogram_multi_process
)()
self._get_synced_durations_histogram_multi_process()

def test_timer_fn(self) -> None:
with log_elapsed_time("test"):
Expand All @@ -232,58 +227,32 @@ def test_timer_fn(self) -> None:
class FullSyncPeriodicTimerTest(unittest.TestCase):
@classmethod
def _full_sync_worker_without_timeout(cls) -> bool:
dist.init_process_group("gloo")
process_group = dist.group.WORLD
interval_threshold = timedelta(seconds=5)
fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group)
return fsp_timer.check()

@classmethod
def _full_sync_worker_with_timeout(cls, timeout: int) -> bool:
dist.init_process_group("gloo")
process_group = dist.group.WORLD
interval_threshold = timedelta(seconds=5)
fsp_timer = FullSyncPeriodicTimer(interval_threshold, process_group)
time.sleep(timeout)
fsp_timer.check() # self._prev_work is assigned, next time the check is called, it will be executed
return fsp_timer.check() # Since 8>5, we will see flag set to True

@spawn_threads_and_init_comms(world_size=2)
def test_full_sync_pt_multi_process_check_false(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check if time diff > interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_without_timeout
)()
result = self._full_sync_worker_without_timeout()
# Both processes should return False
self.assertFalse(result[0])
self.assertFalse(result[1])

def test_full_sync_pt_multi_process_check_true(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check time diff > interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(8)
# Both processes should return True
self.assertTrue(result[0])
self.assertTrue(result[1])
self.assertFalse(result)

@spawn_threads_and_init_comms(world_size=2)
def test_full_sync_pt_multi_process_edgecase(self) -> None:
config = get_pet_launch_config(2)
# Launch 2 worker processes. Each will check time diff >= interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(5)

result = self._full_sync_worker_with_timeout(5)
# Both processes should return True
self.assertTrue(result[0])
self.assertTrue(result[1])

# Launch 2 worker processes. Each will check time diff >= interval threshold
result = launcher.elastic_launch(
config, entrypoint=self._full_sync_worker_with_timeout
)(4)
self.assertTrue(result)

result = self._full_sync_worker_with_timeout(4)
# Both processes should return False
self.assertFalse(result[0])
self.assertFalse(result[1])
self.assertFalse(result)
17 changes: 9 additions & 8 deletions torchtnt/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,12 @@ def seed(seed: int, deterministic: Optional[Union[str, int]] = None) -> None:
_log.debug(f"Setting deterministic debug mode to {deterministic}")
torch.set_deterministic_debug_mode(deterministic)
deterministic_debug_mode = torch.get_deterministic_debug_mode()
if deterministic_debug_mode == 0:
_log.debug("Disabling cuDNN deterministic mode")
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
else:
_log.debug("Enabling cuDNN deterministic mode")
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
_set_cudnn_determinstic_mode(deterministic_debug_mode != 0)


def _set_cudnn_determinstic_mode(is_determinstic: bool = True) -> None:
_log.debug(
f"{'Enabling' if is_determinstic else 'Disabling'} cuDNN deterministic mode"
)
torch.backends.cudnn.deterministic = is_determinstic
torch.backends.cudnn.benchmark = not is_determinstic

0 comments on commit c79d026

Please sign in to comment.