Skip to content

Commit

Permalink
Remove tracking num_optimizer_steps_completed (#437)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #437

This isn't used anywhere now

Reviewed By: daniellepintz

Differential Revision: D46612359

fbshipit-source-id: 97ad1ebd2425fd3b0124972837184cd977c23f23
  • Loading branch information
ananthsub authored and facebook-github-bot committed Jun 30, 2023
1 parent 7219d15 commit 1297517
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 86 deletions.
40 changes: 0 additions & 40 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
import unittest
from typing import Any, Tuple
from unittest.mock import MagicMock, patch
Expand All @@ -19,7 +18,6 @@
DYNAMO_AVAIL = True
import torch._dynamo

from parameterized import parameterized
from torch.distributed import GradBucket, launcher
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -168,44 +166,6 @@ def test_mixed_precision_invalid_str(self) -> None:
precision="foo",
)

@parameterized.expand(
[
[1],
[2],
[4],
[5],
]
)
def test_num_optimizer_steps_completed(self, gradient_accumulation_steps) -> None:
"""
Test the num_optimizer_steps_completed property of AutoUnit
"""
my_module = torch.nn.Linear(2, 2)

input_dim = 2
dataset_len = 16
batch_size = 2
max_epochs = 1

auto_unit = DummyAutoUnit(
module=my_module,
gradient_accumulation_steps=gradient_accumulation_steps,
)

expected_opt_steps_per_epoch = math.ceil(
dataset_len / batch_size / gradient_accumulation_steps
)

train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_train_state(dataloader=train_dl, max_epochs=max_epochs)
train(state, auto_unit)
self.assertEqual(
auto_unit.num_optimizer_steps_completed, expected_opt_steps_per_epoch
)
self.assertIn(
"_num_optimizer_steps_completed", auto_unit.tracked_misc_statefuls()
)

def test_stochastic_weight_averaging_basic(self) -> None:
"""
Basic stochastic weight averaging tests
Expand Down
10 changes: 0 additions & 10 deletions tests/framework/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
_set_module_training_mode,
_step_requires_iterator,
get_current_progress,
StatefulInt,
)
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
Expand Down Expand Up @@ -422,15 +421,6 @@ def _construct_optimizers() -> None:
tc.assertTrue(isinstance(result["optim2"], torch.optim.Optimizer))
tc.assertTrue(isinstance(result["lr_scheduler"], TLRScheduler))

def test_stateful_int(self) -> None:
v = StatefulInt(0)
v += 10
v -= 2
self.assertEqual(v.val, 8)
self.assertEqual(v.state_dict(), {"value": 8})
v.load_state_dict({"value": -4})
self.assertEqual(v.val, -4)


Batch = Tuple[torch.tensor, torch.tensor]

Expand Down
8 changes: 0 additions & 8 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
_get_timing_context,
_is_fsdp_module,
get_current_progress,
StatefulInt,
)
from torchtnt.utils import (
init_from_env,
Expand Down Expand Up @@ -507,7 +506,6 @@ def __init__(
)

self.gradient_accumulation_steps = gradient_accumulation_steps
self._num_optimizer_steps_completed: StatefulInt = StatefulInt(0)

self.detect_anomaly = detect_anomaly
self.clip_grad_norm = clip_grad_norm
Expand Down Expand Up @@ -758,8 +756,6 @@ def train_step(
else:
self.optimizer.step()

self._num_optimizer_steps_completed += 1

# sets gradients to zero
with _get_timing_context(
state, f"{self.__class__.__name__}.optimizer_zero_grad"
Expand Down Expand Up @@ -878,10 +874,6 @@ def on_eval_step_end(
"""
pass

@property
def num_optimizer_steps_completed(self) -> int:
return self._num_optimizer_steps_completed.val


def _validate_torchdynamo_available() -> None:
if not is_torch_version_ge_1_13_1():
Expand Down
28 changes: 0 additions & 28 deletions torchtnt/framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from torchtnt.framework.unit import AppStateMixin
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.progress import Progress
from typing_extensions import Self

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -262,30 +261,3 @@ def get_current_progress(state: State) -> Progress:
return none_throws(state.eval_state).progress
else:
return none_throws(state.predict_state).progress


class StatefulInt:
"""
This wrapper is useful if there are additional values related to training
progress that need to be saved during checkpointing.
"""

def __init__(self, val: int) -> None:
self.val = val

def state_dict(self) -> Dict[str, Any]:
return {"value": self.val}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.val = state_dict["value"]

def __add__(self, other: int) -> Self:
self.val += other
return self

def __sub__(self, other: int) -> Self:
self.val -= other
return self

def __repr__(self) -> str:
return f"StatefulInt({self.val})"

0 comments on commit 1297517

Please sign in to comment.