Skip to content

Commit

Permalink
Remove timer from framework (#347)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #347

Reviewed By: daniellepintz

Differential Revision: D43785640

fbshipit-source-id: f979165c63c1fa75b3ba447380ccbb23859be99e
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Mar 4, 2023
1 parent f2f2063 commit 5ecac88
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 125 deletions.
3 changes: 1 addition & 2 deletions examples/auto_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import AutoUnit, fit, init_fit_state, State
from torchtnt.framework.state import ActivePhase
from torchtnt.utils import get_timer_summary, init_from_env, seed, TLRScheduler
from torchtnt.utils import init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger
from typing_extensions import Literal

Expand Down Expand Up @@ -168,7 +168,6 @@ def main(args: Namespace) -> None:
)

fit(state, my_unit)
print(get_timer_summary(state.timer))


def get_args() -> Namespace:
Expand Down
3 changes: 1 addition & 2 deletions examples/mingpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
OptimizerConfig,
)
from torchtnt.framework import AutoUnit, fit, init_fit_state, State
from torchtnt.utils import get_timer_summary, init_from_env, seed, TLRScheduler
from torchtnt.utils import init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger

_logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -131,7 +131,6 @@ def main(args: Namespace) -> None:
)

fit(state, my_unit)
print(get_timer_summary(state.timer))


def get_args() -> Namespace:
Expand Down
2 changes: 0 additions & 2 deletions examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from torchtnt.framework import AutoUnit, fit, init_fit_state, State
from torchtnt.utils import copy_data_to_device, init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger
from torchtnt.utils.timer import get_timer_summary
from torchvision import datasets, transforms

Batch = Tuple[torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -181,7 +180,6 @@ def main(argv: List[str]) -> None:
)

fit(state, my_unit)
print(get_timer_summary(state.timer))

if args.save_model:
torch.save(module.state_dict(), "mnist_cnn.pt")
Expand Down
9 changes: 1 addition & 8 deletions examples/torchdata_train_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@
from torchdata.datapipes.iter import IterableWrapper
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import init_train_state, State, train, TrainUnit
from torchtnt.utils import (
copy_data_to_device,
get_timer_summary,
init_from_env,
seed,
TLRScheduler,
)
from torchtnt.utils import copy_data_to_device, init_from_env, seed, TLRScheduler

from torchtnt.utils.loggers import TensorBoardLogger

Expand Down Expand Up @@ -177,7 +171,6 @@ def main(argv: List[str]) -> None:
)

train(train_state, train_unit)
print(get_timer_summary(train_state.timer))


def get_args(argv: List[str]) -> Namespace:
Expand Down
9 changes: 1 addition & 8 deletions examples/train_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@
from torch.utils.data.dataset import Dataset, TensorDataset
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import init_train_state, State, train, TrainUnit
from torchtnt.utils import (
copy_data_to_device,
get_timer_summary,
init_from_env,
seed,
TLRScheduler,
)
from torchtnt.utils import copy_data_to_device, init_from_env, seed, TLRScheduler

from torchtnt.utils.loggers import TensorBoardLogger

Expand Down Expand Up @@ -155,7 +149,6 @@ def main(argv: List[str]) -> None:
)

train(state, my_unit)
print(get_timer_summary(state.timer))


def get_args(argv: List[str]) -> Namespace:
Expand Down
28 changes: 0 additions & 28 deletions tests/framework/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
_step_requires_iterator,
StatefulInt,
)
from torchtnt.utils import Timer
from torchtnt.utils.test_utils import get_pet_launch_config


