Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an indicator to state for the currently active phase #245

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/runner/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch import nn
from torchtnt.runner._test_utils import DummyEvalUnit, generate_random_dataloader
from torchtnt.runner.evaluate import evaluate, init_eval_state
from torchtnt.runner.state import State
from torchtnt.runner.state import EntryPoint, State
from torchtnt.runner.unit import EvalUnit


Expand All @@ -37,6 +37,7 @@ def test_evaluate(self) -> None:
self.assertEqual(state.eval_state.progress.num_epochs_completed, 1)
self.assertEqual(state.eval_state.progress.num_steps_completed_in_epoch, 0)
self.assertEqual(state.eval_state.progress.num_steps_completed, expected_steps)
self.assertEqual(state.entry_point, EntryPoint.EVALUATE)

# step_output should be reset to None
self.assertEqual(state.eval_state.step_output, None)
Expand Down Expand Up @@ -66,6 +67,7 @@ def test_evaluate_max_steps_per_epoch(self) -> None:
self.assertEqual(
state.eval_state.progress.num_steps_completed, max_steps_per_epoch
)
self.assertEqual(state.entry_point, EntryPoint.EVALUATE)

# step_output should be reset to None
self.assertEqual(state.eval_state.step_output, None)
Expand Down
4 changes: 3 additions & 1 deletion tests/runner/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchtnt.runner._test_utils import DummyFitUnit, generate_random_dataloader
from torchtnt.runner.callback import Callback
from torchtnt.runner.fit import fit, init_fit_state
from torchtnt.runner.state import State
from torchtnt.runner.state import EntryPoint, State
from torchtnt.runner.unit import EvalUnit, TrainUnit


Expand Down Expand Up @@ -65,6 +65,7 @@ def test_fit_evaluate_every_n_epochs(self) -> None:
state.eval_state.progress.num_steps_completed,
max_epochs * expected_eval_steps_per_epoch,
)
self.assertEqual(state.entry_point, EntryPoint.FIT)

# step_output should be reset to None
self.assertEqual(state.eval_state.step_output, None)
Expand Down Expand Up @@ -123,6 +124,7 @@ def test_fit_evaluate_every_n_steps(self) -> None:
state.eval_state.progress.num_steps_completed,
expected_num_evaluate_calls * expected_eval_steps_per_epoch,
)
self.assertEqual(state.entry_point, EntryPoint.FIT)

# step_output should be reset to None
self.assertEqual(state.eval_state.step_output, None)
Expand Down
3 changes: 2 additions & 1 deletion tests/runner/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchtnt.runner._test_utils import DummyPredictUnit, generate_random_dataloader

from torchtnt.runner.predict import init_predict_state, predict
from torchtnt.runner.state import State
from torchtnt.runner.state import EntryPoint, State
from torchtnt.runner.unit import PredictUnit


Expand All @@ -41,6 +41,7 @@ def test_predict(self) -> None:
self.assertEqual(
state.predict_state.progress.num_steps_completed, expected_steps
)
self.assertEqual(state.entry_point, EntryPoint.PREDICT)

# step_output should be reset to None
self.assertEqual(state.predict_state.step_output, None)
Expand Down
5 changes: 4 additions & 1 deletion tests/runner/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch import nn

from torchtnt.runner._test_utils import DummyTrainUnit, generate_random_dataloader
from torchtnt.runner.state import State
from torchtnt.runner.state import EntryPoint, State
from torchtnt.runner.train import init_train_state, train, train_epoch
from torchtnt.runner.unit import TrainUnit

Expand Down Expand Up @@ -47,6 +47,7 @@ def test_train(self) -> None:
self.assertEqual(state.train_state.step_output, None)

self.assertEqual(my_unit.module.training, initial_training_mode)
self.assertEqual(state.entry_point, EntryPoint.TRAIN)

