diff --git a/torchtnt/runner/evaluate.py b/torchtnt/runner/evaluate.py index ad90a045e5..a7b760b0c2 100644 --- a/torchtnt/runner/evaluate.py +++ b/torchtnt/runner/evaluate.py @@ -10,7 +10,7 @@ import torch from torchtnt.runner.callback import Callback -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TEvalData, TEvalUnit from torchtnt.runner.utils import ( _is_epoch_done, @@ -89,6 +89,8 @@ def _evaluate_impl( eval_state = state.eval_state if not eval_state: raise RuntimeError("Expected eval_state to be initialized!") + + state._active_phase = ActivePhase.EVALUATE logger.info( f"Started evaluate with max_steps_per_epoch={eval_state.max_steps_per_epoch}" ) diff --git a/torchtnt/runner/predict.py b/torchtnt/runner/predict.py index 9c6eeb1f0c..6ad2ad44c5 100644 --- a/torchtnt/runner/predict.py +++ b/torchtnt/runner/predict.py @@ -10,7 +10,7 @@ import torch from torchtnt.runner.callback import Callback -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TPredictData, TPredictUnit from torchtnt.runner.utils import ( _is_epoch_done, @@ -89,6 +89,8 @@ def _predict_impl( predict_state = state.predict_state if not predict_state: raise RuntimeError("Expected predict_state to be initialized!") + + state._active_phase = ActivePhase.PREDICT logger.info( f"Started predict with max_steps_per_epoch={predict_state.max_steps_per_epoch}" ) diff --git a/torchtnt/runner/state.py b/torchtnt/runner/state.py index 5378251c29..366a828481 100644 --- a/torchtnt/runner/state.py +++ b/torchtnt/runner/state.py @@ -27,12 +27,39 @@ def _check_loop_condition(name: str, val: Optional[int]) -> None: class EntryPoint(Enum): + """ + Enum for the user-facing functions offered by the TorchTNT runner module. + - :py:func:`~torchtnt.runner.fit` + - :py:func:`~torchtnt.runner.train` + - :py:func:`~torchtnt.runner.evaluate` + - :py:func:`~torchtnt.runner.predict` + """ + FIT = auto() TRAIN = auto() EVALUATE = auto() PREDICT = auto() +class ActivePhase(Enum): + """Enum for the currently active phase. + + This class complements :class:`EntryPoint` by specifying the active phase for each function. + More than one phase value can be set while a :class:`EntryPoint` is running: + - ``EntryPoint.FIT`` - ``ActivePhase.{TRAIN,EVALUATE}`` + - ``EntryPoint.TRAIN`` - ``ActivePhase.TRAIN`` + - ``EntryPoint.EVALUATE`` - ``ActivePhase.EVALUATE`` + - ``EntryPoint.PREDICT`` - ``ActivePhase.PREDICT`` + + This can be used within hooks such as :meth:`~torchtnt.runner.unit._OnExceptionMixin.on_exception` + to determine within which of training, evaluation, or prediction the hook is being called. + """ + + TRAIN = auto() + EVALUATE = auto() + PREDICT = auto() + + class PhaseState: """State for each phase (train, eval, predict)""" @@ -116,11 +143,16 @@ def __init__( self._eval_state = eval_state self._predict_state = predict_state self._should_stop: bool = False + self._active_phase: ActivePhase = ActivePhase.TRAIN @property def entry_point(self) -> EntryPoint: return self._entry_point + @property + def active_phase(self) -> ActivePhase: + return self._active_phase + @property def timer(self) -> Timer: return self._timer diff --git a/torchtnt/runner/train.py b/torchtnt/runner/train.py index f285412f99..0434d305ca 100644 --- a/torchtnt/runner/train.py +++ b/torchtnt/runner/train.py @@ -10,7 +10,7 @@ import torch from torchtnt.runner.callback import Callback from torchtnt.runner.evaluate import _evaluate_impl -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TTrainData, TTrainUnit from torchtnt.runner.utils import ( _is_done, @@ -100,6 +100,7 @@ def _train_impl( logger.info( f"Started train with max_epochs={train_state.max_epochs}, max_steps={train_state.max_steps}, max_steps_per_epoch={train_state.max_steps_per_epoch}" ) + state._active_phase = ActivePhase.TRAIN # Set all modules to train() mode # access modules made available through _AppStateMixin @@ -166,6 +167,7 @@ def _train_epoch_impl( callbacks: List[Callback], ) -> None: logger.info("Started train epoch") + state._active_phase = ActivePhase.TRAIN # Set all modules to train() mode # access modules made available through _AppStateMixin