Expand Down Expand Up @@ -117,10 +116,8 @@ def test_run_callback_fn_hooks(self) -> None:
"""
callback = DummyCallback("train")
train_unit = MagicMock()
timer = Timer()
dummy_train_state = State(
entry_point=EntryPoint.TRAIN,
timer=timer,
train_state=None,
)
self.assertEqual(callback.dummy_data, "train")
Expand All @@ -133,55 +130,30 @@ def test_run_callback_fn_hooks(self) -> None:
ValueError("test"),
)
self.assertEqual(callback.dummy_data, "on_exception")
self.assertTrue(
"callback.DummyCallback.on_exception" in timer.recorded_durations.keys()
)

_run_callback_fn([callback], "on_train_start", dummy_train_state, train_unit)
self.assertEqual(callback.dummy_data, "on_train_start")
self.assertTrue(
"callback.DummyCallback.on_train_start" in timer.recorded_durations.keys()
)

_run_callback_fn(
[callback], "on_train_epoch_start", dummy_train_state, train_unit
)
self.assertEqual(callback.dummy_data, "on_train_epoch_start")
self.assertTrue(
"callback.DummyCallback.on_train_epoch_start"
in timer.recorded_durations.keys()
)

_run_callback_fn(
[callback], "on_train_step_start", dummy_train_state, train_unit
)
self.assertEqual(callback.dummy_data, "on_train_step_start")
self.assertTrue(
"callback.DummyCallback.on_train_step_start"
in timer.recorded_durations.keys()
)

_run_callback_fn([callback], "on_train_step_end", dummy_train_state, train_unit)
self.assertEqual(callback.dummy_data, "on_train_step_end")
self.assertTrue(
"callback.DummyCallback.on_train_step_end"
in timer.recorded_durations.keys()
)

_run_callback_fn(
[callback], "on_train_epoch_end", dummy_train_state, train_unit
)
self.assertEqual(callback.dummy_data, "on_train_epoch_end")
self.assertTrue(
"callback.DummyCallback.on_train_epoch_end"
in timer.recorded_durations.keys()
)

_run_callback_fn([callback], "on_train_end", dummy_train_state, train_unit)
self.assertEqual(callback.dummy_data, "on_train_end")
self.assertTrue(
"callback.DummyCallback.on_train_end" in timer.recorded_durations.keys()
)

def test_run_callback_fn_exception(self) -> None:
"""
Expand Down
22 changes: 6 additions & 16 deletions torchtnt/framework/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
_step_requires_iterator,
log_api_usage,
)
from torchtnt.utils.timer import get_timer_summary

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -95,7 +94,6 @@ def evaluate(
state._entry_point = EntryPoint.EVALUATE
_evaluate_impl(state, eval_unit, callbacks)
logger.info("Finished evaluation")
logger.debug(get_timer_summary(state.timer))
except Exception as e:
# TODO: log for diagnostics
logger.info(e)
Expand Down Expand Up @@ -123,17 +121,13 @@ def _evaluate_impl(
tracked_modules = eval_unit.tracked_modules()
prior_module_train_states = _set_module_training_mode(tracked_modules, False)

with state.timer.time(f"eval.{eval_unit.__class__.__name__}.on_eval_start"):
eval_unit.on_eval_start(state)
eval_unit.on_eval_start(state)
_run_callback_fn(callbacks, "on_eval_start", state, eval_unit)

# Conditionally run this to avoid running this multiple times
# in the case of resuming from a checkpoint mid-epoch
if eval_state.progress.num_steps_completed_in_epoch == 0:
with state.timer.time(
f"eval.{eval_unit.__class__.__name__}.on_eval_epoch_start"
):
eval_unit.on_eval_epoch_start(state)
eval_unit.on_eval_epoch_start(state)
_run_callback_fn(callbacks, "on_eval_epoch_start", state, eval_unit)

data_iter = iter(eval_state.dataloader)
Expand All @@ -151,11 +145,9 @@ def _evaluate_impl(
try:
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with state.timer.time("eval.data_iter_next"):
step_input = next(data_iter)
step_input = next(data_iter)
_run_callback_fn(callbacks, "on_eval_step_start", state, eval_unit)
with state.timer.time(f"eval.{eval_unit.__class__.__name__}.eval_step"):
eval_state._step_output = eval_unit.eval_step(state, step_input)
eval_state._step_output = eval_unit.eval_step(state, step_input)

eval_state.progress.increment_step()
_run_callback_fn(callbacks, "on_eval_step_end", state, eval_unit)
Expand All @@ -174,12 +166,10 @@ def _evaluate_impl(
# set progress counters for the next epoch
eval_state.progress.increment_epoch()

with state.timer.time(f"eval.{eval_unit.__class__.__name__}.on_eval_epoch_end"):
eval_unit.on_eval_epoch_end(state)
eval_unit.on_eval_epoch_end(state)
_run_callback_fn(callbacks, "on_eval_epoch_end", state, eval_unit)

with state.timer.time(f"eval.{eval_unit.__class__.__name__}.on_eval_end"):
eval_unit.on_eval_end(state)
eval_unit.on_eval_end(state)
_run_callback_fn(callbacks, "on_eval_end", state, eval_unit)

# Reset training mode for modules at the end of the epoch
Expand Down
8 changes: 2 additions & 6 deletions torchtnt/framework/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
TTrainUnit,
)
from torchtnt.framework.utils import _is_done, _run_callback_fn, log_api_usage
from torchtnt.utils.timer import get_timer_summary

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,7 +111,6 @@ def fit(
try:
state._entry_point = EntryPoint.FIT
_fit_impl(state, unit, callbacks)
logger.debug(get_timer_summary(state.timer))
except Exception as e:
# TODO: log for diagnostics
logger.info(e)
Expand Down Expand Up @@ -144,8 +142,7 @@ def _fit_impl(
f"evaluate_every_n_epochs={eval_state.evaluate_every_n_epochs} "
)

with state.timer.time(f"train.{unit.__class__.__name__}.on_train_start"):
unit.on_train_start(state)
unit.on_train_start(state)
_run_callback_fn(callbacks, "on_train_start", state, unit)

while not (
Expand All @@ -154,6 +151,5 @@ def _fit_impl(
):
_train_epoch_impl(state, unit, callbacks)

with state.timer.time(f"train.{unit.__class__.__name__}.on_train_end"):
unit.on_train_end(state)
unit.on_train_end(state)
_run_callback_fn(callbacks, "on_train_end", state, unit)
30 changes: 6 additions & 24 deletions torchtnt/framework/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
_step_requires_iterator,
log_api_usage,
)
from torchtnt.utils.timer import get_timer_summary

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -96,7 +95,6 @@ def predict(
state._entry_point = EntryPoint.PREDICT
_predict_impl(state, predict_unit, callbacks)
logger.info("Finished predict")
logger.debug(get_timer_summary(state.timer))
except Exception as e:
# TODO: log for diagnostics
logger.info(e)
Expand Down Expand Up @@ -124,19 +122,13 @@ def _predict_impl(
tracked_modules = predict_unit.tracked_modules()
prior_module_train_states = _set_module_training_mode(tracked_modules, False)

with state.timer.time(
f"predict.{predict_unit.__class__.__name__}.on_predict_start"
):
predict_unit.on_predict_start(state)
predict_unit.on_predict_start(state)
_run_callback_fn(callbacks, "on_predict_start", state, predict_unit)

# Conditionally run this to avoid running this multiple times
# in the case of resuming from a checkpoint mid-epoch
if predict_state.progress.num_steps_completed_in_epoch == 0:
with state.timer.time(
f"predict.{predict_unit.__class__.__name__}.on_predict_epoch_start"
):
predict_unit.on_predict_epoch_start(state)
predict_unit.on_predict_epoch_start(state)
_run_callback_fn(callbacks, "on_predict_epoch_start", state, predict_unit)

data_iter = iter(predict_state.dataloader)
Expand All @@ -156,16 +148,10 @@ def _predict_impl(
try:
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with state.timer.time("predict.data_iter_next"):
step_input = next(data_iter)
step_input = next(data_iter)

_run_callback_fn(callbacks, "on_predict_step_start", state, predict_unit)
with state.timer.time(
f"predict.{predict_unit.__class__.__name__}.predict_step"
):
predict_state._step_output = predict_unit.predict_step(
state, step_input
)
predict_state._step_output = predict_unit.predict_step(state, step_input)
predict_state.progress.increment_step()
_run_callback_fn(callbacks, "on_predict_step_end", state, predict_unit)

Expand All @@ -185,14 +171,10 @@ def _predict_impl(
# set progress counters for the next epoch
predict_state.progress.increment_epoch()

with state.timer.time(
f"predict.{predict_unit.__class__.__name__}.on_predict_epoch_end"
):
predict_unit.on_predict_epoch_end(state)
predict_unit.on_predict_epoch_end(state)
_run_callback_fn(callbacks, "on_predict_epoch_end", state, predict_unit)

with state.timer.time(f"predict.{predict_unit.__class__.__name__}.on_predict_end"):
predict_unit.on_predict_end(state)
predict_unit.on_predict_end(state)
_run_callback_fn(callbacks, "on_predict_end", state, predict_unit)

# Reset training mode for modules at the end of the epoch
Expand Down
8 changes: 0 additions & 8 deletions torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import Any, Iterable, Optional

from torchtnt.framework.progress import Progress
from torchtnt.utils.timer import Timer

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,13 +147,11 @@ def __init__(
self,
*,
entry_point: EntryPoint,
timer: Optional[Timer] = None,
train_state: Optional[PhaseState] = None,
eval_state: Optional[PhaseState] = None,
predict_state: Optional[PhaseState] = None,
) -> None:
self._entry_point = entry_point
self._timer: Timer = timer or Timer()
self._train_state = train_state
self._eval_state = eval_state
self._predict_state = predict_state
Expand All @@ -171,11 +168,6 @@ def active_phase(self) -> ActivePhase:
"""Current active phase of the loop. (One of TRAIN, EVALUATE, PREDICT)."""
return self._active_phase

@property
def timer(self) -> Timer:
"""A :class:`~torchtnt.framework.Timer` object which records latencies of key events during loop execution."""
return self._timer

@property
def train_state(self) -> Optional[PhaseState]:
"""A :class:`~torchtnt.framework.PhaseState` object which contains meta information about the train phase."""
Expand Down
Loading

0 comments on commit 5ecac88

Please sign in to comment.