Skip to content

Commit

Permalink
update docs to reflect new init_state (#259)
Browse files Browse the repository at this point in the history
Summary:
Please read through our [contribution guide](https://github.com/pytorch/tnt/blob/main/CONTRIBUTING.md) prior to creating your pull request.
<!-- Change Summary -->

Pull Request resolved: #259

Test Plan: Fixes #{issue number}

Reviewed By: ananthsub

Differential Revision: D40672500

Pulled By: edward-io

fbshipit-source-id: a82d49b003057f8557de94674a8d33c3369abc44
  • Loading branch information
edward-io authored and facebook-github-bot committed Oct 25, 2022
1 parent 939260e commit bc64893
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 45 deletions.
7 changes: 4 additions & 3 deletions docs/source/eval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ Evaluate Entry Point

To run your evaluation loop, call :py:func:`~torchtnt.runner.evaluate`.

The :py:func:`~torchtnt.runner.evaluate` entry point takes as arguments one EvalUnit, one iterable containing your data (can be *any* iterable, including PyTorch DataLoader, numpy, etc.), an optional list of callbacks
(described below), and several optional parameters to control run duration of the loop.
The :py:func:`~torchtnt.runner.evaluate` entry point takes a :class:`~torchtnt.runner.EvalUnit` object, a :class:`~torchtnt.runner.State` object, and an optional list of callbacks.

Below is an example of calling the :py:func:`~torchtnt.runner.evaluate` entry point with the ``EvalUnit`` created above.
The :class:`~torchtnt.runner.State` object can be initialized with :func:`~torchtnt.runner.init_eval_state`, which takes in a dataloader (can be *any* iterable, including PyTorch DataLoader, numpy, etc.) and some parameters to control the run duration of the loop.

Below is an example of calling the :py:func:`~torchtnt.runner.evaluate` entry point with ``MyEvalUnit`` created above.

.. code-block:: python
Expand Down
6 changes: 3 additions & 3 deletions docs/source/fit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ Fit Entry Point

To run your fit loop, call the fit loop entry point: :py:func:`~torchtnt.runner.fit`.

The `:py:func:`~torchtnt.runner.fit` entry point takes as arguments one TrainUnit/EvalUnit, one iterable containing your training data and one iterable containing your eval
data (can be *any* iterable, including PyTorch DataLoader, numpy, etc.), an optional list of callbacks(described below), and several optional parameters to control
run duration of the loop.
The :py:func:`~torchtnt.runner.fit` entry point takes an object subclassing both :class:`~torchtnt.runner.TrainUnit` and :class:`~torchtnt.runner.EvalUnit`, a :class:`~torchtnt.runner.state.State` object, and an optional list of callbacks.

The :class:`~torchtnt.runner.state.State` object can be initialized with :func:`~torchtnt.runner.init_fit_state`, which takes in a dataloader (can be *any* iterable, including PyTorch DataLoader, numpy, etc.) and some parameters to control the run duration of the loop.

Below is an example of calling the :py:func:`~torchtnt.runner.fit` entry point with the TrainUnit/EvalUnit created above.

Expand Down
7 changes: 4 additions & 3 deletions docs/source/predict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ Predict Entry Point

To run your prediction loop, call the prediction loop entry point: :py:func:`~torchtnt.runner.predict`.

The :py:func:`~torchtnt.runner.predict`. entry point takes as arguments one PredictUnit, one iterable containing your data (can be *any* iterable, including PyTorch DataLoader, numpy, etc.), an optional list of callbacks
(described below), and several optional parameters to control run duration of the loop.
The :py:func:`~torchtnt.runner.predict` entry point takes a :class:`~torchtnt.runner.PredictUnit` object, a :class:`~torchtnt.runner.State` object, and an optional list of callbacks.

Below is an example of calling the :py:func:`~torchtnt.runner.predict` entry point with the ``PredictUnit`` created above.
The :class:`~torchtnt.runner.State` object can be initialized with :func:`~torchtnt.runner.init_predict_state`, which takes in a dataloader (can be *any* iterable, including PyTorch DataLoader, numpy, etc.) and some parameters to control the run duration of the loop.

Below is an example of calling the :py:func:`~torchtnt.runner.predict` entry point with ``MyPredictUnit`` created above.

.. code-block:: python
Expand Down
7 changes: 4 additions & 3 deletions docs/source/train.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ Train Entry Point

To run your training loop, call the training loop entry point: :py:func:`~torchtnt.runner.train`.

The :py:func:`~torchtnt.runner.train` entry point takes as arguments one TrainUnit, one iterable containing your data (can be *any* iterable, including PyTorch DataLoader, numpy, etc.), an optional list of callbacks
(described below), and several optional parameters to control run duration of the loop.
The :py:func:`~torchtnt.runner.train` entry point takes a :class:`~torchtnt.runner.TrainUnit` object, a :class:`~torchtnt.runner.State` object, and an optional list of callbacks.

Below is an example of calling the :py:func:`~torchtnt.runner.train` entry point with the ``TrainUnit`` created above.
The :class:`~torchtnt.runner.State` object can be initialized with :func:`~torchtnt.runner.init_train_state`, which takes in a dataloader (can be *any* iterable, including PyTorch DataLoader, numpy, etc.) and some parameters to control the run duration of the loop.

Below is an example of calling the :py:func:`~torchtnt.runner.train` entry point with ``MyTrainUnit`` created above.

.. code-block:: python
Expand Down
16 changes: 8 additions & 8 deletions torchtnt/runner/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class AutoTrainUnit(TrainUnit[TTrainData], ABC):
"""
The AutoTrainUnit is a convenience for users who are training with stochastic gradient descent and would like to have model optimization
handled for them. The AutoTrainUnit subclasses TrainUnit, and runs the train_step for the user, specifically: forward pass, loss computation,
backward pass, and optimizer step. To benefit from the AutoTrainUnit, the user must subclass it and implement the `compute_loss` method, and
optionally the `update_metrics` and `log_metrics` methods. Then use with the `train` or `fit` entry point as normal.
backward pass, and optimizer step. To benefit from the AutoTrainUnit, the user must subclass it and implement the ``compute_loss`` method, and
optionally the ``update_metrics`` and ``log_metrics`` methods. Then use with the ``train`` or ``fit`` entry point as normal.
For more advanced customization, the basic TrainUnit interface may be a better fit.
Expand Down Expand Up @@ -120,8 +120,8 @@ def compute_loss(self, state: State, data: TTrainData) -> Tuple[torch.Tensor, An
The user should implement this method with their loss computation. This will be called every `train_step`.
Args:
state: a State object which is passed from the `train_step`
data: a batch of data which is passed from the `train_step`
state: a State object which is passed from the ``train_step``
data: a batch of data which is passed from the ``train_step``
Returns:
Tuple containing the loss and the output of the model
Expand All @@ -135,8 +135,8 @@ def update_metrics(
The user should implement this method with code to update metrics. This will be called every `train_step`.
Args:
state: a State object which is passed from the `train_step`
data: a batch of data which is passed from the `train_step`
state: a State object which is passed from the ``train_step``
data: a batch of data which is passed from the ``train_step``
outputs: the outputs of the model forward pass
"""
pass
Expand All @@ -149,9 +149,9 @@ def log_metrics(
and how many parameter updates have been run on the model.
Args:
state: a State object which is passed from the `train_step`
state: a State object which is passed from the ``train_step``
step: how many steps have been completed (i.e. how many parameter updates have been run on the model)
interval: whether `log_metrics` is called at the end of a step or at the end of an epoch
interval: whether ``log_metrics`` is called at the end of a step or at the end of an epoch
"""
pass

Expand Down
8 changes: 4 additions & 4 deletions torchtnt/runner/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def init_eval_state(
max_steps_per_epoch: Optional[int] = None,
) -> State:
"""
Helper function that initializes a state object for evaluation.
Helper function that initializes a :class:`~torchtnt.runner.state.State` object for evaluation.
Args:
dataloader: dataloader to be used during evaluation.
Expand All @@ -57,11 +57,11 @@ def evaluate(
callbacks: Optional[List[Callback]] = None,
) -> None:
"""
The `evaluate` entry point takes in a State and EvalUnit and runs the evaluation loop over the data.
The``evaluate``entry point takes in a :class:`~torchtnt.runner.State` and :class:`~torchtnt.runner.unit.EvalUnit` and runs the evaluation loop over the data.
Args:
state: a State object containing metadata about the evaluation run.
eval_unit: an instance of EvalUnit which implements `eval_step`.
state: a :class:`~torchtnt.runner.State` object containing metadata about the evaluation run.
eval_unit: an instance of :class:`~torchtnt.runner.EvalUnit` which implements `eval_step`.
callbacks: an optional list of callbacks.
"""
log_api_usage("evaluate")
Expand Down
14 changes: 8 additions & 6 deletions torchtnt/runner/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def init_fit_state(
evaluate_every_n_epochs: Optional[int] = 1,
) -> State:
"""
Helper function that initializes a state object for fitting.
Helper function that initializes a :class:`~torchtnt.runner.State` object for fitting.
Args:
train_dataloader: dataloader to be used during training.
eval_dataloader: dataloader to be used during evaluation.
max_epochs: the max number of epochs to run for training. `None` means no limit (infinite training) unless stopped by max_steps.
max_steps: the max number of steps to run for training. `None` means no limit (infinite training) unless stopped by max_epochs.
max_epochs: the max number of epochs to run for training. ``None`` means no limit (infinite training) unless stopped by max_steps.
max_steps: the max number of steps to run for training. ``None`` means no limit (infinite training) unless stopped by max_epochs.
max_train_steps_per_epoch: the max number of steps to run per epoch for training. None means train until the dataloader is exhausted.
evaluate_every_n_steps: how often to run the evaluation loop in terms of training steps.
evaluate_every_n_epochs: how often to run the evaluation loop in terms of training epochs.
Expand Down Expand Up @@ -65,11 +65,13 @@ def fit(
state: State, unit: TTrainUnit, *, callbacks: Optional[List[Callback]] = None
) -> None:
"""
The `fit` entry point interleaves the training and evaluation loops, taking in an instance of TrainUnit/EvalUnit as well as train and eval dataloaders.
The ``fit`` entry point interleaves training and evaluation loops.
Args:
state: a State object containing metadata about the fitting run.
unit: an instance of both TrainUnit EvalUnit which implements both `train_step` and `eval_step`.
state: a :class:`~torchtnt.runner.State` object containing metadata about the fitting run.
:func:`~torchtnt.runner.init_fit_state` can be used to initialize a state object.
unit: an instance that subclasses both :class:`~torchtnt.runner.unit.TrainUnit` and :class:`~torchtnt.runner.unit.EvalUnit`,
implementing :meth:`~torchtnt.runner.TrainUnit.train_step` and :meth:`~torchtnt.runner.EvalUnit.eval_step`.
callbacks: an optional list of callbacks.
"""
log_api_usage("fit")
Expand Down
8 changes: 4 additions & 4 deletions torchtnt/runner/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def init_predict_state(
max_steps_per_epoch: Optional[int] = None,
) -> State:
"""
Helper function that initializes a state object for prediction.
Helper function that initializes a :class:`~torchtnt.runner.State` object for prediction.
Args:
dataloader: dataloader to be used during prediction.
Expand All @@ -57,11 +57,11 @@ def predict(
callbacks: Optional[List[Callback]] = None,
) -> None:
"""
The `predict` entry point takes in a State and PredictUnit and runs the prediction loop over the data.
The ``predict`` entry point takes in a :class:`~torchtnt.runner.State` and :class:`~torchtnt.runner.PredictUnit` and runs the prediction loop over the data.
Args:
state: a State object containing metadata about the prediction run.
predict_unit: an instance of PredictUnit which implements `predict_step`.
state: a State object containing metadata about the prediction run. This can be initialized using :func:`~torchtnt.runner.init_predict_state`.
predict_unit: an instance of :class:`~torchtnt.runner.PredictUnit` which implements `predict_step`.
callbacks: an optional list of callbacks.
"""
log_api_usage("predict")
Expand Down
21 changes: 10 additions & 11 deletions torchtnt/runner/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def init_train_state(
max_steps_per_epoch: Optional[int] = None,
) -> State:
"""
Helper function that initializes a state object for training.
Helper function that initializes a :class:`~torchtnt.runner.State` object for training.
Args:
dataloader: dataloader to be used during training.
max_epochs: the max number of epochs to run. `None` means no limit (infinite training) unless stopped by max_steps.
max_steps: the max number of steps to run. `None` means no limit (infinite training) unless stopped by max_epochs.
max_epochs: the max number of epochs to run. ``None`` means no limit (infinite training) unless stopped by max_steps.
max_steps: the max number of steps to run. ``None`` means no limit (infinite training) unless stopped by max_epochs.
max_steps_per_epoch: the max number of steps to run per epoch. None means train until the dataloader is exhausted.
Returns:
Expand All @@ -66,11 +66,11 @@ def train(
callbacks: Optional[List[Callback]] = None,
) -> None:
"""
The `train` entry point takes in a State and a TrainUnit and runs the training loop.
state / train_unit / callbacks are expected to introduce side effects
The ``train`` entry point takes in a :class:`~torchtnt.runner.State` object and a :class:`~torchtnt.runner.TrainUnit` object and runs the training loop.
Args:
state: a State object containing metadata about the training run.
train_unit: an instance of TrainUnit which implements `train_step`.
state: a :class:`~torchtnt.runner.State` object containing metadata about the training run.
train_unit: an instance of :class:`~torchtnt.runner.TrainUnit` which implements `train_step`.
callbacks: an optional list of callbacks.
"""
log_api_usage("train")
Expand Down Expand Up @@ -137,12 +137,11 @@ def train_epoch(
The `train_epoch` entry point takes in a State and a TrainUnit and runs one epoch (one pass through the dataloader).
This entry point can be used for interleaving training with another entry point (evaluate or predict).
Note: this does not call the `on_train_start` or `on_train_end` methods on the unit or callbacks.
Note: this does not call the ``on_train_start`` or ``on_train_end`` methods on the unit or callbacks.
state / train_unit / callbacks are expected to introduce side effects.
Args:
state: a State object containing metadata about the training run.
train_unit: an instance of TrainUnit which implements `train_step`.
state: a class:`~torchtnt.runner.State` object containing metadata about the training run.
train_unit: an instance of :class:`~torchtnt.runner.TrainUnit` which implements `train_step`.
callbacks: an optional list of callbacks.
"""
callbacks = callbacks or []
Expand Down

0 comments on commit bc64893

Please sign in to comment.