diff --git a/torchtnt/runner/evaluate.py b/torchtnt/runner/evaluate.py index 063760e9eb..0b4697a165 100644 --- a/torchtnt/runner/evaluate.py +++ b/torchtnt/runner/evaluate.py @@ -10,7 +10,7 @@ import torch from torchtnt.runner.callback import Callback -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TEvalData, TEvalUnit from torchtnt.runner.utils import ( _is_epoch_done, @@ -91,6 +91,7 @@ def _evaluate_impl( # always set entry point state._entry_point = EntryPoint.EVALUATE + state._active_phase = ActivePhase.TRAIN logger.info( f"Started evaluate with max_steps_per_epoch={eval_state.max_steps_per_epoch}" diff --git a/torchtnt/runner/predict.py b/torchtnt/runner/predict.py index 6fc626b2ad..eba088d3e7 100644 --- a/torchtnt/runner/predict.py +++ b/torchtnt/runner/predict.py @@ -10,7 +10,7 @@ import torch from torchtnt.runner.callback import Callback -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TPredictData, TPredictUnit from torchtnt.runner.utils import ( _is_epoch_done, @@ -91,6 +91,7 @@ def _predict_impl( # always set entry point state._entry_point = EntryPoint.PREDICT + state._active_phase = ActivePhase.TRAIN logger.info( f"Started predict with max_steps_per_epoch={predict_state.max_steps_per_epoch}" diff --git a/torchtnt/runner/state.py b/torchtnt/runner/state.py index 5378251c29..b17ab017aa 100644 --- a/torchtnt/runner/state.py +++ b/torchtnt/runner/state.py @@ -27,12 +27,39 @@ def _check_loop_condition(name: str, val: Optional[int]) -> None: class EntryPoint(Enum): + """ + Enum for the user-facing functions offered by the TorchTNT runner module. + - :py:func:`~torchtnt.runner.train` + - :py:func:`~torchtnt.runner.evaluate` + - :py:func:`~torchtnt.runner.predict` + - :py:func:`~torchtnt.runner.fit` + """ + FIT = auto() TRAIN = auto() EVALUATE = auto() PREDICT = auto() +class ActivePhase(Enum): + """Enum for the currently active phase. + + This class complements :class:`EntryPoint` by specifying the active phase for each function. + More than one phase value can be set while a :class:`EntryPoint` is running: + - ``EntryPoint.FIT`` - ``ActivePhase.{TRAIN,EVALUATE}`` + - ``EntryPoint.TRAIN`` - ``ActivePhase.TRAIN`` + - ``EntryPoint.EVALUATE`` - ``ActivePhase.EVALUATE`` + - ``EntryPoint.PREDICT`` - ``ActivePhase.PREDICTING`` + + This can be used within hooks such as :meth:`~torchtnt.runner.unit._OnExceptionMixin.on_exception` + to determine within which of training, evaluation, or prediction the hook is being called. + """ + + TRAIN = auto() + EVALUATE = auto() + PREDICT = auto() + + class PhaseState: """State for each phase (train, eval, predict)""" @@ -116,11 +143,16 @@ def __init__( self._eval_state = eval_state self._predict_state = predict_state self._should_stop: bool = False + self._active_phase: ActivePhase = ActivePhase.TRAIN @property def entry_point(self) -> EntryPoint: return self._entry_point + @property + def active_phase(self) -> ActivePhase: + return self._active_phase + @property def timer(self) -> Timer: return self._timer diff --git a/torchtnt/runner/train.py b/torchtnt/runner/train.py index ef220a471a..67f7d74fc4 100644 --- a/torchtnt/runner/train.py +++ b/torchtnt/runner/train.py @@ -5,12 +5,11 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import cast, Iterable, List, Optional - +from typing import Iterable, List, Optional import torch from torchtnt.runner.callback import Callback from torchtnt.runner.evaluate import _evaluate_impl -from torchtnt.runner.state import EntryPoint, PhaseState, State +from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State from torchtnt.runner.unit import TTrainData, TTrainUnit from torchtnt.runner.utils import ( _is_done, @@ -102,6 +101,7 @@ def _train_impl( # always set entry point state._entry_point = EntryPoint.TRAIN + state._active_phase = ActivePhase.TRAIN # Set all modules to train() mode # access modules made available through _AppStateMixin @@ -169,6 +169,7 @@ def _train_epoch_impl( callbacks: List[Callback], ) -> None: logger.info("Started train epoch") + state._active_phase = ActivePhase.TRAIN # Set all modules to train() mode # access modules made available through _AppStateMixin