diff --git a/docs/source/index.rst b/docs/source/index.rst
index 0545d46240bc..400c40827568 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -213,6 +213,7 @@ conversion utilities for the following models:
:maxdepth: 2
:caption: Main Classes
+ main_classes/callback
main_classes/configuration
main_classes/logging
main_classes/model
@@ -270,3 +271,4 @@ conversion utilities for the following models:
internal/modeling_utils
internal/pipelines_utils
internal/tokenization_utils
+ internal/trainer_utils
diff --git a/docs/source/internal/trainer_utils.rst b/docs/source/internal/trainer_utils.rst
new file mode 100644
index 000000000000..48e8568b9530
--- /dev/null
+++ b/docs/source/internal/trainer_utils.rst
@@ -0,0 +1,21 @@
+Utilities for Trainer
+-----------------------------------------------------------------------------------------------------------------------
+
+This page lists all the utility functions used by :class:`~transformers.Trainer`.
+
+Most of those are only useful if you are studying the code of the Trainer in the library.
+
+Utilities
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: transformers.EvalPrediction
+
+.. autofunction:: transformers.set_seed
+
+.. autofunction:: transformers.torch_distributed_zero_first
+
+
+Callbacks internals
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: transformers.trainer_callback.CallbackHandler
diff --git a/docs/source/main_classes/callback.rst b/docs/source/main_classes/callback.rst
new file mode 100644
index 000000000000..1b31c5464521
--- /dev/null
+++ b/docs/source/main_classes/callback.rst
@@ -0,0 +1,68 @@
+Callbacks
+-----------------------------------------------------------------------------------------------------------------------
+
+Callbacks are objects that can customize the behavior of the training loop in the PyTorch
+:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop
+state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early
+stopping).
+
+Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they
+cannot change anything in the training loop. For customizations that require changes in the training loop, you should
+subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples).
+
+By default a :class:`~transformers.Trainer` will use the following callbacks:
+
+- :class:`~transformers.DefaultFlowCallback` which handles the default beahvior for logging, saving and evaluation.
+- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the
+ logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise
+ it's the second one).
+- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4
+ or tensorboardX).
+- :class:`~transformers.integrations.WandbCallback` if `wandb `__ is installed.
+- :class:`~transformers.integrations.CometCallback` if `comet_ml `__ is installed.
+
+The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the
+:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that
+Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via
+:class:`~transformers.TrainerControl`.
+
+
+Available Callbacks
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+Here is the list of the available :class:`~transformers.TrainerCallback` in the library:
+
+.. autoclass:: transformers.integrations.CometCallback
+ :members: setup
+
+.. autoclass:: transformers.DefaultFlowCallback
+
+.. autoclass:: transformers.PrinterCallback
+
+.. autoclass:: transformers.ProgressCallback
+
+.. autoclass:: transformers.integrations.TensorBoardCallback
+
+.. autoclass:: transformers.integrations.WandbCallback
+ :members: setup
+
+
+TrainerCallback
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: transformers.TrainerCallback
+ :members:
+
+
+TrainerState
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: transformers.TrainerState
+ :members:
+
+
+TrainerControl
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autoclass:: transformers.TrainerControl
+ :members:
diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst
index 76cd1f34d75c..07050d1707aa 100644
--- a/docs/source/main_classes/trainer.rst
+++ b/docs/source/main_classes/trainer.rst
@@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override
- **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset.
- **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset.
- **log** -- Logs information on the various objects watching training.
-- **setup_wandb** -- Setups wandb (see `here `__ for more information).
- **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at
init.
- **compute_loss** - Computes the loss on a batch of training inputs.
@@ -40,6 +39,10 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu
logits = outputs[0]
return my_custom_loss(logits, labels)
+Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use
+:doc:`callbacks ` that can inspect the training loop state (for progress reporting, logging on TensorBoard or
+other ML platforms...) and take decisions (like early stopping).
+
Trainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -47,29 +50,23 @@ Trainer
.. autoclass:: transformers.Trainer
:members:
+
TFTrainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFTrainer
:members:
+
TrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TrainingArguments
:members:
+
TFTrainingArguments
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFTrainingArguments
:members:
-
-Utilities
-~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-
-.. autoclass:: transformers.EvalPrediction
-
-.. autofunction:: transformers.set_seed
-
-.. autofunction:: transformers.torch_distributed_zero_first
diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py
index 0f585eb26218..293244df41d2 100644
--- a/examples/seq2seq/seq2seq_trainer.py
+++ b/examples/seq2seq/seq2seq_trainer.py
@@ -9,7 +9,7 @@
from transformers.configuration_fsmt import FSMTConfig
from transformers.file_utils import is_torch_tpu_available
from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup
-from transformers.trainer import get_tpu_sampler
+from transformers.trainer_pt_utils import get_tpu_sampler
try:
diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py
index 517e76b2327e..9002f0284a69 100644
--- a/examples/seq2seq/test_finetune_trainer.py
+++ b/examples/seq2seq/test_finetune_trainer.py
@@ -4,7 +4,8 @@
from unittest.mock import patch
from transformers.testing_utils import slow
-from transformers.trainer_utils import TrainerState, set_seed
+from transformers.trainer_callback import TrainerState
+from transformers.trainer_utils import set_seed
from .finetune_trainer import main
from .test_seq2seq_examples import MBART_TINY
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 9ab37830cfea..b8bb0f8a8eb9 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -205,7 +205,15 @@
from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
# Trainer
-from .trainer_utils import EvalPrediction, TrainerState, set_seed
+from .trainer_callback import (
+ DefaultFlowCallback,
+ PrinterCallback,
+ ProgressCallback,
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+)
+from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed
from .training_args import TrainingArguments
from .training_args_tf import TFTrainingArguments
from .utils import logging
@@ -528,7 +536,8 @@
from .tokenization_marian import MarianTokenizer
# Trainer
- from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first
+ from .trainer import Trainer
+ from .trainer_pt_utils import torch_distributed_zero_first
else:
from .utils.dummy_pt_objects import *
diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py
index 15badb59cc01..9e0ee0cbb01b 100644
--- a/src/transformers/integrations.py
+++ b/src/transformers/integrations.py
@@ -2,6 +2,11 @@
import math
import os
+from .file_utils import is_torch_tpu_available
+from .trainer_callback import TrainerCallback
+from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun
+from .utils import logging
+
try:
import comet_ml # noqa: F401
@@ -36,15 +41,6 @@
except (ImportError):
_has_ray = False
-
-# No ML framework or transformer imports above this point
-
-from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # isort:skip
-from .utils import logging # isort:skip
-
-logger = logging.get_logger(__name__)
-
-
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
@@ -57,9 +53,10 @@
except ImportError:
_has_tensorboard = False
-# Integration functions:
+logger = logging.get_logger(__name__)
+# Integration functions:
def is_wandb_available():
return _has_wandb
@@ -128,8 +125,8 @@ def _objective(trial, checkpoint_dir=None):
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
- _tb_writer = trainer.tb_writer
- trainer.tb_writer = None
+
+ _tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None
# Setup default `resources_per_trial` and `reporter`.
if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
@@ -182,5 +179,159 @@ def _objective(trial, checkpoint_dir=None):
analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
- trainer.tb_writer = _tb_writer
+ if _tb_writer is not None:
+ trainer.add_callback(_tb_writer)
return best_run
+
+
+class TensorBoardCallback(TrainerCallback):
+ """
+ A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
+ `__.
+
+ Args:
+ tb_writer (:obj:`SummaryWriter`, `optional`):
+ The writer to use. Will instatiate one if not set.
+ """
+
+ def __init__(self, tb_writer=None):
+ assert (
+ _has_tensorboard
+ ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
+ self.tb_writer = tb_writer
+
+ def on_init_end(self, args, state, control, **kwargs):
+ if self.tb_writer is None and state.is_world_process_zero:
+ self.tb_writer = SummaryWriter(log_dir=args.logging_dir)
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ if self.tb_writer is not None:
+ self.tb_writer.add_text("args", args.to_json_string())
+ self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ if self.tb_writer:
+ for k, v in logs.items():
+ if isinstance(v, (int, float)):
+ self.tb_writer.add_scalar(k, v, state.global_step)
+ else:
+ logger.warning(
+ "Trainer is attempting to log a value of "
+ '"%s" of type %s for key "%s" as a scalar. '
+ "This invocation of Tensorboard's writer.add_scalar() "
+ "is incorrect so we dropped this attribute.",
+ v,
+ type(v),
+ k,
+ )
+ self.tb_writer.flush()
+
+ def on_train_end(self, args, state, control, **kwargs):
+ if self.tb_writer:
+ self.tb_writer.close()
+
+
+class WandbCallback(TrainerCallback):
+ """
+ A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases
+ `__.
+ """
+
+ def __init__(self):
+ assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
+ self._initialized = False
+
+ def setup(self, args, state, model):
+ """
+ Setup the optional Weights & Biases (`wandb`) integration.
+
+ One can subclass and override this method to customize the setup if needed. Find more information
+ `here `__. You can also override the following environment variables:
+
+ Environment:
+ WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
+ Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
+ logging or :obj:`"all"` to log gradients and parameters.
+ WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
+ Set this to a custom string to store results in a different project.
+ WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not to disable wandb entirely.
+ """
+ self._initialized = True
+ if state.is_world_process_zero:
+ logger.info(
+ 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
+ )
+ combined_dict = {**args.to_sanitized_dict()}
+ if hasattr(model, "config"):
+ combined_dict = {**model.config.to_dict(), **combined_dict}
+ wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name)
+ # keep track of model topology and gradients, unsupported on TPU
+ if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
+ wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))
+
+ def on_train_begin(self, args, state, control, model=None, **kwargs):
+ if not self._initialized:
+ self.setup(args, state, model)
+
+ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
+ if not self._initialized:
+ self.setup(args, state, model)
+ if state.is_world_process_zero:
+ wandb.log(logs, step=state.global_step)
+
+
+class CometCallback(TrainerCallback):
+ """
+ A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML
+ `__.
+ """
+
+ def __init__(self):
+ assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
+ self._initialized = False
+
+ def setup(self, args, state, model):
+ """
+ Setup the optional Comet.ml integration.
+
+ Environment:
+ COMET_MODE (:obj:`str`, `optional`):
+ "OFFLINE", "ONLINE", or "DISABLED"
+ COMET_PROJECT_NAME (:obj:`str`, `optional`):
+ Comet.ml project name for experiments
+ COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
+ Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"
+
+ For a number of configurable items in the environment,
+ see `here `__.
+ """
+ self._initialized = True
+ if state.is_world_process_zero:
+ comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
+ args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
+ experiment = None
+ if comet_mode == "ONLINE":
+ experiment = comet_ml.Experiment(**args)
+ logger.info("Automatic Comet.ml online logging enabled")
+ elif comet_mode == "OFFLINE":
+ args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
+ experiment = comet_ml.OfflineExperiment(**args)
+ logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
+ if experiment is not None:
+ experiment._set_model_graph(model, framework="transformers")
+ experiment._log_parameters(args, prefix="args/", framework="transformers")
+ if hasattr(model, "config"):
+ experiment._log_parameters(model.config, prefix="config/", framework="transformers")
+
+ def on_train_begin(self, args, state, control, model=None, **kwargs):
+ if not self._initialized:
+ self.setup(args, state, model)
+
+ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
+ if not self._initialized:
+ self.setup(args, state, model)
+ if state.is_world_process_zero:
+ experiment = comet_ml.config.get_global_experiment()
+ if experiment is not None:
+ experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 262ed6df59c1..b8e6e494b8f7 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -1,10 +1,26 @@
+# coding=utf-8
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
+"""
+
import inspect
-import math
import os
import re
import shutil
import warnings
-from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -15,8 +31,7 @@
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
-from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler
-from tqdm.auto import tqdm, trange
+from torch.utils.data.sampler import RandomSampler, SequentialSampler
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available
@@ -34,23 +49,35 @@
from .modeling_utils import PreTrainedModel
from .optimization import AdamW, get_linear_schedule_with_warmup
from .tokenization_utils_base import PreTrainedTokenizerBase
+from .trainer_callback import (
+ CallbackHandler,
+ DefaultFlowCallback,
+ PrinterCallback,
+ ProgressCallback,
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+)
+from .trainer_pt_utils import (
+ SequentialDistributedSampler,
+ distributed_broadcast_scalars,
+ distributed_concat,
+ get_tpu_sampler,
+ nested_concat,
+ nested_detach,
+ nested_numpify,
+ nested_xla_mesh_reduce,
+ reissue_pt_warnings,
+)
from .trainer_utils import (
PREFIX_CHECKPOINT_DIR,
BestRun,
EvalPrediction,
- EvaluationStrategy,
HPSearchBackend,
PredictionOutput,
- TrainerState,
TrainOutput,
default_compute_objective,
default_hp_space,
- distributed_broadcast_scalars,
- distributed_concat,
- nested_concat,
- nested_detach,
- nested_numpify,
- nested_xla_mesh_reduce,
set_seed,
)
from .training_args import TrainingArguments
@@ -60,7 +87,8 @@
_use_native_amp = False
_use_apex = False
-PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+
# Check if Pytorch version >= 1.6 to switch between Native AMP and Apex
if version.parse(torch.__version__) < version.parse("1.6"):
@@ -82,16 +110,20 @@
import torch_xla.distributed.parallel_loader as pl
if is_tensorboard_available():
- try:
- from torch.utils.tensorboard import SummaryWriter
- except ImportError:
- from tensorboardX import SummaryWriter
+ from .integrations import TensorBoardCallback
+
+ DEFAULT_CALLBACKS.append(TensorBoardCallback)
+
if is_wandb_available():
- import wandb
+ from .integrations import WandbCallback
+
+ DEFAULT_CALLBACKS.append(WandbCallback)
if is_comet_available():
- import comet_ml
+ from .integrations import CometCallback
+
+ DEFAULT_CALLBACKS.append(CometCallback)
if is_optuna_available():
import optuna
@@ -102,91 +134,20 @@
logger = logging.get_logger(__name__)
-def reissue_pt_warnings(caught_warnings):
- # Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
- if len(caught_warnings) > 1:
- for w in caught_warnings:
- if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
- warnings.warn(w.message, w.category)
-
-
-@contextmanager
-def torch_distributed_zero_first(local_rank: int):
- """
- Decorator to make all processes in distributed training wait for each local_master to do something.
-
- Args:
- local_rank (:obj:`int`): The rank of the local process.
- """
- if local_rank not in [-1, 0]:
- torch.distributed.barrier()
- yield
- if local_rank == 0:
- torch.distributed.barrier()
-
-
-class SequentialDistributedSampler(Sampler):
- """
- Distributed Sampler that subsamples indicies sequentially,
- making it easier to collate all results at the end.
-
- Even though we only use this sampler for eval and predict (no training),
- which means that the model params won't have to be synced (i.e. will not hang
- for synchronization even if varied number of forward passes), we still add extra
- samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
- to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
- """
-
- def __init__(self, dataset, num_replicas=None, rank=None):
- if num_replicas is None:
- if not torch.distributed.is_available():
- raise RuntimeError("Requires distributed package to be available")
- num_replicas = torch.distributed.get_world_size()
- if rank is None:
- if not torch.distributed.is_available():
- raise RuntimeError("Requires distributed package to be available")
- rank = torch.distributed.get_rank()
- self.dataset = dataset
- self.num_replicas = num_replicas
- self.rank = rank
- self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
- self.total_size = self.num_samples * self.num_replicas
-
- def __iter__(self):
- indices = list(range(len(self.dataset)))
-
- # add extra samples to make it evenly divisible
- indices += indices[: (self.total_size - len(indices))]
- assert (
- len(indices) == self.total_size
- ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
-
- # subsample
- indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
- assert (
- len(indices) == self.num_samples
- ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
-
- return iter(indices)
-
- def __len__(self):
- return self.num_samples
-
-
-def get_tpu_sampler(dataset: Dataset):
- if xm.xrt_world_size() <= 1:
- return RandomSampler(dataset)
- return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
-
-
class Trainer:
"""
Trainer is a simple but feature-complete training and eval loop for PyTorch,
optimized for 🤗 Transformers.
Args:
- model (:class:`~transformers.PreTrainedModel`, `optional`):
+ model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`):
The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed.
+
+ .. note::
+
+ :class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel`
+ provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as
+ they work the same way as the 🤗 Transformers models.
args (:class:`~transformers.TrainingArguments`, `optional`):
The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments`
with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided.
@@ -210,8 +171,11 @@ class Trainer:
compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
The function that will be used to compute metrics at evaluation. Must take a
:class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
- tb_writer (:obj:`SummaryWriter`, `optional`):
- Object to write to TensorBoard.
+ callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`):
+ A list of callbacks to customize the training loop. Will add those to the list of default callbacks
+ detailed in :doc:`here `.
+
+ If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method.
optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of
:class:`~transformers.AdamW` on your model and a scheduler given by
@@ -222,7 +186,7 @@ class Trainer:
def __init__(
self,
- model: PreTrainedModel = None,
+ model: Union[PreTrainedModel, torch.nn.Module] = None,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
@@ -230,7 +194,7 @@ def __init__(
tokenizer: Optional["PreTrainedTokenizerBase"] = None,
model_init: Callable[[], PreTrainedModel] = None,
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
- tb_writer: Optional["SummaryWriter"] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
**kwargs,
):
@@ -259,7 +223,21 @@ def __init__(
"Passing a `model_init` is incompatible with providing the `optimizers` argument."
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
- self.tb_writer = tb_writer
+ callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
+ self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback)
+
+ # Deprecated arguments
+ if "tb_writer" in kwargs:
+ warnings.warn(
+ "Passing `tb_writer` as a keyword argument is deprecated and won't be possible in a "
+ + "future version. Use `TensorBoardCallback(tb_writer=...)` instead and pass it to the `callbacks`"
+ + "argument",
+ FutureWarning,
+ )
+ tb_writer = kwargs.pop("tb_writer")
+ self.remove_callback(TensorBoardCallback)
+ self.add_callback(TensorBoardCallback(tb_writer=tb_writer))
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a "
@@ -270,13 +248,6 @@ def __init__(
self.args.prediction_loss_only = kwargs.pop("prediction_loss_only")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
- if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
- self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
- if not is_tensorboard_available():
- logger.warning(
- "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
- )
-
# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
self._loggers_initialized = False
@@ -304,6 +275,7 @@ def __init__(
self._remove_unused_columns(self.eval_dataset, description="evaluation")
self.state = TrainerState()
+ self.control = TrainerControl()
# Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the
# state at each call to self.log.
self._total_flos = None
@@ -317,6 +289,45 @@ def __init__(
else ["labels"]
)
self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
+ self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+
+ def add_callback(self, callback):
+ """
+ Add a callback to the current list of :class:`~transformer.TrainerCallback`.
+
+ Args:
+ callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
+ A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
+ In the first case, will instantiate a member of that class.
+ """
+ self.callback_handler.add_callback(callback)
+
+ def pop_callback(self, callback):
+ """
+ Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it.
+
+ If the callback is not found, returns :obj:`None` (and no error is raised).
+
+ Args:
+ callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
+ A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
+ In the first case, will pop the first member of that class found in the list of callbacks.
+
+ Returns:
+ :class:`~transformer.TrainerCallback`: The callback removed, if found.
+ """
+ return self.callback_handler.pop_callback(callback)
+
+ def remove_callback(self, callback):
+ """
+ Remove a callback from the current list of :class:`~transformer.TrainerCallback`.
+
+ Args:
+ callback (:obj:`type` or :class:`~transformer.TrainerCallback`):
+ A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`.
+ In the first case, will remove the first member of that class found in the list of callbacks.
+ """
+ self.callback_handler.remove_callback(callback)
def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
if not self.args.remove_unused_columns:
@@ -465,102 +476,12 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
)
- def setup_wandb(self):
- """
- Setup the optional Weights & Biases (`wandb`) integration.
-
- One can subclass and override this method to customize the setup if needed. Find more information
- `here `__. You can also override the following environment variables:
-
- Environment:
- WANDB_WATCH:
- (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
- or "all" to log gradients and parameters
- WANDB_PROJECT:
- (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
- WANDB_DISABLED:
- (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
- """
- if hasattr(self, "_setup_wandb"):
- warnings.warn(
- "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
- FutureWarning,
- )
- return self._setup_wandb()
-
- if self.is_world_process_zero():
- logger.info(
- 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
- )
- combined_dict = {**self.args.to_sanitized_dict()}
- if isinstance(self.model, PreTrainedModel):
- combined_dict = {**self.model.config.to_dict(), **combined_dict}
- wandb.init(
- project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
- )
- # keep track of model topology and gradients, unsupported on TPU
- if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
- wandb.watch(
- self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
- )
-
- def setup_comet(self):
- """
- Setup the optional Comet.ml integration.
-
- Environment:
- COMET_MODE:
- (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
- COMET_PROJECT_NAME:
- (Optional): str - Comet.ml project name for experiments
- COMET_OFFLINE_DIRECTORY:
- (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
-
- For a number of configurable items in the environment,
- see `here `__
- """
- if self.is_world_master():
- comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
- args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
- experiment = None
- if comet_mode == "ONLINE":
- experiment = comet_ml.Experiment(**args)
- logger.info("Automatic Comet.ml online logging enabled")
- elif comet_mode == "OFFLINE":
- args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
- experiment = comet_ml.OfflineExperiment(**args)
- logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
- if experiment is not None:
- experiment._set_model_graph(self.model, framework="transformers")
- experiment._log_parameters(self.args, prefix="args/", framework="transformers")
- if isinstance(self.model, PreTrainedModel):
- experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")
-
def num_examples(self, dataloader: DataLoader) -> int:
"""
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
"""
return len(dataloader.dataset)
- def _setup_loggers(self):
- if self._loggers_initialized:
- return
- if is_wandb_available():
- self.setup_wandb()
- elif os.environ.get("WANDB_DISABLED") != "true":
- logger.info(
- "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
- "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
- )
- if is_comet_available():
- self.setup_comet()
- elif os.environ.get("COMET_MODE") != "DISABLED":
- logger.info(
- "To use comet_ml logging, run `pip/conda install comet_ml` "
- "see https://www.comet.ml/docs/python-sdk/huggingface/"
- )
- self._loggers_initialized = True
-
def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
""" HP search setup code """
if self.hp_search_backend is None or trial is None:
@@ -661,7 +582,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))
reissue_pt_warnings(caught_warnings)
- # Moxed precision training with apex (torch < 1.6)
+ # Mixed precision training with apex (torch < 1.6)
model = self.model
if self.args.fp16 and _use_apex:
if not is_apex_available():
@@ -687,10 +608,6 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
- if self.tb_writer is not None:
- self.tb_writer.add_text("args", self.args.to_json_string())
- self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
-
# Train!
if is_torch_tpu_available():
total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
@@ -723,17 +640,25 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
logger.info(" Continuing training from global step %d", self.state.global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
+ # Update the references
+ self.callback_handler.model = self.model
+ self.callback_handler.optimizer = self.optimizer
+ self.callback_handler.lr_scheduler = self.lr_scheduler
+ self.callback_handler.train_dataloader = train_dataloader
# This should be the same if the state has been saved but in case the training arguments changed, it's safer
# to set this after the load.
self.state.max_steps = max_steps
self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
tr_loss = torch.tensor(0.0).to(self.args.device)
+ self._logging_loss_scalar = 0
self._total_flos = self.state.total_flos
- logging_loss_scalar = 0.0
model.zero_grad()
- disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
- train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm)
+
+ self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control)
+
for epoch in range(epochs_trained, num_train_epochs):
if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch)
@@ -750,15 +675,18 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
if self.args.past_index >= 0:
self._past = None
- epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
+ self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
+
for step, inputs in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
- epoch_pbar.update(1)
continue
+ if (step + 1) % self.args.gradient_accumulation_steps == 0:
+ self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control)
+
tr_loss += self.training_step(model, inputs)
self._total_flos += self.floating_point_ops(inputs)
@@ -787,50 +715,15 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / len(epoch_iterator)
+ self.control = self.callback_handler.on_step_end(self.args, self.state, self.control)
+
+ self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
- if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or (
- self.state.global_step == 1 and self.args.logging_first_step
- ):
- logs: Dict[str, float] = {}
- tr_loss_scalar = tr_loss.item()
- logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps
- # backward compatibility for pytorch schedulers
- logs["learning_rate"] = (
- self.lr_scheduler.get_last_lr()[0]
- if version.parse(torch.__version__) >= version.parse("1.4")
- else self.lr_scheduler.get_lr()[0]
- )
- logging_loss_scalar = tr_loss_scalar
-
- self.log(logs)
-
- if (
- self.args.evaluation_strategy == EvaluationStrategy.STEPS
- and self.state.global_step % self.args.eval_steps == 0
- ):
- metrics = self.evaluate()
- self._report_to_hp_search(trial, epoch, metrics)
- if self.args.load_best_model_at_end:
- self._save_training(model, trial, metrics=metrics)
-
- if (
- not self.args.load_best_model_at_end
- and self.args.save_steps > 0
- and self.state.global_step % self.args.save_steps == 0
- ):
- self._save_training(model, trial)
-
- epoch_pbar.update(1)
- if self.state.global_step >= max_steps:
+ if self.control.should_epoch_stop or self.control.should_training_stop:
break
- epoch_pbar.close()
- train_pbar.update(1)
- if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
- metrics = self.evaluate()
- self._report_to_hp_search(trial, epoch, metrics)
- if self.args.load_best_model_at_end:
- self._save_training(model, trial, metrics=metrics)
+ self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control)
+ self._maybe_log_save_evalute(tr_loss, model, trial, epoch)
if self.args.tpu_metrics_debug or self.args.debug:
if is_torch_tpu_available():
@@ -841,12 +734,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
"You enabled PyTorch/XLA debug metrics but you don't have a TPU "
"configured. Check your training configuration if this is unexpected."
)
- if self.state.global_step >= max_steps:
+ if self.control.should_training_stop:
break
- train_pbar.close()
- if self.tb_writer:
- self.tb_writer.close()
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of training
delattr(self, "_past")
@@ -863,9 +753,36 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
+ self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
+
return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step)
- def _save_training(self, model, trial, metrics=None):
+ def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch):
+ if self.control.should_log:
+ logs: Dict[str, float] = {}
+ tr_loss_scalar = tr_loss.item()
+ logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps
+ # backward compatibility for pytorch schedulers
+ logs["learning_rate"] = (
+ self.lr_scheduler.get_last_lr()[0]
+ if version.parse(torch.__version__) >= version.parse("1.4")
+ else self.lr_scheduler.get_lr()[0]
+ )
+ self._logging_loss_scalar = tr_loss_scalar
+
+ self.log(logs)
+
+ metrics = None
+ if self.control.should_evaluate:
+ metrics = self.evaluate()
+ self._report_to_hp_search(trial, epoch, metrics)
+ self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
+
+ if self.control.should_save:
+ self._save_checkpoint(model, trial, metrics=metrics)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+ def _save_checkpoint(self, model, trial, metrics=None):
# In all cases (even distributed/parallel), self.model is always a reference
# to the model we want to save.
if hasattr(model, "module"):
@@ -896,7 +813,7 @@ def _save_training(self, model, trial, metrics=None):
reissue_pt_warnings(caught_warnings)
# Determine the new best metric / best model checkpoint
- if metrics is not None:
+ if metrics is not None and self.args.metric_for_best_model is not None:
metric_to_check = self.args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
@@ -998,7 +915,7 @@ def hyperparameter_search(
self.hp_search_backend = None
return best_run
- def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
+ def log(self, logs: Dict[str, float]) -> None:
"""
Log :obj:`logs` on the various objects watching training.
@@ -1007,55 +924,22 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
Args:
logs (:obj:`Dict[str, float]`):
The values to log.
- iterator (:obj:`tqdm`, `optional`):
- A potential tqdm progress bar to write the logs on.
"""
- # Set up loggers like W&B or Comet ML
- self._setup_loggers()
-
if hasattr(self, "_log"):
warnings.warn(
"The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
FutureWarning,
)
- return self._log(logs, iterator=iterator)
+ return self._log(logs)
if self.state.epoch is not None:
logs["epoch"] = self.state.epoch
if self._total_flos is not None:
self.store_flos()
logs["total_flos"] = self.state.total_flos
- if self.tb_writer:
- for k, v in logs.items():
- if isinstance(v, (int, float)):
- self.tb_writer.add_scalar(k, v, self.state.global_step)
- else:
- logger.warning(
- "Trainer is attempting to log a value of "
- '"%s" of type %s for key "%s" as a scalar. '
- "This invocation of Tensorboard's writer.add_scalar() "
- "is incorrect so we dropped this attribute.",
- v,
- type(v),
- k,
- )
- self.tb_writer.flush()
- if is_wandb_available():
- if self.is_world_process_zero():
- wandb.log(logs, step=self.state.global_step)
- if is_comet_available():
- if self.is_world_process_zero():
- experiment = comet_ml.config.get_global_experiment()
- if experiment is not None:
- experiment._log_metrics(
- logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers"
- )
+ self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
output = {**logs, **{"step": self.state.global_step}}
self.state.log_history.append(output)
- if iterator is not None:
- iterator.write(output)
- else:
- print(output)
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
@@ -1372,8 +1256,9 @@ def prediction_loop(
if self.args.past_index >= 0:
self._past = None
- disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
- for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
+ self.callback_handler.eval_dataloader = dataloader
+
+ for inputs in dataloader:
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
if loss is not None:
@@ -1382,6 +1267,7 @@ def prediction_loop(
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
if labels is not None:
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
+ self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py
new file mode 100644
index 000000000000..bc14d99f25e1
--- /dev/null
+++ b/src/transformers/trainer_callback.py
@@ -0,0 +1,468 @@
+# coding=utf-8
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Callbacks to use with the Trainer class and customize the training loop.
+"""
+
+import dataclasses
+import json
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+from tqdm.auto import tqdm
+
+from .trainer_utils import EvaluationStrategy
+from .training_args import TrainingArguments
+from .utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+
+@dataclass
+class TrainerState:
+ """
+ A class containing the :class:`~transformers.Trainer` inner state that will be saved along the model and optimizer
+ when checkpointing and passed to the :class:`~transformers.TrainerCallback`.
+
+ .. note::
+
+ In all this class, one step is to be understood as one update step. When using gradient accumulation, one
+ update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
+ then one update step requires going throuch `n` batches.
+
+ Args:
+ epoch (:obj:`float`, `optional`):
+ Only set during training, will represent the epoch the training is at (the decimal part being the
+ percentage of the current epoch completed).
+ global_step (:obj:`int`, `optional`, defaults to 0):
+ During training, represents the number of update steps completed.
+ max_steps (:obj:`int`, `optional`, defaults to 0):
+ The number of update steps to do during the current training.
+ total_flos (:obj:`int`, `optional`, defaults to 0):
+ The total number of floating operations done by the model since the beginning of training.
+ log_history (:obj:`List[Dict[str, float]]`, `optional`):
+ The list of logs done since the beginning of training.
+ best_metric (:obj:`float`, `optional`):
+ When tracking the best model, the value of the best metric encountered so far.
+ best_model_checkpoint (:obj:`str`, `optional`):
+ When tracking the best model, the value of the name of the checkpoint for the best model encountered so
+ far.
+ is_local_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
+ several machines) main process.
+ is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`):
+ Whether or not this process is the global main process (when training in a distributed fashion on
+ several machines, this is only going to be :obj:`True` for one process).
+ """
+
+ epoch: Optional[float] = None
+ global_step: int = 0
+ max_steps: int = 0
+ num_train_epochs: int = 0
+ total_flos: int = 0
+ log_history: List[Dict[str, float]] = None
+ best_metric: Optional[float] = None
+ best_model_checkpoint: Optional[str] = None
+ is_local_process_zero: bool = True
+ is_world_process_zero: bool = True
+
+ def __post_init__(self):
+ if self.log_history is None:
+ self.log_history = []
+
+ def save_to_json(self, json_path: str):
+ """ Save the content of this instance in JSON format inside :obj:`json_path`."""
+ json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
+ with open(json_path, "w", encoding="utf-8") as f:
+ f.write(json_string)
+
+ @classmethod
+ def load_from_json(cls, json_path: str):
+ """ Create an instance from the content of :obj:`json_path`."""
+ with open(json_path, "r", encoding="utf-8") as f:
+ text = f.read()
+ return cls(**json.loads(text))
+
+
+@dataclass
+class TrainerControl:
+ """
+ A class that handles the :class:`~transformers.Trainer` control flow. This class is used by the
+ :class:`~transformers.TrainerCallback` to activate some switches in the training loop.
+
+ Args:
+ should_training_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the training should be interrupted.
+
+ If :obj:`True`, this variable will not be set back to :obj:`False`. The training will just stop.
+ should_epoch_stop (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the current epoch should be interrupted.
+
+ If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next epoch.
+ should_save (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the model should be saved at this step.
+
+ If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
+ should_evaluate (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the model should be evaluated at this step.
+
+ If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
+ should_log (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the logs should be reported at this step.
+
+ If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step.
+ """
+
+ should_training_stop: bool = False
+ should_epoch_stop: bool = False
+ should_save: bool = False
+ should_evaluate: bool = False
+ should_log: bool = False
+
+ def _new_training(self):
+ """ Internal method that resets the variable for a new training. """
+ self.should_training_stop = False
+
+ def _new_epoch(self):
+ """ Internal method that resets the variable for a new epoch. """
+ self.should_epoch_stop = False
+
+ def _new_step(self):
+ """ Internal method that resets the variable for a new step. """
+ self.should_save_model = False
+ self.should_evaluate = False
+ self.should_log = False
+
+
+class TrainerCallback:
+ """
+ A class for objects that will inspect the state of the training loop at some events and take some decisions. At
+ each of those events the following arguments are available:
+
+ Args:
+ args (:class:`~transformers.TrainingArguments`):
+ The training arguments used to instantiate the :class:`~transformers.Trainer`.
+ state (:class:`~transformers.TrainerState`):
+ The current state of the :class:`~transformers.Trainer`.
+ control (:class:`~transformers.TrainerControl`):
+ The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
+ model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
+ The model being trained.
+ optimizer (:obj:`torch.optim.Optimizer`):
+ The optimizer used for the training steps.
+ lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
+ The scheduler used for setting the learning rate.
+ train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
+ The current dataloader used for training.
+ eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`):
+ The current dataloader used for training.
+ metrics (:obj:`Dict[str, float]`):
+ The metrics computed by the last evaluation phase.
+
+ Those are only accessible in the event :obj:`on_evaluate`.
+ logs (:obj:`Dict[str, float]`):
+ The values to log.
+
+ Those are only accessible in the event :obj:`on_log`.
+
+ The :obj:`control` object is the only one that can be changed by the callback, in which case the event that changes
+ it should return the modified version.
+
+ The argument :obj:`args`, :obj:`state` and :obj:`control` are positionals for all events, all the others are
+ grouped in :obj:`kwargs`. You can unpack the ones you need in the signature of the event using them. As an example,
+ see the code of the simple :class:`~transformer.PrinterCallback`.
+
+ Example::
+
+ class PrinterCallback(TrainerCallback):
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ _ = logs.pop("total_flos", None)
+ if state.is_local_process_zero:
+ print(logs)
+ """
+
+ def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of the initialization of the :class:`~transformers.Trainer`.
+ """
+ pass
+
+ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the beginning of training.
+ """
+ pass
+
+ def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of training.
+ """
+ pass
+
+ def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the beginning of an epoch.
+ """
+ pass
+
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of an epoch.
+ """
+ pass
+
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the beginning of a training step. If using gradient accumulation, one training step might take
+ several inputs.
+ """
+ pass
+
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called at the end of a training step. If using gradient accumulation, one training step might take
+ several inputs.
+ """
+ pass
+
+ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after an evaluation phase.
+ """
+ pass
+
+ def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after a checkpoint save.
+ """
+ pass
+
+ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after logging the last logs.
+ """
+ pass
+
+ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ """
+ Event called after a prediction step.
+ """
+ pass
+
+
+class CallbackHandler(TrainerCallback):
+ """ Internal class that just calls the list of callbacks in order. """
+
+ def __init__(self, callbacks, model, optimizer, lr_scheduler):
+ self.callbacks = []
+ for cb in callbacks:
+ self.add_callback(cb)
+ self.model = model
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ self.train_dataloader = None
+ self.eval_dataloader = None
+
+ if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks):
+ logger.warn(
+ "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n"
+ + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of"
+ + "callbacks is\n:"
+ + self.callback_list
+ )
+
+ def add_callback(self, callback):
+ cb = callback() if isinstance(callback, type) else callback
+ cb_class = callback if isinstance(callback, type) else callback.__class__
+ if cb_class in [c.__class__ for c in self.callbacks]:
+ logger.warn(
+ f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current"
+ + "list of callbacks is\n:"
+ + self.callback_list
+ )
+ self.callbacks.append(cb)
+
+ def pop_callback(self, callback):
+ if isinstance(callback, type):
+ for cb in self.callbacks:
+ if isinstance(cb, callback):
+ self.callbacks.remove(cb)
+ return cb
+ else:
+ for cb in self.callbacks:
+ if cb == callback:
+ self.callbacks.remove(cb)
+ return cb
+
+ def remove_callback(self, callback):
+ if isinstance(callback, type):
+ for cb in self.callbacks:
+ if isinstance(cb, callback):
+ self.callbacks.remove(cb)
+ return
+ else:
+ self.callbacks.remove(callback)
+
+ @property
+ def callback_list(self):
+ return "\n".join(self.callbacks)
+
+ def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_init_end", args, state, control)
+
+ def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ control.should_training_stop = False
+ return self.call_event("on_train_begin", args, state, control)
+
+ def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_train_end", args, state, control)
+
+ def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ control.should_epoch_stop = False
+ return self.call_event("on_epoch_begin", args, state, control)
+
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_epoch_end", args, state, control)
+
+ def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ control.should_log = False
+ control.should_evaluate = False
+ control.should_save = False
+ return self.call_event("on_step_begin", args, state, control)
+
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_step_end", args, state, control)
+
+ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics):
+ control.should_evaluate = False
+ return self.call_event("on_evaluate", args, state, control, metrics=metrics)
+
+ def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ control.should_save = False
+ return self.call_event("on_save", args, state, control)
+
+ def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs):
+ control.should_log = False
+ return self.call_event("on_log", args, state, control, logs=logs)
+
+ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
+ return self.call_event("on_prediction_step", args, state, control)
+
+ def call_event(self, event, args, state, control, **kwargs):
+ for callback in self.callbacks:
+ result = getattr(callback, event)(
+ args,
+ state,
+ control,
+ model=self.model,
+ optimizer=self.optimizer,
+ lr_scheduler=self.lr_scheduler,
+ train_dataloader=self.train_dataloader,
+ eval_dataloader=self.eval_dataloader,
+ **kwargs,
+ )
+ # A Callback can skip the return of `control` if it doesn't change it.
+ if result is not None:
+ control = result
+ return control
+
+
+class DefaultFlowCallback(TrainerCallback):
+ """
+ A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
+ and checkpoints.
+ """
+
+ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ # Log
+ if state.global_step == 1 and args.logging_first_step:
+ control.should_log = True
+ if args.logging_steps > 0 and state.global_step % args.logging_steps == 0:
+ control.should_log = True
+
+ # Evaluate
+ if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0:
+ control.should_evaluate = True
+ if args.load_best_model_at_end:
+ control.should_save = True
+
+ # Save
+ if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0:
+ control.should_save = True
+
+ # End training
+ if state.global_step >= state.max_steps:
+ control.should_training_stop = True
+
+ return control
+
+ def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
+ if args.evaluation_strategy == EvaluationStrategy.EPOCH:
+ control.should_evaluate = True
+ if args.load_best_model_at_end:
+ control.should_save = True
+ return control
+
+
+class ProgressCallback(TrainerCallback):
+ """
+ A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation.
+ """
+
+ def __init__(self):
+ self.training_bar = None
+ self.prediction_bar = None
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ if state.is_local_process_zero:
+ self.training_bar = tqdm(total=state.max_steps)
+
+ def on_step_end(self, args, state, control, **kwargs):
+ if state.is_local_process_zero:
+ self.training_bar.update(1)
+
+ def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
+ if state.is_local_process_zero:
+ if self.prediction_bar is None:
+ self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None)
+ self.prediction_bar.update(1)
+
+ def on_evaluate(self, args, state, control, **kwargs):
+ if state.is_local_process_zero:
+ self.prediction_bar.close()
+ self.prediction_bar = None
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ if state.is_local_process_zero and self.training_bar is not None:
+ _ = logs.pop("total_flos", None)
+ self.training_bar.write(str(logs))
+
+ def on_train_end(self, args, state, control, **kwargs):
+ if state.is_local_process_zero:
+ self.training_bar.close()
+ self.training_bar = None
+
+
+class PrinterCallback(TrainerCallback):
+ """
+ A bare :class:`~transformers.TrainerCallback` that just prints the logs.
+ """
+
+ def on_log(self, args, state, control, logs=None, **kwargs):
+ _ = logs.pop("total_flos", None)
+ if state.is_local_process_zero:
+ print(logs)
diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py
new file mode 100644
index 000000000000..74a93f8286b8
--- /dev/null
+++ b/src/transformers/trainer_pt_utils.py
@@ -0,0 +1,179 @@
+# coding=utf-8
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Torch utilities for the Trainer class.
+"""
+
+import math
+import warnings
+from contextlib import contextmanager
+from typing import List, Optional, Union
+
+import torch
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.data.sampler import RandomSampler, Sampler
+
+from .file_utils import is_torch_tpu_available
+
+
+if is_torch_tpu_available():
+ import torch_xla.core.xla_model as xm
+
+PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."
+
+
+def nested_concat(tensors, new_tensors, dim=0):
+ "Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
+ assert type(tensors) == type(
+ new_tensors
+ ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
+ if isinstance(tensors, (list, tuple)):
+ return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
+ return torch.cat((tensors, new_tensors), dim=dim)
+
+
+def nested_numpify(tensors):
+ "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
+ if isinstance(tensors, (list, tuple)):
+ return type(tensors)(nested_numpify(t) for t in tensors)
+ return tensors.cpu().numpy()
+
+
+def nested_detach(tensors):
+ "Detach `tensors` (even if it's a nested list/tuple of tensors)."
+ if isinstance(tensors, (list, tuple)):
+ return type(tensors)(nested_detach(t) for t in tensors)
+ return tensors.detach()
+
+
+def nested_xla_mesh_reduce(tensors, name):
+ if is_torch_tpu_available():
+ import torch_xla.core.xla_model as xm
+
+ if isinstance(tensors, (list, tuple)):
+ return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
+ return xm.mesh_reduce(name, tensors, torch.cat)
+ else:
+ raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
+
+
+def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor:
+ try:
+ if isinstance(tensor, (tuple, list)):
+ return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
+ output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(output_tensors, tensor)
+ concat = torch.cat(output_tensors, dim=0)
+
+ # truncate the dummy elements added by SequentialDistributedSampler
+ if num_total_examples is not None:
+ concat = concat[:num_total_examples]
+ return concat
+ except AssertionError:
+ raise AssertionError("Not currently using distributed training")
+
+
+def distributed_broadcast_scalars(
+ scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
+) -> torch.Tensor:
+ try:
+ tensorized_scalar = torch.tensor(scalars).cuda()
+ output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(output_tensors, tensorized_scalar)
+ concat = torch.cat(output_tensors, dim=0)
+
+ # truncate the dummy elements added by SequentialDistributedSampler
+ if num_total_examples is not None:
+ concat = concat[:num_total_examples]
+ return concat
+ except AssertionError:
+ raise AssertionError("Not currently using distributed training")
+
+
+def reissue_pt_warnings(caught_warnings):
+ # Reissue warnings that are not the PT_LR_SCHEDULER_WARNING
+ if len(caught_warnings) > 1:
+ for w in caught_warnings:
+ if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING:
+ warnings.warn(w.message, w.category)
+
+
+@contextmanager
+def torch_distributed_zero_first(local_rank: int):
+ """
+ Decorator to make all processes in distributed training wait for each local_master to do something.
+
+ Args:
+ local_rank (:obj:`int`): The rank of the local process.
+ """
+ if local_rank not in [-1, 0]:
+ torch.distributed.barrier()
+ yield
+ if local_rank == 0:
+ torch.distributed.barrier()
+
+
+class SequentialDistributedSampler(Sampler):
+ """
+ Distributed Sampler that subsamples indicies sequentially,
+ making it easier to collate all results at the end.
+
+ Even though we only use this sampler for eval and predict (no training),
+ which means that the model params won't have to be synced (i.e. will not hang
+ for synchronization even if varied number of forward passes), we still add extra
+ samples to the sampler to make it evenly divisible (like in `DistributedSampler`)
+ to make it easy to `gather` or `reduce` resulting tensors at the end of the loop.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None):
+ if num_replicas is None:
+ if not torch.distributed.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = torch.distributed.get_world_size()
+ if rank is None:
+ if not torch.distributed.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = torch.distributed.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ indices = list(range(len(self.dataset)))
+
+ # add extra samples to make it evenly divisible
+ indices += indices[: (self.total_size - len(indices))]
+ assert (
+ len(indices) == self.total_size
+ ), f"Indices length {len(indices)} and total size {self.total_size} mismatched"
+
+ # subsample
+ indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
+ assert (
+ len(indices) == self.num_samples
+ ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+
+def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
+ if xm.xrt_world_size() <= 1:
+ return RandomSampler(dataset)
+ return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py
index e816b0772a9c..96757a922e6a 100644
--- a/src/transformers/trainer_utils.py
+++ b/src/transformers/trainer_utils.py
@@ -1,19 +1,30 @@
-import dataclasses
-import json
+# coding=utf-8
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
+"""
+
import random
-from dataclasses import dataclass
-from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
+from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import numpy as np
-from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available
+from .file_utils import is_tf_available, is_torch_available
from .tokenization_utils_base import ExplicitEnum
-if is_torch_available():
- import torch
-
-
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf``
@@ -139,144 +150,3 @@ class HPSearchBackend(ExplicitEnum):
HPSearchBackend.OPTUNA: default_hp_space_optuna,
HPSearchBackend.RAY: default_hp_space_ray,
}
-
-
-def nested_concat(tensors, new_tensors, dim=0):
- "Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
- if is_torch_available():
- assert type(tensors) == type(
- new_tensors
- ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
- if isinstance(tensors, (list, tuple)):
- return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
- return torch.cat((tensors, new_tensors), dim=dim)
- else:
- raise ImportError("Torch must be installed to use `nested_concat`")
-
-
-def nested_deatch(tensors):
- "Detach `tensors` (even if it's a nested list/tuple of tensors)."
- if isinstance(tensors, (list, tuple)):
- return type(tensors)(nested_detach(t) for t in tensors)
- return tensors.detach()
-
-
-def nested_numpify(tensors):
- "Numpify `tensors` (even if it's a nested list/tuple of tensors)."
- if isinstance(tensors, (list, tuple)):
- return type(tensors)(nested_numpify(t) for t in tensors)
- return tensors.cpu().numpy()
-
-
-def nested_detach(tensors):
- "Detach `tensors` (even if it's a nested list/tuple of tensors)."
- if isinstance(tensors, (list, tuple)):
- return type(tensors)(nested_detach(t) for t in tensors)
- return tensors.detach()
-
-
-def nested_xla_mesh_reduce(tensors, name):
- if is_torch_tpu_available():
- import torch_xla.core.xla_model as xm
-
- if isinstance(tensors, (list, tuple)):
- return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
- return xm.mesh_reduce(name, tensors, torch.cat)
- else:
- raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
-
-
-def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor":
- if is_torch_available():
- try:
- if isinstance(tensor, (tuple, list)):
- return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor)
- output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
- torch.distributed.all_gather(output_tensors, tensor)
- concat = torch.cat(output_tensors, dim=0)
-
- # truncate the dummy elements added by SequentialDistributedSampler
- if num_total_examples is not None:
- concat = concat[:num_total_examples]
- return concat
- except AssertionError:
- raise AssertionError("Not currently using distributed training")
- else:
- raise ImportError("Torch must be installed to use `distributed_concat`")
-
-
-def distributed_broadcast_scalars(
- scalars: List[Union[int, float]], num_total_examples: Optional[int] = None
-) -> "torch.Tensor":
- if is_torch_available():
- try:
- tensorized_scalar = torch.tensor(scalars).cuda()
- output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())]
- torch.distributed.all_gather(output_tensors, tensorized_scalar)
- concat = torch.cat(output_tensors, dim=0)
-
- # truncate the dummy elements added by SequentialDistributedSampler
- if num_total_examples is not None:
- concat = concat[:num_total_examples]
- return concat
- except AssertionError:
- raise AssertionError("Not currently using distributed training")
- else:
- raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`")
-
-
-@dataclass
-class TrainerState:
- """
- A class containing the `Trainer` inner state that will be saved along the model and optimizer.
-
- .. note::
-
- In all this class, one step is to be understood as one update step. When using gradient accumulation, one
- update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`,
- then one update step requires going throuch `n` batches.
-
- Args:
- epoch (:obj:`float`, `optional`):
- Only set during training, will represent the epoch the training is at (the decimal part being the
- percentage of the current epoch completed).
- global_step (:obj:`int`, `optional`, defaults to 0):
- During training, represents the number of update steps completed.
- max_steps (:obj:`int`, `optional`, defaults to 0):
- The number of update steps to do during the current training.
- total_flos (:obj:`int`, `optional`, defaults to 0):
- The total number of floating operations done by the model since the beginning of training.
- log_history (:obj:`List[Dict[str, float]]`, `optional`):
- The list of logs done since the beginning of training.
- best_metric (:obj:`float`, `optional`):
- When tracking the best model, the value of the best metric encountered so far.
- best_model_checkpoint (:obj:`str`, `optional`):
- When tracking the best model, the value of the name of the checkpoint for the best model encountered so
- far.
- """
-
- epoch: Optional[float] = None
- global_step: int = 0
- max_steps: int = 0
- num_train_epochs: int = 0
- total_flos: int = 0
- log_history: List[Dict[str, float]] = None
- best_metric: Optional[float] = None
- best_model_checkpoint: Optional[str] = None
-
- def __post_init__(self):
- if self.log_history is None:
- self.log_history = []
-
- def save_to_json(self, json_path: str):
- """ Save the content of this instance in JSON format inside :obj:`json_path`."""
- json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"
- with open(json_path, "w", encoding="utf-8") as f:
- f.write(json_string)
-
- @classmethod
- def load_from_json(cls, json_path: str):
- """ Create an instance from the content of :obj:`json_path`."""
- with open(json_path, "r", encoding="utf-8") as f:
- text = f.read()
- return cls(**json.loads(text))
diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py
index 2d75e7518303..9359a9f17ced 100644
--- a/src/transformers/training_args.py
+++ b/src/transformers/training_args.py
@@ -54,7 +54,7 @@ class TrainingArguments:
:obj:`"no"`.
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run predictions on the test set or not.
- evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
+ evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
The evaluation strategy to adopt during training. Possible values are:
* :obj:`"no"`: No evaluation is done during training.
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index ab854e5e7477..0521051a0a96 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -1860,19 +1860,10 @@ def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
-class EvalPrediction:
- def __init__(self, *args, **kwargs):
- requires_pytorch(self)
-
-
class Trainer:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
-def set_seed(*args, **kwargs):
- requires_pytorch(set_seed)
-
-
def torch_distributed_zero_first(*args, **kwargs):
requires_pytorch(torch_distributed_zero_first)
diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py
new file mode 100644
index 000000000000..0469c077fa75
--- /dev/null
+++ b/tests/test_trainer_callback.py
@@ -0,0 +1,214 @@
+import shutil
+import tempfile
+import unittest
+
+from transformers import (
+ DefaultFlowCallback,
+ EvaluationStrategy,
+ PrinterCallback,
+ ProgressCallback,
+ Trainer,
+ TrainerCallback,
+ TrainingArguments,
+ is_torch_available,
+)
+from transformers.testing_utils import require_torch
+
+
+if is_torch_available():
+ from transformers.trainer import DEFAULT_CALLBACKS
+
+ from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel
+
+
+class TestTrainerCallback(TrainerCallback):
+ "A callback that registers the events that goes through."
+
+ def __init__(self):
+ self.events = []
+
+ def on_init_end(self, args, state, control, **kwargs):
+ self.events.append("on_init_end")
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ self.events.append("on_train_begin")
+
+ def on_train_end(self, args, state, control, **kwargs):
+ self.events.append("on_train_end")
+
+ def on_epoch_begin(self, args, state, control, **kwargs):
+ self.events.append("on_epoch_begin")
+
+ def on_epoch_end(self, args, state, control, **kwargs):
+ self.events.append("on_epoch_end")
+
+ def on_step_begin(self, args, state, control, **kwargs):
+ self.events.append("on_step_begin")
+
+ def on_step_end(self, args, state, control, **kwargs):
+ self.events.append("on_step_end")
+
+ def on_evaluate(self, args, state, control, **kwargs):
+ self.events.append("on_evaluate")
+
+ def on_save(self, args, state, control, **kwargs):
+ self.events.append("on_save")
+
+ def on_log(self, args, state, control, **kwargs):
+ self.events.append("on_log")
+
+ def on_prediction_step(self, args, state, control, **kwargs):
+ self.events.append("on_prediction_step")
+
+
+@require_torch
+class TrainerCallbackTest(unittest.TestCase):
+ def setUp(self):
+ self.output_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.output_dir)
+
+ def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs):
+ # disable_tqdm in TrainingArguments has a flaky default since it depends on the level of logging. We make sure
+ # its set to False since the tests later on depend on its value.
+ train_dataset = RegressionDataset(length=train_len)
+ eval_dataset = RegressionDataset(length=eval_len)
+ config = RegressionModelConfig(a=a, b=b)
+ model = RegressionPreTrainedModel(config)
+
+ args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, **kwargs)
+ return Trainer(
+ model,
+ args,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ callbacks=callbacks,
+ )
+
+ def check_callbacks_equality(self, cbs1, cbs2):
+ self.assertEqual(len(cbs1), len(cbs2))
+
+ # Order doesn't matter
+ cbs1 = list(sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
+ cbs2 = list(sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__))
+
+ for cb1, cb2 in zip(cbs1, cbs2):
+ if isinstance(cb1, type) and isinstance(cb2, type):
+ self.assertEqual(cb1, cb2)
+ elif isinstance(cb1, type) and not isinstance(cb2, type):
+ self.assertEqual(cb1, cb2.__class__)
+ elif not isinstance(cb1, type) and isinstance(cb2, type):
+ self.assertEqual(cb1.__class__, cb2)
+ else:
+ self.assertEqual(cb1, cb2)
+
+ def get_expected_events(self, trainer):
+ expected_events = ["on_init_end", "on_train_begin"]
+ step = 0
+ train_dl_len = len(trainer.get_eval_dataloader())
+ evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"]
+ for _ in range(trainer.state.num_train_epochs):
+ expected_events.append("on_epoch_begin")
+ for _ in range(train_dl_len):
+ step += 1
+ expected_events += ["on_step_begin", "on_step_end"]
+ if step % trainer.args.logging_steps == 0:
+ expected_events.append("on_log")
+ if (
+ trainer.args.evaluation_strategy == EvaluationStrategy.STEPS
+ and step % trainer.args.eval_steps == 0
+ ):
+ expected_events += evaluation_events.copy()
+ if step % trainer.args.save_steps == 0:
+ expected_events.append("on_save")
+ expected_events.append("on_epoch_end")
+ if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH:
+ expected_events += evaluation_events.copy()
+ expected_events.append("on_train_end")
+ return expected_events
+
+ def test_init_callback(self):
+ trainer = self.get_trainer()
+ expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ # Callbacks passed at init are added to the default callbacks
+ trainer = self.get_trainer(callbacks=[TestTrainerCallback])
+ expected_callbacks.append(TestTrainerCallback)
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ # TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback
+ trainer = self.get_trainer(disable_tqdm=True)
+ expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback]
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ def test_add_remove_callback(self):
+ expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback]
+ trainer = self.get_trainer()
+
+ # We can add, pop, or remove by class name
+ trainer.remove_callback(DefaultFlowCallback)
+ expected_callbacks.remove(DefaultFlowCallback)
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ trainer = self.get_trainer()
+ cb = trainer.pop_callback(DefaultFlowCallback)
+ self.assertEqual(cb.__class__, DefaultFlowCallback)
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ trainer.add_callback(DefaultFlowCallback)
+ expected_callbacks.insert(0, DefaultFlowCallback)
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ # We can also add, pop, or remove by instance
+ trainer = self.get_trainer()
+ cb = trainer.callback_handler.callbacks[0]
+ trainer.remove_callback(cb)
+ expected_callbacks.remove(DefaultFlowCallback)
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ trainer = self.get_trainer()
+ cb1 = trainer.callback_handler.callbacks[0]
+ cb2 = trainer.pop_callback(cb1)
+ self.assertEqual(cb1, cb2)
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ trainer.add_callback(cb1)
+ expected_callbacks.insert(0, DefaultFlowCallback)
+ self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks)
+
+ def test_event_flow(self):
+ trainer = self.get_trainer(callbacks=[TestTrainerCallback])
+ trainer.train()
+ events = trainer.callback_handler.callbacks[-2].events
+ self.assertEqual(events, self.get_expected_events(trainer))
+
+ # Independent log/save/eval
+ trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5)
+ trainer.train()
+ events = trainer.callback_handler.callbacks[-2].events
+ self.assertEqual(events, self.get_expected_events(trainer))
+
+ trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5)
+ trainer.train()
+ events = trainer.callback_handler.callbacks[-2].events
+ self.assertEqual(events, self.get_expected_events(trainer))
+
+ trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps")
+ trainer.train()
+ events = trainer.callback_handler.callbacks[-2].events
+ self.assertEqual(events, self.get_expected_events(trainer))
+
+ trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch")
+ trainer.train()
+ events = trainer.callback_handler.callbacks[-2].events
+ self.assertEqual(events, self.get_expected_events(trainer))
+
+ # A bit of everything
+ trainer = self.get_trainer(
+ callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps"
+ )
+ trainer.train()
+ events = trainer.callback_handler.callbacks[-2].events
+ self.assertEqual(events, self.get_expected_events(trainer))