Skip to content

Commit

Permalink
add back timer to state
Browse files Browse the repository at this point in the history
Differential Revision: D43852641

fbshipit-source-id: 6576246b12a9f8255c851b96509652c6dff9c9af
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Mar 6, 2023
1 parent 5ecac88 commit dc625fb
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions torchtnt/framework/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Iterable, Optional

from torchtnt.framework.progress import Progress
from torchtnt.utils import Timer

_logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -147,11 +148,13 @@ def __init__(
self,
*,
entry_point: EntryPoint,
timer: Optional[Timer] = None,
train_state: Optional[PhaseState] = None,
eval_state: Optional[PhaseState] = None,
predict_state: Optional[PhaseState] = None,
) -> None:
self._entry_point = entry_point
self._timer: Timer = timer or Timer()
self._train_state = train_state
self._eval_state = eval_state
self._predict_state = predict_state
Expand All @@ -168,6 +171,11 @@ def active_phase(self) -> ActivePhase:
"""Current active phase of the loop. (One of TRAIN, EVALUATE, PREDICT)."""
return self._active_phase

@property
def timer(self) -> Timer:
"""A :class:`~torchtnt.framework.Timer` object which records latencies of key events during loop execution."""
return self._timer

@property
def train_state(self) -> Optional[PhaseState]:
"""A :class:`~torchtnt.framework.PhaseState` object which contains meta information about the train phase."""
Expand Down

0 comments on commit dc625fb

Please sign in to comment.