From bc6489324b88995b583a9798d15da74f15d8f9f2 Mon Sep 17 00:00:00 2001 From: "Edward Wang (EcoF)" Date: Tue, 25 Oct 2022 12:31:35 -0700 Subject: [PATCH] update docs to reflect new init_state (#259) Summary: Please read through our [contribution guide](https://github.com/pytorch/tnt/blob/main/CONTRIBUTING.md) prior to creating your pull request. Pull Request resolved: https://github.com/pytorch/tnt/pull/259 Test Plan: Fixes #{issue number} Reviewed By: ananthsub Differential Revision: D40672500 Pulled By: edward-io fbshipit-source-id: a82d49b003057f8557de94674a8d33c3369abc44 --- docs/source/eval.rst | 7 ++++--- docs/source/fit.rst | 6 +++--- docs/source/predict.rst | 7 ++++--- docs/source/train.rst | 7 ++++--- torchtnt/runner/auto_unit.py | 16 ++++++++-------- torchtnt/runner/evaluate.py | 8 ++++---- torchtnt/runner/fit.py | 14 ++++++++------ torchtnt/runner/predict.py | 8 ++++---- torchtnt/runner/train.py | 21 ++++++++++----------- 9 files changed, 49 insertions(+), 45 deletions(-) diff --git a/docs/source/eval.rst b/docs/source/eval.rst index 8733ddd833..b99086d09a 100644 --- a/docs/source/eval.rst +++ b/docs/source/eval.rst @@ -35,10 +35,11 @@ Evaluate Entry Point To run your evaluation loop, call :py:func:`~torchtnt.runner.evaluate`. -The :py:func:`~torchtnt.runner.evaluate` entry point takes as arguments one EvalUnit, one iterable containing your data (can be *any* iterable, including PyTorch DataLoader, numpy, etc.), an optional list of callbacks -(described below), and several optional parameters to control run duration of the loop. +The :py:func:`~torchtnt.runner.evaluate` entry point takes a :class:`~torchtnt.runner.EvalUnit` object, a :class:`~torchtnt.runner.State` object, and an optional list of callbacks. -Below is an example of calling the :py:func:`~torchtnt.runner.evaluate` entry point with the ``EvalUnit`` created above. +The :class:`~torchtnt.runner.State` object can be initialized with :func:`~torchtnt.runner.init_eval_state`, which takes in a dataloader (can be *any* iterable, including PyTorch DataLoader, numpy, etc.) and some parameters to control the run duration of the loop. + +Below is an example of calling the :py:func:`~torchtnt.runner.evaluate` entry point with ``MyEvalUnit`` created above. .. code-block:: python diff --git a/docs/source/fit.rst b/docs/source/fit.rst index c0757fd7f3..f51d1998bd 100644 --- a/docs/source/fit.rst +++ b/docs/source/fit.rst @@ -53,9 +53,9 @@ Fit Entry Point To run your fit loop, call the fit loop entry point: :py:func:`~torchtnt.runner.fit`. -The `:py:func:`~torchtnt.runner.fit` entry point takes as arguments one TrainUnit/EvalUnit, one iterable containing your training data and one iterable containing your eval -data (can be *any* iterable, including PyTorch DataLoader, numpy, etc.), an optional list of callbacks(described below), and several optional parameters to control -run duration of the loop. +The :py:func:`~torchtnt.runner.fit` entry point takes an object subclassing both :class:`~torchtnt.runner.TrainUnit` and :class:`~torchtnt.runner.EvalUnit`, a :class:`~torchtnt.runner.state.State` object, and an optional list of callbacks. + +The :class:`~torchtnt.runner.state.State` object can be initialized with :func:`~torchtnt.runner.init_fit_state`, which takes in a dataloader (can be *any* iterable, including PyTorch DataLoader, numpy, etc.) and some parameters to control the run duration of the loop. Below is an example of calling the :py:func:`~torchtnt.runner.fit` entry point with the TrainUnit/EvalUnit created above. diff --git a/docs/source/predict.rst b/docs/source/predict.rst index 4eabf1f367..5e7b243c67 100644 --- a/docs/source/predict.rst +++ b/docs/source/predict.rst @@ -34,10 +34,11 @@ Predict Entry Point To run your prediction loop, call the prediction loop entry point: :py:func:`~torchtnt.runner.predict`. -The :py:func:`~torchtnt.runner.predict`. entry point takes as arguments one PredictUnit, one iterable containing your data (can be *any* iterable, including PyTorch DataLoader, numpy, etc.), an optional list of callbacks -(described below), and several optional parameters to control run duration of the loop. +The :py:func:`~torchtnt.runner.predict` entry point takes a :class:`~torchtnt.runner.PredictUnit` object, a :class:`~torchtnt.runner.State` object, and an optional list of callbacks. -Below is an example of calling the :py:func:`~torchtnt.runner.predict` entry point with the ``PredictUnit`` created above. +The :class:`~torchtnt.runner.State` object can be initialized with :func:`~torchtnt.runner.init_predict_state`, which takes in a dataloader (can be *any* iterable, including PyTorch DataLoader, numpy, etc.) and some parameters to control the run duration of the loop. + +Below is an example of calling the :py:func:`~torchtnt.runner.predict` entry point with ``MyPredictUnit`` created above. .. code-block:: python diff --git a/docs/source/train.rst b/docs/source/train.rst index 8f4de1951f..a0f3934f1e 100644 --- a/docs/source/train.rst +++ b/docs/source/train.rst @@ -47,10 +47,11 @@ Train Entry Point To run your training loop, call the training loop entry point: :py:func:`~torchtnt.runner.train`. -The :py:func:`~torchtnt.runner.train` entry point takes as arguments one TrainUnit, one iterable containing your data (can be *any* iterable, including PyTorch DataLoader, numpy, etc.), an optional list of callbacks -(described below), and several optional parameters to control run duration of the loop. +The :py:func:`~torchtnt.runner.train` entry point takes a :class:`~torchtnt.runner.TrainUnit` object, a :class:`~torchtnt.runner.State` object, and an optional list of callbacks. -Below is an example of calling the :py:func:`~torchtnt.runner.train` entry point with the ``TrainUnit`` created above. +The :class:`~torchtnt.runner.State` object can be initialized with :func:`~torchtnt.runner.init_train_state`, which takes in a dataloader (can be *any* iterable, including PyTorch DataLoader, numpy, etc.) and some parameters to control the run duration of the loop. + +Below is an example of calling the :py:func:`~torchtnt.runner.train` entry point with ``MyTrainUnit`` created above. .. code-block:: python diff --git a/torchtnt/runner/auto_unit.py b/torchtnt/runner/auto_unit.py index cae5285901..01a6096346 100644 --- a/torchtnt/runner/auto_unit.py +++ b/torchtnt/runner/auto_unit.py @@ -27,8 +27,8 @@ class AutoTrainUnit(TrainUnit[TTrainData], ABC): """ The AutoTrainUnit is a convenience for users who are training with stochastic gradient descent and would like to have model optimization handled for them. The AutoTrainUnit subclasses TrainUnit, and runs the train_step for the user, specifically: forward pass, loss computation, - backward pass, and optimizer step. To benefit from the AutoTrainUnit, the user must subclass it and implement the `compute_loss` method, and - optionally the `update_metrics` and `log_metrics` methods. Then use with the `train` or `fit` entry point as normal. + backward pass, and optimizer step. To benefit from the AutoTrainUnit, the user must subclass it and implement the ``compute_loss`` method, and + optionally the ``update_metrics`` and ``log_metrics`` methods. Then use with the ``train`` or ``fit`` entry point as normal. For more advanced customization, the basic TrainUnit interface may be a better fit. @@ -120,8 +120,8 @@ def compute_loss(self, state: State, data: TTrainData) -> Tuple[torch.Tensor, An The user should implement this method with their loss computation. This will be called every `train_step`. Args: - state: a State object which is passed from the `train_step` - data: a batch of data which is passed from the `train_step` + state: a State object which is passed from the ``train_step`` + data: a batch of data which is passed from the ``train_step`` Returns: Tuple containing the loss and the output of the model @@ -135,8 +135,8 @@ def update_metrics( The user should implement this method with code to update metrics. This will be called every `train_step`. Args: - state: a State object which is passed from the `train_step` - data: a batch of data which is passed from the `train_step` + state: a State object which is passed from the ``train_step`` + data: a batch of data which is passed from the ``train_step`` outputs: the outputs of the model forward pass """ pass @@ -149,9 +149,9 @@ def log_metrics( and how many parameter updates have been run on the model. Args: - state: a State object which is passed from the `train_step` + state: a State object which is passed from the ``train_step`` step: how many steps have been completed (i.e. how many parameter updates have been run on the model) - interval: whether `log_metrics` is called at the end of a step or at the end of an epoch + interval: whether ``log_metrics`` is called at the end of a step or at the end of an epoch """ pass diff --git a/torchtnt/runner/evaluate.py b/torchtnt/runner/evaluate.py index a7b760b0c2..e0296f92d0 100644 --- a/torchtnt/runner/evaluate.py +++ b/torchtnt/runner/evaluate.py @@ -31,7 +31,7 @@ def init_eval_state( max_steps_per_epoch: Optional[int] = None, ) -> State: """ - Helper function that initializes a state object for evaluation. + Helper function that initializes a :class:`~torchtnt.runner.state.State` object for evaluation. Args: dataloader: dataloader to be used during evaluation. @@ -57,11 +57,11 @@ def evaluate( callbacks: Optional[List[Callback]] = None, ) -> None: """ - The `evaluate` entry point takes in a State and EvalUnit and runs the evaluation loop over the data. + The``evaluate``entry point takes in a :class:`~torchtnt.runner.State` and :class:`~torchtnt.runner.unit.EvalUnit` and runs the evaluation loop over the data. Args: - state: a State object containing metadata about the evaluation run. - eval_unit: an instance of EvalUnit which implements `eval_step`. + state: a :class:`~torchtnt.runner.State` object containing metadata about the evaluation run. + eval_unit: an instance of :class:`~torchtnt.runner.EvalUnit` which implements `eval_step`. callbacks: an optional list of callbacks. """ log_api_usage("evaluate") diff --git a/torchtnt/runner/fit.py b/torchtnt/runner/fit.py index d584d9f019..3ad66ae91a 100644 --- a/torchtnt/runner/fit.py +++ b/torchtnt/runner/fit.py @@ -29,13 +29,13 @@ def init_fit_state( evaluate_every_n_epochs: Optional[int] = 1, ) -> State: """ - Helper function that initializes a state object for fitting. + Helper function that initializes a :class:`~torchtnt.runner.State` object for fitting. Args: train_dataloader: dataloader to be used during training. eval_dataloader: dataloader to be used during evaluation. - max_epochs: the max number of epochs to run for training. `None` means no limit (infinite training) unless stopped by max_steps. - max_steps: the max number of steps to run for training. `None` means no limit (infinite training) unless stopped by max_epochs. + max_epochs: the max number of epochs to run for training. ``None`` means no limit (infinite training) unless stopped by max_steps. + max_steps: the max number of steps to run for training. ``None`` means no limit (infinite training) unless stopped by max_epochs. max_train_steps_per_epoch: the max number of steps to run per epoch for training. None means train until the dataloader is exhausted. evaluate_every_n_steps: how often to run the evaluation loop in terms of training steps. evaluate_every_n_epochs: how often to run the evaluation loop in terms of training epochs. @@ -65,11 +65,13 @@ def fit( state: State, unit: TTrainUnit, *, callbacks: Optional[List[Callback]] = None ) -> None: """ - The `fit` entry point interleaves the training and evaluation loops, taking in an instance of TrainUnit/EvalUnit as well as train and eval dataloaders. + The ``fit`` entry point interleaves training and evaluation loops. Args: - state: a State object containing metadata about the fitting run. - unit: an instance of both TrainUnit EvalUnit which implements both `train_step` and `eval_step`. + state: a :class:`~torchtnt.runner.State` object containing metadata about the fitting run. + :func:`~torchtnt.runner.init_fit_state` can be used to initialize a state object. + unit: an instance that subclasses both :class:`~torchtnt.runner.unit.TrainUnit` and :class:`~torchtnt.runner.unit.EvalUnit`, + implementing :meth:`~torchtnt.runner.TrainUnit.train_step` and :meth:`~torchtnt.runner.EvalUnit.eval_step`. callbacks: an optional list of callbacks. """ log_api_usage("fit") diff --git a/torchtnt/runner/predict.py b/torchtnt/runner/predict.py index 6ad2ad44c5..83f07e6beb 100644 --- a/torchtnt/runner/predict.py +++ b/torchtnt/runner/predict.py @@ -31,7 +31,7 @@ def init_predict_state( max_steps_per_epoch: Optional[int] = None, ) -> State: """ - Helper function that initializes a state object for prediction. + Helper function that initializes a :class:`~torchtnt.runner.State` object for prediction. Args: dataloader: dataloader to be used during prediction. @@ -57,11 +57,11 @@ def predict( callbacks: Optional[List[Callback]] = None, ) -> None: """ - The `predict` entry point takes in a State and PredictUnit and runs the prediction loop over the data. + The ``predict`` entry point takes in a :class:`~torchtnt.runner.State` and :class:`~torchtnt.runner.PredictUnit` and runs the prediction loop over the data. Args: - state: a State object containing metadata about the prediction run. - predict_unit: an instance of PredictUnit which implements `predict_step`. + state: a State object containing metadata about the prediction run. This can be initialized using :func:`~torchtnt.runner.init_predict_state`. + predict_unit: an instance of :class:`~torchtnt.runner.PredictUnit` which implements `predict_step`. callbacks: an optional list of callbacks. """ log_api_usage("predict") diff --git a/torchtnt/runner/train.py b/torchtnt/runner/train.py index 35808e4c9d..ccb3064c71 100644 --- a/torchtnt/runner/train.py +++ b/torchtnt/runner/train.py @@ -35,12 +35,12 @@ def init_train_state( max_steps_per_epoch: Optional[int] = None, ) -> State: """ - Helper function that initializes a state object for training. + Helper function that initializes a :class:`~torchtnt.runner.State` object for training. Args: dataloader: dataloader to be used during training. - max_epochs: the max number of epochs to run. `None` means no limit (infinite training) unless stopped by max_steps. - max_steps: the max number of steps to run. `None` means no limit (infinite training) unless stopped by max_epochs. + max_epochs: the max number of epochs to run. ``None`` means no limit (infinite training) unless stopped by max_steps. + max_steps: the max number of steps to run. ``None`` means no limit (infinite training) unless stopped by max_epochs. max_steps_per_epoch: the max number of steps to run per epoch. None means train until the dataloader is exhausted. Returns: @@ -66,11 +66,11 @@ def train( callbacks: Optional[List[Callback]] = None, ) -> None: """ - The `train` entry point takes in a State and a TrainUnit and runs the training loop. - state / train_unit / callbacks are expected to introduce side effects + The ``train`` entry point takes in a :class:`~torchtnt.runner.State` object and a :class:`~torchtnt.runner.TrainUnit` object and runs the training loop. + Args: - state: a State object containing metadata about the training run. - train_unit: an instance of TrainUnit which implements `train_step`. + state: a :class:`~torchtnt.runner.State` object containing metadata about the training run. + train_unit: an instance of :class:`~torchtnt.runner.TrainUnit` which implements `train_step`. callbacks: an optional list of callbacks. """ log_api_usage("train") @@ -137,12 +137,11 @@ def train_epoch( The `train_epoch` entry point takes in a State and a TrainUnit and runs one epoch (one pass through the dataloader). This entry point can be used for interleaving training with another entry point (evaluate or predict). - Note: this does not call the `on_train_start` or `on_train_end` methods on the unit or callbacks. + Note: this does not call the ``on_train_start`` or ``on_train_end`` methods on the unit or callbacks. - state / train_unit / callbacks are expected to introduce side effects. Args: - state: a State object containing metadata about the training run. - train_unit: an instance of TrainUnit which implements `train_step`. + state: a class:`~torchtnt.runner.State` object containing metadata about the training run. + train_unit: an instance of :class:`~torchtnt.runner.TrainUnit` which implements `train_step`. callbacks: an optional list of callbacks. """ callbacks = callbacks or []