diff --git a/tests/runner/test_evaluate.py b/tests/runner/test_evaluate.py index be63ae1e48..8458d02400 100644 --- a/tests/runner/test_evaluate.py +++ b/tests/runner/test_evaluate.py @@ -13,7 +13,7 @@ from torch import nn from torchtnt.runner._test_utils import DummyEvalUnit, generate_random_dataloader from torchtnt.runner.evaluate import evaluate, init_eval_state -from torchtnt.runner.state import State +from torchtnt.runner.state import EntryPoint, State from torchtnt.runner.unit import EvalUnit @@ -37,6 +37,7 @@ def test_evaluate(self) -> None: self.assertEqual(state.eval_state.progress.num_epochs_completed, 1) self.assertEqual(state.eval_state.progress.num_steps_completed_in_epoch, 0) self.assertEqual(state.eval_state.progress.num_steps_completed, expected_steps) + self.assertEqual(state.entry_point, EntryPoint.EVALUATE) # step_output should be reset to None self.assertEqual(state.eval_state.step_output, None) @@ -66,6 +67,7 @@ def test_evaluate_max_steps_per_epoch(self) -> None: self.assertEqual( state.eval_state.progress.num_steps_completed, max_steps_per_epoch ) + self.assertEqual(state.entry_point, EntryPoint.EVALUATE) # step_output should be reset to None self.assertEqual(state.eval_state.step_output, None) diff --git a/tests/runner/test_fit.py b/tests/runner/test_fit.py index 445de92b26..4b63cbb788 100644 --- a/tests/runner/test_fit.py +++ b/tests/runner/test_fit.py @@ -14,7 +14,7 @@ from torchtnt.runner._test_utils import DummyFitUnit, generate_random_dataloader from torchtnt.runner.callback import Callback from torchtnt.runner.fit import fit, init_fit_state -from torchtnt.runner.state import State +from torchtnt.runner.state import EntryPoint, State from torchtnt.runner.unit import EvalUnit, TrainUnit @@ -65,6 +65,7 @@ def test_fit_evaluate_every_n_epochs(self) -> None: state.eval_state.progress.num_steps_completed, max_epochs * expected_eval_steps_per_epoch, ) + self.assertEqual(state.entry_point, EntryPoint.FIT) # step_output should be reset to None self.assertEqual(state.eval_state.step_output, None) @@ -123,6 +124,7 @@ def test_fit_evaluate_every_n_steps(self) -> None: state.eval_state.progress.num_steps_completed, expected_num_evaluate_calls * expected_eval_steps_per_epoch, ) + self.assertEqual(state.entry_point, EntryPoint.FIT) # step_output should be reset to None self.assertEqual(state.eval_state.step_output, None) diff --git a/tests/runner/test_predict.py b/tests/runner/test_predict.py index 9c85e9b7b1..b96d618163 100644 --- a/tests/runner/test_predict.py +++ b/tests/runner/test_predict.py @@ -15,7 +15,7 @@ from torchtnt.runner._test_utils import DummyPredictUnit, generate_random_dataloader from torchtnt.runner.predict import init_predict_state, predict -from torchtnt.runner.state import State +from torchtnt.runner.state import EntryPoint, State from torchtnt.runner.unit import PredictUnit @@ -41,6 +41,7 @@ def test_predict(self) -> None: self.assertEqual( state.predict_state.progress.num_steps_completed, expected_steps ) + self.assertEqual(state.entry_point, EntryPoint.PREDICT) # step_output should be reset to None self.assertEqual(state.predict_state.step_output, None) diff --git a/tests/runner/test_train.py b/tests/runner/test_train.py index bd558e5089..f37f9c51b3 100644 --- a/tests/runner/test_train.py +++ b/tests/runner/test_train.py @@ -13,7 +13,7 @@ from torch import nn from torchtnt.runner._test_utils import DummyTrainUnit, generate_random_dataloader -from torchtnt.runner.state import State +from torchtnt.runner.state import EntryPoint, State from torchtnt.runner.train import init_train_state, train, train_epoch from torchtnt.runner.unit import TrainUnit @@ -47,6 +47,7 @@ def test_train(self) -> None: self.assertEqual(state.train_state.step_output, None) self.assertEqual(my_unit.module.training, initial_training_mode) + self.assertEqual(state.entry_point, EntryPoint.TRAIN) def test_train_max_steps_per_epoch(self) -> None: """ @@ -77,6 +78,7 @@ def test_train_max_steps_per_epoch(self) -> None: state.train_state.progress.num_steps_completed, max_epochs * max_steps_per_epoch, ) + self.assertEqual(state.entry_point, EntryPoint.TRAIN) # step_output should be reset to None self.assertEqual(state.train_state.step_output, None) @@ -107,6 +109,7 @@ def test_train_epoch(self) -> None: state.train_state.progress.num_steps_completed, expected_steps_per_epoch, ) + self.assertEqual(state.entry_point, EntryPoint.TRAIN) # step_output should be reset to None self.assertEqual(state.train_state.step_output, None) diff --git a/torchtnt/runner/evaluate.py b/torchtnt/runner/evaluate.py index abf830e47e..a7b760b0c2 100644 --- a/torchtnt/runner/evaluate.py +++ b/torchtnt/runner/evaluate.py @@ -10,7 +10,7 @@ import torch from torchtnt.runner.callback import Callback -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TEvalData, TEvalUnit from torchtnt.runner.utils import ( _is_epoch_done, @@ -67,6 +67,7 @@ def evaluate( log_api_usage("evaluate") callbacks = callbacks or [] try: + state._entry_point = EntryPoint.EVALUATE _evaluate_impl(state, eval_unit, callbacks) logger.info("Finished evaluation") logger.debug(get_timer_summary(state.timer)) @@ -88,6 +89,8 @@ def _evaluate_impl( eval_state = state.eval_state if not eval_state: raise RuntimeError("Expected eval_state to be initialized!") + + state._active_phase = ActivePhase.EVALUATE logger.info( f"Started evaluate with max_steps_per_epoch={eval_state.max_steps_per_epoch}" ) diff --git a/torchtnt/runner/fit.py b/torchtnt/runner/fit.py index c8b7c4f64d..d584d9f019 100644 --- a/torchtnt/runner/fit.py +++ b/torchtnt/runner/fit.py @@ -76,6 +76,7 @@ def fit( callbacks = callbacks or [] try: + state._entry_point = EntryPoint.FIT _fit_impl(state, unit, callbacks) logger.debug(get_timer_summary(state.timer)) except Exception as e: diff --git a/torchtnt/runner/predict.py b/torchtnt/runner/predict.py index 96f9206ce5..6ad2ad44c5 100644 --- a/torchtnt/runner/predict.py +++ b/torchtnt/runner/predict.py @@ -10,7 +10,7 @@ import torch from torchtnt.runner.callback import Callback -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TPredictData, TPredictUnit from torchtnt.runner.utils import ( _is_epoch_done, @@ -67,6 +67,7 @@ def predict( log_api_usage("predict") callbacks = callbacks or [] try: + state._entry_point = EntryPoint.PREDICT _predict_impl(state, predict_unit, callbacks) logger.info("Finished predict") logger.debug(get_timer_summary(state.timer)) @@ -88,6 +89,8 @@ def _predict_impl( predict_state = state.predict_state if not predict_state: raise RuntimeError("Expected predict_state to be initialized!") + + state._active_phase = ActivePhase.PREDICT logger.info( f"Started predict with max_steps_per_epoch={predict_state.max_steps_per_epoch}" ) diff --git a/torchtnt/runner/state.py b/torchtnt/runner/state.py index 5378251c29..366a828481 100644 --- a/torchtnt/runner/state.py +++ b/torchtnt/runner/state.py @@ -27,12 +27,39 @@ def _check_loop_condition(name: str, val: Optional[int]) -> None: class EntryPoint(Enum): + """ + Enum for the user-facing functions offered by the TorchTNT runner module. + - :py:func:`~torchtnt.runner.fit` + - :py:func:`~torchtnt.runner.train` + - :py:func:`~torchtnt.runner.evaluate` + - :py:func:`~torchtnt.runner.predict` + """ + FIT = auto() TRAIN = auto() EVALUATE = auto() PREDICT = auto() +class ActivePhase(Enum): + """Enum for the currently active phase. + + This class complements :class:`EntryPoint` by specifying the active phase for each function. + More than one phase value can be set while a :class:`EntryPoint` is running: + - ``EntryPoint.FIT`` - ``ActivePhase.{TRAIN,EVALUATE}`` + - ``EntryPoint.TRAIN`` - ``ActivePhase.TRAIN`` + - ``EntryPoint.EVALUATE`` - ``ActivePhase.EVALUATE`` + - ``EntryPoint.PREDICT`` - ``ActivePhase.PREDICT`` + + This can be used within hooks such as :meth:`~torchtnt.runner.unit._OnExceptionMixin.on_exception` + to determine within which of training, evaluation, or prediction the hook is being called. + """ + + TRAIN = auto() + EVALUATE = auto() + PREDICT = auto() + + class PhaseState: """State for each phase (train, eval, predict)""" @@ -116,11 +143,16 @@ def __init__( self._eval_state = eval_state self._predict_state = predict_state self._should_stop: bool = False + self._active_phase: ActivePhase = ActivePhase.TRAIN @property def entry_point(self) -> EntryPoint: return self._entry_point + @property + def active_phase(self) -> ActivePhase: + return self._active_phase + @property def timer(self) -> Timer: return self._timer diff --git a/torchtnt/runner/train.py b/torchtnt/runner/train.py index f5879b3a27..0434d305ca 100644 --- a/torchtnt/runner/train.py +++ b/torchtnt/runner/train.py @@ -10,7 +10,7 @@ import torch from torchtnt.runner.callback import Callback from torchtnt.runner.evaluate import _evaluate_impl -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TTrainData, TTrainUnit from torchtnt.runner.utils import ( _is_done, @@ -76,6 +76,7 @@ def train( log_api_usage("train") callbacks = callbacks or [] try: + state._entry_point = EntryPoint.TRAIN _train_impl(state, train_unit, callbacks) logger.info("Finished train") logger.debug(get_timer_summary(state.timer)) @@ -99,6 +100,7 @@ def _train_impl( logger.info( f"Started train with max_epochs={train_state.max_epochs}, max_steps={train_state.max_steps}, max_steps_per_epoch={train_state.max_steps_per_epoch}" ) + state._active_phase = ActivePhase.TRAIN # Set all modules to train() mode # access modules made available through _AppStateMixin @@ -140,6 +142,7 @@ def train_epoch( raise RuntimeError( f"Expected state.train_state.max_epochs to be 1, but received {train_state.max_epochs}." ) + state._entry_point = EntryPoint.TRAIN logger.info( f"Started train_epoch with max_steps_per_epoch={train_state.max_steps_per_epoch}" ) @@ -164,6 +167,7 @@ def _train_epoch_impl( callbacks: List[Callback], ) -> None: logger.info("Started train epoch") + state._active_phase = ActivePhase.TRAIN # Set all modules to train() mode # access modules made available through _AppStateMixin