def test_train_max_steps_per_epoch(self) -> None:
"""
Expand Down Expand Up @@ -77,6 +78,7 @@ def test_train_max_steps_per_epoch(self) -> None:
state.train_state.progress.num_steps_completed,
max_epochs * max_steps_per_epoch,
)
self.assertEqual(state.entry_point, EntryPoint.TRAIN)

# step_output should be reset to None
self.assertEqual(state.train_state.step_output, None)
Expand Down Expand Up @@ -107,6 +109,7 @@ def test_train_epoch(self) -> None:
state.train_state.progress.num_steps_completed,
expected_steps_per_epoch,
)
self.assertEqual(state.entry_point, EntryPoint.TRAIN)

# step_output should be reset to None
self.assertEqual(state.train_state.step_output, None)
Expand Down
5 changes: 4 additions & 1 deletion torchtnt/runner/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torchtnt.runner.callback import Callback

from torchtnt.runner.state import EntryPoint, PhaseState, State
from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.runner.unit import TEvalData, TEvalUnit
from torchtnt.runner.utils import (
_is_epoch_done,
Expand Down Expand Up @@ -67,6 +67,7 @@ def evaluate(
log_api_usage("evaluate")
callbacks = callbacks or []
try:
state._entry_point = EntryPoint.EVALUATE
_evaluate_impl(state, eval_unit, callbacks)
logger.info("Finished evaluation")
logger.debug(get_timer_summary(state.timer))
Expand All @@ -88,6 +89,8 @@ def _evaluate_impl(
eval_state = state.eval_state
if not eval_state:
raise RuntimeError("Expected eval_state to be initialized!")

state._active_phase = ActivePhase.EVALUATE
logger.info(
f"Started evaluate with max_steps_per_epoch={eval_state.max_steps_per_epoch}"
)
Expand Down
1 change: 1 addition & 0 deletions torchtnt/runner/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def fit(
callbacks = callbacks or []

try:
state._entry_point = EntryPoint.FIT
_fit_impl(state, unit, callbacks)
logger.debug(get_timer_summary(state.timer))
except Exception as e:
Expand Down
5 changes: 4 additions & 1 deletion torchtnt/runner/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torchtnt.runner.callback import Callback

from torchtnt.runner.state import EntryPoint, PhaseState, State
from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.runner.unit import TPredictData, TPredictUnit
from torchtnt.runner.utils import (
_is_epoch_done,
Expand Down Expand Up @@ -67,6 +67,7 @@ def predict(
log_api_usage("predict")
callbacks = callbacks or []
try:
state._entry_point = EntryPoint.PREDICT
_predict_impl(state, predict_unit, callbacks)
logger.info("Finished predict")
logger.debug(get_timer_summary(state.timer))
Expand All @@ -88,6 +89,8 @@ def _predict_impl(
predict_state = state.predict_state
if not predict_state:
raise RuntimeError("Expected predict_state to be initialized!")

state._active_phase = ActivePhase.PREDICT
logger.info(
f"Started predict with max_steps_per_epoch={predict_state.max_steps_per_epoch}"
)
Expand Down
32 changes: 32 additions & 0 deletions torchtnt/runner/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,39 @@ def _check_loop_condition(name: str, val: Optional[int]) -> None:


class EntryPoint(Enum):
"""
Enum for the user-facing functions offered by the TorchTNT runner module.
- :py:func:`~torchtnt.runner.fit`
- :py:func:`~torchtnt.runner.train`
- :py:func:`~torchtnt.runner.evaluate`
- :py:func:`~torchtnt.runner.predict`
"""

FIT = auto()
TRAIN = auto()
EVALUATE = auto()
PREDICT = auto()


class ActivePhase(Enum):
"""Enum for the currently active phase.

This class complements :class:`EntryPoint` by specifying the active phase for each function.
More than one phase value can be set while a :class:`EntryPoint` is running:
- ``EntryPoint.FIT`` - ``ActivePhase.{TRAIN,EVALUATE}``
- ``EntryPoint.TRAIN`` - ``ActivePhase.TRAIN``
- ``EntryPoint.EVALUATE`` - ``ActivePhase.EVALUATE``
- ``EntryPoint.PREDICT`` - ``ActivePhase.PREDICT``

This can be used within hooks such as :meth:`~torchtnt.runner.unit._OnExceptionMixin.on_exception`
to determine within which of training, evaluation, or prediction the hook is being called.
"""

TRAIN = auto()
EVALUATE = auto()
PREDICT = auto()


class PhaseState:
"""State for each phase (train, eval, predict)"""

Expand Down Expand Up @@ -116,11 +143,16 @@ def __init__(
self._eval_state = eval_state
self._predict_state = predict_state
self._should_stop: bool = False
self._active_phase: ActivePhase = ActivePhase.TRAIN

@property
def entry_point(self) -> EntryPoint:
return self._entry_point

@property
def active_phase(self) -> ActivePhase:
return self._active_phase

@property
def timer(self) -> Timer:
return self._timer
Expand Down
6 changes: 5 additions & 1 deletion torchtnt/runner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torchtnt.runner.callback import Callback
from torchtnt.runner.evaluate import _evaluate_impl
from torchtnt.runner.state import EntryPoint, PhaseState, State
from torchtnt.runner.state import ActivePhase, EntryPoint, PhaseState, State
from torchtnt.runner.unit import TTrainData, TTrainUnit
from torchtnt.runner.utils import (
_is_done,
Expand Down Expand Up @@ -76,6 +76,7 @@ def train(
log_api_usage("train")
callbacks = callbacks or []
try:
state._entry_point = EntryPoint.TRAIN
_train_impl(state, train_unit, callbacks)
logger.info("Finished train")
logger.debug(get_timer_summary(state.timer))
Expand All @@ -99,6 +100,7 @@ def _train_impl(
logger.info(
f"Started train with max_epochs={train_state.max_epochs}, max_steps={train_state.max_steps}, max_steps_per_epoch={train_state.max_steps_per_epoch}"
)
state._active_phase = ActivePhase.TRAIN

# Set all modules to train() mode
# access modules made available through _AppStateMixin
Expand Down Expand Up @@ -140,6 +142,7 @@ def train_epoch(
raise RuntimeError(
f"Expected state.train_state.max_epochs to be 1, but received {train_state.max_epochs}."
)
state._entry_point = EntryPoint.TRAIN
logger.info(
f"Started train_epoch with max_steps_per_epoch={train_state.max_steps_per_epoch}"
)
Expand All @@ -164,6 +167,7 @@ def _train_epoch_impl(
callbacks: List[Callback],
) -> None:
logger.info("Started train epoch")
state._active_phase = ActivePhase.TRAIN

# Set all modules to train() mode
# access modules made available through _AppStateMixin
Expand Down