Skip to content
Merged
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ conversion utilities for the following models:
:maxdepth: 2
:caption: Main Classes

main_classes/callback
main_classes/configuration
main_classes/logging
main_classes/model
Expand Down Expand Up @@ -270,3 +271,4 @@ conversion utilities for the following models:
internal/modeling_utils
internal/pipelines_utils
internal/tokenization_utils
internal/trainer_utils
21 changes: 21 additions & 0 deletions docs/source/internal/trainer_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
Utilities for Trainer
-----------------------------------------------------------------------------------------------------------------------

This page lists all the utility functions used by :class:`~transformers.Trainer`.

Most of those are only useful if you are studying the code of the Trainer in the library.

Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.EvalPrediction

.. autofunction:: transformers.set_seed

.. autofunction:: transformers.torch_distributed_zero_first


Callbacks internals
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.trainer_callback.CallbackHandler
68 changes: 68 additions & 0 deletions docs/source/main_classes/callback.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
Callbacks
-----------------------------------------------------------------------------------------------------------------------

Callbacks are objects that can customize the behavior of the training loop in the PyTorch
:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop
state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
stopping).

Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they
cannot change anything in the training loop. For customizations that require changes in the training loop, you should
subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples).

By default a :class:`~transformers.Trainer` will use the following callbacks:

- :class:`~transformers.DefaultFlowCallback` which handles the default beahvior for logging, saving and evaluation.
- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
it's the second one).
- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
or tensorboardX).
- :class:`~transformers.integrations.WandbCallback` if `wandb <https://www.wandb.com/>`__ is installed.
- :class:`~transformers.integrations.CometCallback` if `comet_ml <https://www.comet.ml/site/>`__ is installed.

The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via
:class:`~transformers.TrainerControl`.


Available Callbacks
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Here is the list of the available :class:`~transformers.TrainerCallback` in the library:

.. autoclass:: transformers.integrations.CometCallback
:members: setup

.. autoclass:: transformers.DefaultFlowCallback

.. autoclass:: transformers.PrinterCallback

.. autoclass:: transformers.ProgressCallback

.. autoclass:: transformers.integrations.TensorBoardCallback

.. autoclass:: transformers.integrations.WandbCallback
:members: setup


TrainerCallback
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TrainerCallback
:members:


TrainerState
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TrainerState
:members:


TrainerControl
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TrainerControl
:members:
17 changes: 7 additions & 10 deletions docs/source/main_classes/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
- **setup_wandb** -- Setups wandb (see `here <https://docs.wandb.com/huggingface>`__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init.
- **compute_loss** - Computes the loss on a batch of training inputs.
Expand All @@ -40,36 +39,34 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu
logits = outputs[0]
return my_custom_loss(logits, labels)

Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
:doc:`callbacks <callback>` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
other ML platforms...) and take decisions (like early stopping).


Trainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.Trainer
:members:


TFTrainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFTrainer
:members:


TrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TrainingArguments
:members:


TFTrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.TFTrainingArguments
:members:

Utilities
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.EvalPrediction

.. autofunction:: transformers.set_seed

.. autofunction:: transformers.torch_distributed_zero_first
2 changes: 1 addition & 1 deletion examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
from transformers.trainer import get_tpu_sampler
from transformers.trainer_pt_utils import get_tpu_sampler


try:
Expand Down
3 changes: 2 additions & 1 deletion examples/seq2seq/test_finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from unittest.mock import patch

from transformers.testing_utils import slow
from transformers.trainer_utils import TrainerState, set_seed
from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed

from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY
Expand Down
13 changes: 11 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,15 @@
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer

# Trainer
from .trainer_utils import EvalPrediction, TrainerState, set_seed
from .trainer_callback import (
DefaultFlowCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,
TrainerControl,
TrainerState,
)
from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments
from .utils import logging
Expand Down Expand Up @@ -528,7 +536,8 @@
from .tokenization_marian import MarianTokenizer

# Trainer
from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
else:
from .utils.dummy_pt_objects import *

Expand Down
Loading