Skip to content

Commit

Permalink
Add an indicator to state for the currently active phase
Browse files Browse the repository at this point in the history
Summary: We have hooks such as `on_exception` which are common to train/eval/predict. in these instances, especially when using `fit` where both `train` and `eval` are called, it's difficult to know what was the currently running phase without an explicit indicator. this diff adds that indicator, which is only meaningful while the loop is being executed.

Differential Revision: D40527985

fbshipit-source-id: 2772df147b28b424df8f9df933f1d7cbcc0996dd
  • Loading branch information
ananthsub authored and facebook-github-bot committed Oct 20, 2022
1 parent 65c93ed commit 61be0ab
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 3 deletions.
4 changes: 3 additions & 1 deletion torchtnt/runner/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torchtnt.runner.callback import Callback

from torchtnt.runner.state import EntryPoint, PhaseState, State
from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.runner.unit import TEvalData, TEvalUnit
from torchtnt.runner.utils import (
_is_epoch_done,
Expand Down Expand Up @@ -89,6 +89,8 @@ def _evaluate_impl(
eval_state = state.eval_state
if not eval_state:
raise RuntimeError("Expected eval_state to be initialized!")

state._active_phase = ActivePhase.EVALUATE
logger.info(
f"Started evaluate with max_steps_per_epoch={eval_state.max_steps_per_epoch}"
)
Expand Down
4 changes: 3 additions & 1 deletion torchtnt/runner/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torchtnt.runner.callback import Callback

from torchtnt.runner.state import EntryPoint, PhaseState, State
from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.runner.unit import TPredictData, TPredictUnit
from torchtnt.runner.utils import (
_is_epoch_done,
Expand Down Expand Up @@ -89,6 +89,8 @@ def _predict_impl(
predict_state = state.predict_state
if not predict_state:
raise RuntimeError("Expected predict_state to be initialized!")

state._active_phase = ActivePhase.PREDICT
logger.info(
f"Started predict with max_steps_per_epoch={predict_state.max_steps_per_epoch}"
)
Expand Down
32 changes: 32 additions & 0 deletions torchtnt/runner/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,39 @@ def _check_loop_condition(name: str, val: Optional[int]) -> None:


class EntryPoint(Enum):
"""
Enum for the user-facing functions offered by the TorchTNT runner module.
- :py:func:`~torchtnt.runner.fit`
- :py:func:`~torchtnt.runner.train`
- :py:func:`~torchtnt.runner.evaluate`
- :py:func:`~torchtnt.runner.predict`
"""

FIT = auto()
TRAIN = auto()
EVALUATE = auto()
PREDICT = auto()


class ActivePhase(Enum):
"""Enum for the currently active phase.
This class complements :class:`EntryPoint` by specifying the active phase for each function.
More than one phase value can be set while a :class:`EntryPoint` is running:
- ``EntryPoint.FIT`` - ``ActivePhase.{TRAIN,EVALUATE}``
- ``EntryPoint.TRAIN`` - ``ActivePhase.TRAIN``
- ``EntryPoint.EVALUATE`` - ``ActivePhase.EVALUATE``
- ``EntryPoint.PREDICT`` - ``ActivePhase.PREDICT``
This can be used within hooks such as :meth:`~torchtnt.runner.unit._OnExceptionMixin.on_exception`
to determine within which of training, evaluation, or prediction the hook is being called.
"""

TRAIN = auto()
EVALUATE = auto()
PREDICT = auto()


class PhaseState:
"""State for each phase (train, eval, predict)"""

Expand Down Expand Up @@ -116,11 +143,16 @@ def __init__(
self._eval_state = eval_state
self._predict_state = predict_state
self._should_stop: bool = False
self._active_phase: ActivePhase = ActivePhase.TRAIN

@property
def entry_point(self) -> EntryPoint:
return self._entry_point

@property
def active_phase(self) -> ActivePhase:
return self._active_phase

@property
def timer(self) -> Timer:
return self._timer
Expand Down
4 changes: 3 additions & 1 deletion torchtnt/runner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torchtnt.runner.callback import Callback
from torchtnt.runner.evaluate import _evaluate_impl
from torchtnt.runner.state import EntryPoint, PhaseState, State
from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.runner.unit import TTrainData, TTrainUnit
from torchtnt.runner.utils import (
_is_done,
Expand Down Expand Up @@ -100,6 +100,7 @@ def _train_impl(
logger.info(
f"Started train with max_epochs={train_state.max_epochs}, max_steps={train_state.max_steps}, max_steps_per_epoch={train_state.max_steps_per_epoch}"
)
state._active_phase = ActivePhase.TRAIN

# Set all modules to train() mode
# access modules made available through _AppStateMixin
Expand Down Expand Up @@ -166,6 +167,7 @@ def _train_epoch_impl(
callbacks: List[Callback],
) -> None:
logger.info("Started train epoch")
state._active_phase = ActivePhase.TRAIN

# Set all modules to train() mode
# access modules made available through _AppStateMixin
Expand Down

0 comments on commit 61be0ab

Please sign in to comment.