diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index c734f20b05..2601a7c509 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import math import unittest from typing import Any, Tuple from unittest.mock import MagicMock, patch @@ -19,7 +18,6 @@ DYNAMO_AVAIL = True import torch._dynamo -from parameterized import parameterized from torch.distributed import GradBucket, launcher from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel as DDP @@ -168,44 +166,6 @@ def test_mixed_precision_invalid_str(self) -> None: precision="foo", ) - @parameterized.expand( - [ - [1], - [2], - [4], - [5], - ] - ) - def test_num_optimizer_steps_completed(self, gradient_accumulation_steps) -> None: - """ - Test the num_optimizer_steps_completed property of AutoUnit - """ - my_module = torch.nn.Linear(2, 2) - - input_dim = 2 - dataset_len = 16 - batch_size = 2 - max_epochs = 1 - - auto_unit = DummyAutoUnit( - module=my_module, - gradient_accumulation_steps=gradient_accumulation_steps, - ) - - expected_opt_steps_per_epoch = math.ceil( - dataset_len / batch_size / gradient_accumulation_steps - ) - - train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) - state = init_train_state(dataloader=train_dl, max_epochs=max_epochs) - train(state, auto_unit) - self.assertEqual( - auto_unit.num_optimizer_steps_completed, expected_opt_steps_per_epoch - ) - self.assertIn( - "_num_optimizer_steps_completed", auto_unit.tracked_misc_statefuls() - ) - def test_stochastic_weight_averaging_basic(self) -> None: """ Basic stochastic weight averaging tests diff --git a/tests/framework/test_utils.py b/tests/framework/test_utils.py index 22ddd0e49e..911289b316 100644 --- a/tests/framework/test_utils.py +++ b/tests/framework/test_utils.py @@ -48,7 +48,6 @@ _set_module_training_mode, _step_requires_iterator, get_current_progress, - StatefulInt, ) from torchtnt.utils.env import init_from_env from torchtnt.utils.lr_scheduler import TLRScheduler @@ -422,15 +421,6 @@ def _construct_optimizers() -> None: tc.assertTrue(isinstance(result["optim2"], torch.optim.Optimizer)) tc.assertTrue(isinstance(result["lr_scheduler"], TLRScheduler)) - def test_stateful_int(self) -> None: - v = StatefulInt(0) - v += 10 - v -= 2 - self.assertEqual(v.val, 8) - self.assertEqual(v.state_dict(), {"value": 8}) - v.load_state_dict({"value": -4}) - self.assertEqual(v.val, -4) - Batch = Tuple[torch.tensor, torch.tensor] diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index ecddbef353..6ae5875579 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -42,7 +42,6 @@ _get_timing_context, _is_fsdp_module, get_current_progress, - StatefulInt, ) from torchtnt.utils import ( init_from_env, @@ -507,7 +506,6 @@ def __init__( ) self.gradient_accumulation_steps = gradient_accumulation_steps - self._num_optimizer_steps_completed: StatefulInt = StatefulInt(0) self.detect_anomaly = detect_anomaly self.clip_grad_norm = clip_grad_norm @@ -758,8 +756,6 @@ def train_step( else: self.optimizer.step() - self._num_optimizer_steps_completed += 1 - # sets gradients to zero with _get_timing_context( state, f"{self.__class__.__name__}.optimizer_zero_grad" @@ -878,10 +874,6 @@ def on_eval_step_end( """ pass - @property - def num_optimizer_steps_completed(self) -> int: - return self._num_optimizer_steps_completed.val - def _validate_torchdynamo_available() -> None: if not is_torch_version_ge_1_13_1(): diff --git a/torchtnt/framework/utils.py b/torchtnt/framework/utils.py index 3e1f5eac6d..913451c221 100644 --- a/torchtnt/framework/utils.py +++ b/torchtnt/framework/utils.py @@ -32,7 +32,6 @@ from torchtnt.framework.unit import AppStateMixin from torchtnt.utils.lr_scheduler import TLRScheduler from torchtnt.utils.progress import Progress -from typing_extensions import Self _logger: logging.Logger = logging.getLogger(__name__) @@ -262,30 +261,3 @@ def get_current_progress(state: State) -> Progress: return none_throws(state.eval_state).progress else: return none_throws(state.predict_state).progress - - -class StatefulInt: - """ - This wrapper is useful if there are additional values related to training - progress that need to be saved during checkpointing. - """ - - def __init__(self, val: int) -> None: - self.val = val - - def state_dict(self) -> Dict[str, Any]: - return {"value": self.val} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self.val = state_dict["value"] - - def __add__(self, other: int) -> Self: - self.val += other - return self - - def __sub__(self, other: int) -> Self: - self.val -= other - return self - - def __repr__(self) -> str: - return f"StatefulInt({self.val})"