diff --git a/docs/source/index.rst b/docs/source/index.rst index 0545d46240bc..400c40827568 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -213,6 +213,7 @@ conversion utilities for the following models: :maxdepth: 2 :caption: Main Classes + main_classes/callback main_classes/configuration main_classes/logging main_classes/model @@ -270,3 +271,4 @@ conversion utilities for the following models: internal/modeling_utils internal/pipelines_utils internal/tokenization_utils + internal/trainer_utils diff --git a/docs/source/internal/trainer_utils.rst b/docs/source/internal/trainer_utils.rst new file mode 100644 index 000000000000..48e8568b9530 --- /dev/null +++ b/docs/source/internal/trainer_utils.rst @@ -0,0 +1,21 @@ +Utilities for Trainer +----------------------------------------------------------------------------------------------------------------------- + +This page lists all the utility functions used by :class:`~transformers.Trainer`. + +Most of those are only useful if you are studying the code of the Trainer in the library. + +Utilities +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.EvalPrediction + +.. autofunction:: transformers.set_seed + +.. autofunction:: transformers.torch_distributed_zero_first + + +Callbacks internals +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.trainer_callback.CallbackHandler diff --git a/docs/source/main_classes/callback.rst b/docs/source/main_classes/callback.rst new file mode 100644 index 000000000000..1b31c5464521 --- /dev/null +++ b/docs/source/main_classes/callback.rst @@ -0,0 +1,68 @@ +Callbacks +----------------------------------------------------------------------------------------------------------------------- + +Callbacks are objects that can customize the behavior of the training loop in the PyTorch +:class:`~transformers.Trainer` (this feature is not yet implemented in TensorFlow) that can inspect the training loop +state (for progress reporting, logging on TensorBoard or other ML platforms...) and take decisions (like early +stopping). + +Callbacks are "read only" pieces of code, apart from the :class:`~transformers.TrainerControl` object they return, they +cannot change anything in the training loop. For customizations that require changes in the training loop, you should +subclass :class:`~transformers.Trainer` and override the methods you need (see :doc:`trainer` for examples). + +By default a :class:`~transformers.Trainer` will use the following callbacks: + +- :class:`~transformers.DefaultFlowCallback` which handles the default beahvior for logging, saving and evaluation. +- :class:`~transformers.PrinterCallback` or :class:`~transformers.ProrgressCallback` to display progress and print the + logs (the first one is used if you deactivate tqdm through the :class:`~transformers.TrainingArguments`, otherwise + it's the second one). +- :class:`~transformers.integrations.TensorBoardCallback` if tensorboard is accessible (either through PyTorch >= 1.4 + or tensorboardX). +- :class:`~transformers.integrations.WandbCallback` if `wandb `__ is installed. +- :class:`~transformers.integrations.CometCallback` if `comet_ml `__ is installed. + +The main class that implements callbacks is :class:`~transformers.TrainerCallback`. It gets the +:class:`~transformers.TrainingArguments` used to instantiate the :class:`~transformers.Trainer`, can access that +Trainer's internal state via :class:`~transformers.TrainerState`, and can take some actions on the training loop via +:class:`~transformers.TrainerControl`. + + +Available Callbacks +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Here is the list of the available :class:`~transformers.TrainerCallback` in the library: + +.. autoclass:: transformers.integrations.CometCallback + :members: setup + +.. autoclass:: transformers.DefaultFlowCallback + +.. autoclass:: transformers.PrinterCallback + +.. autoclass:: transformers.ProgressCallback + +.. autoclass:: transformers.integrations.TensorBoardCallback + +.. autoclass:: transformers.integrations.WandbCallback + :members: setup + + +TrainerCallback +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TrainerCallback + :members: + + +TrainerState +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TrainerState + :members: + + +TrainerControl +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TrainerControl + :members: diff --git a/docs/source/main_classes/trainer.rst b/docs/source/main_classes/trainer.rst index 76cd1f34d75c..07050d1707aa 100644 --- a/docs/source/main_classes/trainer.rst +++ b/docs/source/main_classes/trainer.rst @@ -18,7 +18,6 @@ previous features. To inject custom behavior you can subclass them and override - **get_eval_dataloader**/**get_eval_tfdataset** -- Creates the evaulation DataLoader (PyTorch) or TF Dataset. - **get_test_dataloader**/**get_test_tfdataset** -- Creates the test DataLoader (PyTorch) or TF Dataset. - **log** -- Logs information on the various objects watching training. -- **setup_wandb** -- Setups wandb (see `here `__ for more information). - **create_optimizer_and_scheduler** -- Setups the optimizer and learning rate scheduler if they were not passed at init. - **compute_loss** - Computes the loss on a batch of training inputs. @@ -40,6 +39,10 @@ Here is an example of how to customize :class:`~transformers.Trainer` using a cu logits = outputs[0] return my_custom_loss(logits, labels) +Another way to customize the training loop behavior for the PyTorch :class:`~transformers.Trainer` is to use +:doc:`callbacks ` that can inspect the training loop state (for progress reporting, logging on TensorBoard or +other ML platforms...) and take decisions (like early stopping). + Trainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -47,29 +50,23 @@ Trainer .. autoclass:: transformers.Trainer :members: + TFTrainer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFTrainer :members: + TrainingArguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TrainingArguments :members: + TFTrainingArguments ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.TFTrainingArguments :members: - -Utilities -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: transformers.EvalPrediction - -.. autofunction:: transformers.set_seed - -.. autofunction:: transformers.torch_distributed_zero_first diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index 0f585eb26218..293244df41d2 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -9,7 +9,7 @@ from transformers.configuration_fsmt import FSMTConfig from transformers.file_utils import is_torch_tpu_available from transformers.optimization import Adafactor, AdamW, get_linear_schedule_with_warmup -from transformers.trainer import get_tpu_sampler +from transformers.trainer_pt_utils import get_tpu_sampler try: diff --git a/examples/seq2seq/test_finetune_trainer.py b/examples/seq2seq/test_finetune_trainer.py index 517e76b2327e..9002f0284a69 100644 --- a/examples/seq2seq/test_finetune_trainer.py +++ b/examples/seq2seq/test_finetune_trainer.py @@ -4,7 +4,8 @@ from unittest.mock import patch from transformers.testing_utils import slow -from transformers.trainer_utils import TrainerState, set_seed +from transformers.trainer_callback import TrainerState +from transformers.trainer_utils import set_seed from .finetune_trainer import main from .test_seq2seq_examples import MBART_TINY diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9ab37830cfea..b8bb0f8a8eb9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -205,7 +205,15 @@ from .tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer # Trainer -from .trainer_utils import EvalPrediction, TrainerState, set_seed +from .trainer_callback import ( + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from .trainer_utils import EvalPrediction, EvaluationStrategy, set_seed from .training_args import TrainingArguments from .training_args_tf import TFTrainingArguments from .utils import logging @@ -528,7 +536,8 @@ from .tokenization_marian import MarianTokenizer # Trainer - from .trainer import EvalPrediction, Trainer, set_seed, torch_distributed_zero_first + from .trainer import Trainer + from .trainer_pt_utils import torch_distributed_zero_first else: from .utils.dummy_pt_objects import * diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 15badb59cc01..9e0ee0cbb01b 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -2,6 +2,11 @@ import math import os +from .file_utils import is_torch_tpu_available +from .trainer_callback import TrainerCallback +from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun +from .utils import logging + try: import comet_ml # noqa: F401 @@ -36,15 +41,6 @@ except (ImportError): _has_ray = False - -# No ML framework or transformer imports above this point - -from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun # isort:skip -from .utils import logging # isort:skip - -logger = logging.get_logger(__name__) - - try: from torch.utils.tensorboard import SummaryWriter # noqa: F401 @@ -57,9 +53,10 @@ except ImportError: _has_tensorboard = False -# Integration functions: +logger = logging.get_logger(__name__) +# Integration functions: def is_wandb_available(): return _has_wandb @@ -128,8 +125,8 @@ def _objective(trial, checkpoint_dir=None): # The model and TensorBoard writer do not pickle so we have to remove them (if they exists) # while doing the ray hp search. - _tb_writer = trainer.tb_writer - trainer.tb_writer = None + + _tb_writer = trainer.pop_callback(TensorBoardCallback) trainer.model = None # Setup default `resources_per_trial` and `reporter`. if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0: @@ -182,5 +179,159 @@ def _objective(trial, checkpoint_dir=None): analysis = ray.tune.run(_objective, config=trainer.hp_space(None), num_samples=n_trials, **kwargs) best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3]) best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config) - trainer.tb_writer = _tb_writer + if _tb_writer is not None: + trainer.add_callback(_tb_writer) return best_run + + +class TensorBoardCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard + `__. + + Args: + tb_writer (:obj:`SummaryWriter`, `optional`): + The writer to use. Will instatiate one if not set. + """ + + def __init__(self, tb_writer=None): + assert ( + _has_tensorboard + ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX." + self.tb_writer = tb_writer + + def on_init_end(self, args, state, control, **kwargs): + if self.tb_writer is None and state.is_world_process_zero: + self.tb_writer = SummaryWriter(log_dir=args.logging_dir) + + def on_train_begin(self, args, state, control, **kwargs): + if self.tb_writer is not None: + self.tb_writer.add_text("args", args.to_json_string()) + self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={}) + + def on_log(self, args, state, control, logs=None, **kwargs): + if self.tb_writer: + for k, v in logs.items(): + if isinstance(v, (int, float)): + self.tb_writer.add_scalar(k, v, state.global_step) + else: + logger.warning( + "Trainer is attempting to log a value of " + '"%s" of type %s for key "%s" as a scalar. ' + "This invocation of Tensorboard's writer.add_scalar() " + "is incorrect so we dropped this attribute.", + v, + type(v), + k, + ) + self.tb_writer.flush() + + def on_train_end(self, args, state, control, **kwargs): + if self.tb_writer: + self.tb_writer.close() + + +class WandbCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases + `__. + """ + + def __init__(self): + assert _has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`." + self._initialized = False + + def setup(self, args, state, model): + """ + Setup the optional Weights & Biases (`wandb`) integration. + + One can subclass and override this method to customize the setup if needed. Find more information + `here `__. You can also override the following environment variables: + + Environment: + WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`): + Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient + logging or :obj:`"all"` to log gradients and parameters. + WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`): + Set this to a custom string to store results in a different project. + WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to disable wandb entirely. + """ + self._initialized = True + if state.is_world_process_zero: + logger.info( + 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' + ) + combined_dict = {**args.to_sanitized_dict()} + if hasattr(model, "config"): + combined_dict = {**model.config.to_dict(), **combined_dict} + wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=args.run_name) + # keep track of model topology and gradients, unsupported on TPU + if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": + wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)) + + def on_train_begin(self, args, state, control, model=None, **kwargs): + if not self._initialized: + self.setup(args, state, model) + + def on_log(self, args, state, control, model=None, logs=None, **kwargs): + if not self._initialized: + self.setup(args, state, model) + if state.is_world_process_zero: + wandb.log(logs, step=state.global_step) + + +class CometCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML + `__. + """ + + def __init__(self): + assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`." + self._initialized = False + + def setup(self, args, state, model): + """ + Setup the optional Comet.ml integration. + + Environment: + COMET_MODE (:obj:`str`, `optional`): + "OFFLINE", "ONLINE", or "DISABLED" + COMET_PROJECT_NAME (:obj:`str`, `optional`): + Comet.ml project name for experiments + COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`): + Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE" + + For a number of configurable items in the environment, + see `here `__. + """ + self._initialized = True + if state.is_world_process_zero: + comet_mode = os.getenv("COMET_MODE", "ONLINE").upper() + args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")} + experiment = None + if comet_mode == "ONLINE": + experiment = comet_ml.Experiment(**args) + logger.info("Automatic Comet.ml online logging enabled") + elif comet_mode == "OFFLINE": + args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./") + experiment = comet_ml.OfflineExperiment(**args) + logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished") + if experiment is not None: + experiment._set_model_graph(model, framework="transformers") + experiment._log_parameters(args, prefix="args/", framework="transformers") + if hasattr(model, "config"): + experiment._log_parameters(model.config, prefix="config/", framework="transformers") + + def on_train_begin(self, args, state, control, model=None, **kwargs): + if not self._initialized: + self.setup(args, state, model) + + def on_log(self, args, state, control, model=None, logs=None, **kwargs): + if not self._initialized: + self.setup(args, state, model) + if state.is_world_process_zero: + experiment = comet_ml.config.get_global_experiment() + if experiment is not None: + experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 262ed6df59c1..b8e6e494b8f7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1,10 +1,26 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. +""" + import inspect -import math import os import re import shutil import warnings -from contextlib import contextmanager from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -15,8 +31,7 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler, Sampler, SequentialSampler -from tqdm.auto import tqdm, trange +from torch.utils.data.sampler import RandomSampler, SequentialSampler from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator from .file_utils import WEIGHTS_NAME, is_datasets_available, is_torch_tpu_available @@ -34,23 +49,35 @@ from .modeling_utils import PreTrainedModel from .optimization import AdamW, get_linear_schedule_with_warmup from .tokenization_utils_base import PreTrainedTokenizerBase +from .trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from .trainer_pt_utils import ( + SequentialDistributedSampler, + distributed_broadcast_scalars, + distributed_concat, + get_tpu_sampler, + nested_concat, + nested_detach, + nested_numpify, + nested_xla_mesh_reduce, + reissue_pt_warnings, +) from .trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, EvalPrediction, - EvaluationStrategy, HPSearchBackend, PredictionOutput, - TrainerState, TrainOutput, default_compute_objective, default_hp_space, - distributed_broadcast_scalars, - distributed_concat, - nested_concat, - nested_detach, - nested_numpify, - nested_xla_mesh_reduce, set_seed, ) from .training_args import TrainingArguments @@ -60,7 +87,8 @@ _use_native_amp = False _use_apex = False -PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler." +DEFAULT_CALLBACKS = [DefaultFlowCallback] + # Check if Pytorch version >= 1.6 to switch between Native AMP and Apex if version.parse(torch.__version__) < version.parse("1.6"): @@ -82,16 +110,20 @@ import torch_xla.distributed.parallel_loader as pl if is_tensorboard_available(): - try: - from torch.utils.tensorboard import SummaryWriter - except ImportError: - from tensorboardX import SummaryWriter + from .integrations import TensorBoardCallback + + DEFAULT_CALLBACKS.append(TensorBoardCallback) + if is_wandb_available(): - import wandb + from .integrations import WandbCallback + + DEFAULT_CALLBACKS.append(WandbCallback) if is_comet_available(): - import comet_ml + from .integrations import CometCallback + + DEFAULT_CALLBACKS.append(CometCallback) if is_optuna_available(): import optuna @@ -102,91 +134,20 @@ logger = logging.get_logger(__name__) -def reissue_pt_warnings(caught_warnings): - # Reissue warnings that are not the PT_LR_SCHEDULER_WARNING - if len(caught_warnings) > 1: - for w in caught_warnings: - if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING: - warnings.warn(w.message, w.category) - - -@contextmanager -def torch_distributed_zero_first(local_rank: int): - """ - Decorator to make all processes in distributed training wait for each local_master to do something. - - Args: - local_rank (:obj:`int`): The rank of the local process. - """ - if local_rank not in [-1, 0]: - torch.distributed.barrier() - yield - if local_rank == 0: - torch.distributed.barrier() - - -class SequentialDistributedSampler(Sampler): - """ - Distributed Sampler that subsamples indicies sequentially, - making it easier to collate all results at the end. - - Even though we only use this sampler for eval and predict (no training), - which means that the model params won't have to be synced (i.e. will not hang - for synchronization even if varied number of forward passes), we still add extra - samples to the sampler to make it evenly divisible (like in `DistributedSampler`) - to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. - """ - - def __init__(self, dataset, num_replicas=None, rank=None): - if num_replicas is None: - if not torch.distributed.is_available(): - raise RuntimeError("Requires distributed package to be available") - num_replicas = torch.distributed.get_world_size() - if rank is None: - if not torch.distributed.is_available(): - raise RuntimeError("Requires distributed package to be available") - rank = torch.distributed.get_rank() - self.dataset = dataset - self.num_replicas = num_replicas - self.rank = rank - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) - self.total_size = self.num_samples * self.num_replicas - - def __iter__(self): - indices = list(range(len(self.dataset))) - - # add extra samples to make it evenly divisible - indices += indices[: (self.total_size - len(indices))] - assert ( - len(indices) == self.total_size - ), f"Indices length {len(indices)} and total size {self.total_size} mismatched" - - # subsample - indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] - assert ( - len(indices) == self.num_samples - ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched" - - return iter(indices) - - def __len__(self): - return self.num_samples - - -def get_tpu_sampler(dataset: Dataset): - if xm.xrt_world_size() <= 1: - return RandomSampler(dataset) - return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) - - class Trainer: """ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. Args: - model (:class:`~transformers.PreTrainedModel`, `optional`): + model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`, `optional`): The model to train, evaluate or use for predictions. If not provided, a ``model_init`` must be passed. + + .. note:: + + :class:`~transformers.Trainer` is optimized to work with the :class:`~transformers.PreTrainedModel` + provided by the library. You can still use your own models defined as :obj:`torch.nn.Module` as long as + they work the same way as the 🤗 Transformers models. args (:class:`~transformers.TrainingArguments`, `optional`): The arguments to tweak for training. Will default to a basic instance of :class:`~transformers.TrainingArguments` with the ``output_dir`` set to a directory named `tmp_trainer` in the current directory if not provided. @@ -210,8 +171,11 @@ class Trainer: compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`): The function that will be used to compute metrics at evaluation. Must take a :class:`~transformers.EvalPrediction` and return a dictionary string to metric values. - tb_writer (:obj:`SummaryWriter`, `optional`): - Object to write to TensorBoard. + callbacks (List of :obj:`~transformers.TrainerCallback`, `optional`): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in :doc:`here `. + + If you want to remove one of the default callbacks used, use the :meth:`Trainer.remove_callback` method. optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`): A tuple containing the optimizer and the scheduler to use. Will default to an instance of :class:`~transformers.AdamW` on your model and a scheduler given by @@ -222,7 +186,7 @@ class Trainer: def __init__( self, - model: PreTrainedModel = None, + model: Union[PreTrainedModel, torch.nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, train_dataset: Optional[Dataset] = None, @@ -230,7 +194,7 @@ def __init__( tokenizer: Optional["PreTrainedTokenizerBase"] = None, model_init: Callable[[], PreTrainedModel] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, - tb_writer: Optional["SummaryWriter"] = None, + callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), **kwargs, ): @@ -259,7 +223,21 @@ def __init__( "Passing a `model_init` is incompatible with providing the `optimizers` argument." "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) - self.tb_writer = tb_writer + callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks + self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler) + self.add_callback(PrinterCallback if self.args.disable_tqdm else ProgressCallback) + + # Deprecated arguments + if "tb_writer" in kwargs: + warnings.warn( + "Passing `tb_writer` as a keyword argument is deprecated and won't be possible in a " + + "future version. Use `TensorBoardCallback(tb_writer=...)` instead and pass it to the `callbacks`" + + "argument", + FutureWarning, + ) + tb_writer = kwargs.pop("tb_writer") + self.remove_callback(TensorBoardCallback) + self.add_callback(TensorBoardCallback(tb_writer=tb_writer)) if "prediction_loss_only" in kwargs: warnings.warn( "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a " @@ -270,13 +248,6 @@ def __init__( self.args.prediction_loss_only = kwargs.pop("prediction_loss_only") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." - if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero(): - self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir) - if not is_tensorboard_available(): - logger.warning( - "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it." - ) - # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. self._loggers_initialized = False @@ -304,6 +275,7 @@ def __init__( self._remove_unused_columns(self.eval_dataset, description="evaluation") self.state = TrainerState() + self.control = TrainerControl() # Internal variable for total_flos used to count as tensors (for distributed + TPU), will be sent in the # state at each call to self.log. self._total_flos = None @@ -317,6 +289,45 @@ def __init__( else ["labels"] ) self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) + + def add_callback(self, callback): + """ + Add a callback to the current list of :class:`~transformer.TrainerCallback`. + + Args: + callback (:obj:`type` or :class:`~transformer.TrainerCallback`): + A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`. + In the first case, will instantiate a member of that class. + """ + self.callback_handler.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of :class:`~transformer.TrainerCallback` and returns it. + + If the callback is not found, returns :obj:`None` (and no error is raised). + + Args: + callback (:obj:`type` or :class:`~transformer.TrainerCallback`): + A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`. + In the first case, will pop the first member of that class found in the list of callbacks. + + Returns: + :class:`~transformer.TrainerCallback`: The callback removed, if found. + """ + return self.callback_handler.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of :class:`~transformer.TrainerCallback`. + + Args: + callback (:obj:`type` or :class:`~transformer.TrainerCallback`): + A :class:`~transformer.TrainerCallback` class or an instance of a :class:`~transformer.TrainerCallback`. + In the first case, will remove the first member of that class found in the list of callbacks. + """ + self.callback_handler.remove_callback(callback) def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): if not self.args.remove_unused_columns: @@ -465,102 +476,12 @@ def create_optimizer_and_scheduler(self, num_training_steps: int): self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps ) - def setup_wandb(self): - """ - Setup the optional Weights & Biases (`wandb`) integration. - - One can subclass and override this method to customize the setup if needed. Find more information - `here `__. You can also override the following environment variables: - - Environment: - WANDB_WATCH: - (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging - or "all" to log gradients and parameters - WANDB_PROJECT: - (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project - WANDB_DISABLED: - (Optional): boolean - defaults to false, set to "true" to disable wandb entirely - """ - if hasattr(self, "_setup_wandb"): - warnings.warn( - "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.", - FutureWarning, - ) - return self._setup_wandb() - - if self.is_world_process_zero(): - logger.info( - 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' - ) - combined_dict = {**self.args.to_sanitized_dict()} - if isinstance(self.model, PreTrainedModel): - combined_dict = {**self.model.config.to_dict(), **combined_dict} - wandb.init( - project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name - ) - # keep track of model topology and gradients, unsupported on TPU - if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": - wandb.watch( - self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps) - ) - - def setup_comet(self): - """ - Setup the optional Comet.ml integration. - - Environment: - COMET_MODE: - (Optional): str - "OFFLINE", "ONLINE", or "DISABLED" - COMET_PROJECT_NAME: - (Optional): str - Comet.ml project name for experiments - COMET_OFFLINE_DIRECTORY: - (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE" - - For a number of configurable items in the environment, - see `here `__ - """ - if self.is_world_master(): - comet_mode = os.getenv("COMET_MODE", "ONLINE").upper() - args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")} - experiment = None - if comet_mode == "ONLINE": - experiment = comet_ml.Experiment(**args) - logger.info("Automatic Comet.ml online logging enabled") - elif comet_mode == "OFFLINE": - args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./") - experiment = comet_ml.OfflineExperiment(**args) - logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished") - if experiment is not None: - experiment._set_model_graph(self.model, framework="transformers") - experiment._log_parameters(self.args, prefix="args/", framework="transformers") - if isinstance(self.model, PreTrainedModel): - experiment._log_parameters(self.model.config, prefix="config/", framework="transformers") - def num_examples(self, dataloader: DataLoader) -> int: """ Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset. """ return len(dataloader.dataset) - def _setup_loggers(self): - if self._loggers_initialized: - return - if is_wandb_available(): - self.setup_wandb() - elif os.environ.get("WANDB_DISABLED") != "true": - logger.info( - "You are instantiating a Trainer but W&B is not installed. To use wandb logging, " - "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface." - ) - if is_comet_available(): - self.setup_comet() - elif os.environ.get("COMET_MODE") != "DISABLED": - logger.info( - "To use comet_ml logging, run `pip/conda install comet_ml` " - "see https://www.comet.ml/docs/python-sdk/huggingface/" - ) - self._loggers_initialized = True - def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): """ HP search setup code """ if self.hp_search_backend is None or trial is None: @@ -661,7 +582,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt"))) reissue_pt_warnings(caught_warnings) - # Moxed precision training with apex (torch < 1.6) + # Mixed precision training with apex (torch < 1.6) model = self.model if self.args.fp16 and _use_apex: if not is_apex_available(): @@ -687,10 +608,6 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D # find_unused_parameters breaks checkpointing as per # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 - if self.tb_writer is not None: - self.tb_writer.add_text("args", self.args.to_json_string()) - self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) - # Train! if is_torch_tpu_available(): total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size() @@ -723,17 +640,25 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D logger.info(" Continuing training from global step %d", self.state.global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader # This should be the same if the state has been saved but in case the training arguments changed, it's safer # to set this after the load. self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() tr_loss = torch.tensor(0.0).to(self.args.device) + self._logging_loss_scalar = 0 self._total_flos = self.state.total_flos - logging_loss_scalar = 0.0 model.zero_grad() - disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero() - train_pbar = trange(epochs_trained, num_train_epochs, desc="Epoch", disable=disable_tqdm) + + self.control = self.callback_handler.on_train_begin(self.args, self.state, self.control) + for epoch in range(epochs_trained, num_train_epochs): if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch) @@ -750,15 +675,18 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D if self.args.past_index >= 0: self._past = None - epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm) + self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control) + for step, inputs in enumerate(epoch_iterator): # Skip past any already trained steps if resuming training if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 - epoch_pbar.update(1) continue + if (step + 1) % self.args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(self.args, self.state, self.control) + tr_loss += self.training_step(model, inputs) self._total_flos += self.floating_point_ops(inputs) @@ -787,50 +715,15 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D model.zero_grad() self.state.global_step += 1 self.state.epoch = epoch + (step + 1) / len(epoch_iterator) + self.control = self.callback_handler.on_step_end(self.args, self.state, self.control) + + self._maybe_log_save_evalute(tr_loss, model, trial, epoch) - if (self.args.logging_steps > 0 and self.state.global_step % self.args.logging_steps == 0) or ( - self.state.global_step == 1 and self.args.logging_first_step - ): - logs: Dict[str, float] = {} - tr_loss_scalar = tr_loss.item() - logs["loss"] = (tr_loss_scalar - logging_loss_scalar) / self.args.logging_steps - # backward compatibility for pytorch schedulers - logs["learning_rate"] = ( - self.lr_scheduler.get_last_lr()[0] - if version.parse(torch.__version__) >= version.parse("1.4") - else self.lr_scheduler.get_lr()[0] - ) - logging_loss_scalar = tr_loss_scalar - - self.log(logs) - - if ( - self.args.evaluation_strategy == EvaluationStrategy.STEPS - and self.state.global_step % self.args.eval_steps == 0 - ): - metrics = self.evaluate() - self._report_to_hp_search(trial, epoch, metrics) - if self.args.load_best_model_at_end: - self._save_training(model, trial, metrics=metrics) - - if ( - not self.args.load_best_model_at_end - and self.args.save_steps > 0 - and self.state.global_step % self.args.save_steps == 0 - ): - self._save_training(model, trial) - - epoch_pbar.update(1) - if self.state.global_step >= max_steps: + if self.control.should_epoch_stop or self.control.should_training_stop: break - epoch_pbar.close() - train_pbar.update(1) - if self.args.evaluation_strategy == EvaluationStrategy.EPOCH: - metrics = self.evaluate() - self._report_to_hp_search(trial, epoch, metrics) - if self.args.load_best_model_at_end: - self._save_training(model, trial, metrics=metrics) + self.control = self.callback_handler.on_epoch_end(self.args, self.state, self.control) + self._maybe_log_save_evalute(tr_loss, model, trial, epoch) if self.args.tpu_metrics_debug or self.args.debug: if is_torch_tpu_available(): @@ -841,12 +734,9 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D "You enabled PyTorch/XLA debug metrics but you don't have a TPU " "configured. Check your training configuration if this is unexpected." ) - if self.state.global_step >= max_steps: + if self.control.should_training_stop: break - train_pbar.close() - if self.tb_writer: - self.tb_writer.close() if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of training delattr(self, "_past") @@ -863,9 +753,36 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)) self.model.load_state_dict(state_dict) + self.control = self.callback_handler.on_train_end(self.args, self.state, self.control) + return TrainOutput(self.state.global_step, tr_loss.item() / self.state.global_step) - def _save_training(self, model, trial, metrics=None): + def _maybe_log_save_evalute(self, tr_loss, model, trial, epoch): + if self.control.should_log: + logs: Dict[str, float] = {} + tr_loss_scalar = tr_loss.item() + logs["loss"] = (tr_loss_scalar - self._logging_loss_scalar) / self.args.logging_steps + # backward compatibility for pytorch schedulers + logs["learning_rate"] = ( + self.lr_scheduler.get_last_lr()[0] + if version.parse(torch.__version__) >= version.parse("1.4") + else self.lr_scheduler.get_lr()[0] + ) + self._logging_loss_scalar = tr_loss_scalar + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + metrics = self.evaluate() + self._report_to_hp_search(trial, epoch, metrics) + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def _save_checkpoint(self, model, trial, metrics=None): # In all cases (even distributed/parallel), self.model is always a reference # to the model we want to save. if hasattr(model, "module"): @@ -896,7 +813,7 @@ def _save_training(self, model, trial, metrics=None): reissue_pt_warnings(caught_warnings) # Determine the new best metric / best model checkpoint - if metrics is not None: + if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" @@ -998,7 +915,7 @@ def hyperparameter_search( self.hp_search_backend = None return best_run - def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: + def log(self, logs: Dict[str, float]) -> None: """ Log :obj:`logs` on the various objects watching training. @@ -1007,55 +924,22 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: Args: logs (:obj:`Dict[str, float]`): The values to log. - iterator (:obj:`tqdm`, `optional`): - A potential tqdm progress bar to write the logs on. """ - # Set up loggers like W&B or Comet ML - self._setup_loggers() - if hasattr(self, "_log"): warnings.warn( "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.", FutureWarning, ) - return self._log(logs, iterator=iterator) + return self._log(logs) if self.state.epoch is not None: logs["epoch"] = self.state.epoch if self._total_flos is not None: self.store_flos() logs["total_flos"] = self.state.total_flos - if self.tb_writer: - for k, v in logs.items(): - if isinstance(v, (int, float)): - self.tb_writer.add_scalar(k, v, self.state.global_step) - else: - logger.warning( - "Trainer is attempting to log a value of " - '"%s" of type %s for key "%s" as a scalar. ' - "This invocation of Tensorboard's writer.add_scalar() " - "is incorrect so we dropped this attribute.", - v, - type(v), - k, - ) - self.tb_writer.flush() - if is_wandb_available(): - if self.is_world_process_zero(): - wandb.log(logs, step=self.state.global_step) - if is_comet_available(): - if self.is_world_process_zero(): - experiment = comet_ml.config.get_global_experiment() - if experiment is not None: - experiment._log_metrics( - logs, step=self.state.global_step, epoch=self.state.epoch, framework="transformers" - ) + self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) output = {**logs, **{"step": self.state.global_step}} self.state.log_history.append(output) - if iterator is not None: - iterator.write(output) - else: - print(output) def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: """ @@ -1372,8 +1256,9 @@ def prediction_loop( if self.args.past_index >= 0: self._past = None - disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm - for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm): + self.callback_handler.eval_dataloader = dataloader + + for inputs in dataloader: loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) batch_size = inputs[list(inputs.keys())[0]].shape[0] if loss is not None: @@ -1382,6 +1267,7 @@ def prediction_loop( preds = logits if preds is None else nested_concat(preds, logits, dim=0) if labels is not None: label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0) + self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py new file mode 100644 index 000000000000..bc14d99f25e1 --- /dev/null +++ b/src/transformers/trainer_callback.py @@ -0,0 +1,468 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Callbacks to use with the Trainer class and customize the training loop. +""" + +import dataclasses +import json +from dataclasses import dataclass +from typing import Dict, List, Optional + +from tqdm.auto import tqdm + +from .trainer_utils import EvaluationStrategy +from .training_args import TrainingArguments +from .utils import logging + + +logger = logging.get_logger(__name__) + + +@dataclass +class TrainerState: + """ + A class containing the :class:`~transformers.Trainer` inner state that will be saved along the model and optimizer + when checkpointing and passed to the :class:`~transformers.TrainerCallback`. + + .. note:: + + In all this class, one step is to be understood as one update step. When using gradient accumulation, one + update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`, + then one update step requires going throuch `n` batches. + + Args: + epoch (:obj:`float`, `optional`): + Only set during training, will represent the epoch the training is at (the decimal part being the + percentage of the current epoch completed). + global_step (:obj:`int`, `optional`, defaults to 0): + During training, represents the number of update steps completed. + max_steps (:obj:`int`, `optional`, defaults to 0): + The number of update steps to do during the current training. + total_flos (:obj:`int`, `optional`, defaults to 0): + The total number of floating operations done by the model since the beginning of training. + log_history (:obj:`List[Dict[str, float]]`, `optional`): + The list of logs done since the beginning of training. + best_metric (:obj:`float`, `optional`): + When tracking the best model, the value of the best metric encountered so far. + best_model_checkpoint (:obj:`str`, `optional`): + When tracking the best model, the value of the name of the checkpoint for the best model encountered so + far. + is_local_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on + several machines) main process. + is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not this process is the global main process (when training in a distributed fashion on + several machines, this is only going to be :obj:`True` for one process). + """ + + epoch: Optional[float] = None + global_step: int = 0 + max_steps: int = 0 + num_train_epochs: int = 0 + total_flos: int = 0 + log_history: List[Dict[str, float]] = None + best_metric: Optional[float] = None + best_model_checkpoint: Optional[str] = None + is_local_process_zero: bool = True + is_world_process_zero: bool = True + + def __post_init__(self): + if self.log_history is None: + self.log_history = [] + + def save_to_json(self, json_path: str): + """ Save the content of this instance in JSON format inside :obj:`json_path`.""" + json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" + with open(json_path, "w", encoding="utf-8") as f: + f.write(json_string) + + @classmethod + def load_from_json(cls, json_path: str): + """ Create an instance from the content of :obj:`json_path`.""" + with open(json_path, "r", encoding="utf-8") as f: + text = f.read() + return cls(**json.loads(text)) + + +@dataclass +class TrainerControl: + """ + A class that handles the :class:`~transformers.Trainer` control flow. This class is used by the + :class:`~transformers.TrainerCallback` to activate some switches in the training loop. + + Args: + should_training_stop (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the training should be interrupted. + + If :obj:`True`, this variable will not be set back to :obj:`False`. The training will just stop. + should_epoch_stop (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the current epoch should be interrupted. + + If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next epoch. + should_save (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the model should be saved at this step. + + If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step. + should_evaluate (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the model should be evaluated at this step. + + If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step. + should_log (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the logs should be reported at this step. + + If :obj:`True`, this variable will be set back to :obj:`False` at the beginning of the next step. + """ + + should_training_stop: bool = False + should_epoch_stop: bool = False + should_save: bool = False + should_evaluate: bool = False + should_log: bool = False + + def _new_training(self): + """ Internal method that resets the variable for a new training. """ + self.should_training_stop = False + + def _new_epoch(self): + """ Internal method that resets the variable for a new epoch. """ + self.should_epoch_stop = False + + def _new_step(self): + """ Internal method that resets the variable for a new step. """ + self.should_save_model = False + self.should_evaluate = False + self.should_log = False + + +class TrainerCallback: + """ + A class for objects that will inspect the state of the training loop at some events and take some decisions. At + each of those events the following arguments are available: + + Args: + args (:class:`~transformers.TrainingArguments`): + The training arguments used to instantiate the :class:`~transformers.Trainer`. + state (:class:`~transformers.TrainerState`): + The current state of the :class:`~transformers.Trainer`. + control (:class:`~transformers.TrainerControl`): + The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions. + model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`): + The model being trained. + optimizer (:obj:`torch.optim.Optimizer`): + The optimizer used for the training steps. + lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`): + The scheduler used for setting the learning rate. + train_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`): + The current dataloader used for training. + eval_dataloader (:obj:`torch.utils.data.dataloader.DataLoader`, `optional`): + The current dataloader used for training. + metrics (:obj:`Dict[str, float]`): + The metrics computed by the last evaluation phase. + + Those are only accessible in the event :obj:`on_evaluate`. + logs (:obj:`Dict[str, float]`): + The values to log. + + Those are only accessible in the event :obj:`on_log`. + + The :obj:`control` object is the only one that can be changed by the callback, in which case the event that changes + it should return the modified version. + + The argument :obj:`args`, :obj:`state` and :obj:`control` are positionals for all events, all the others are + grouped in :obj:`kwargs`. You can unpack the ones you need in the signature of the event using them. As an example, + see the code of the simple :class:`~transformer.PrinterCallback`. + + Example:: + + class PrinterCallback(TrainerCallback): + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) + """ + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of the initialization of the :class:`~transformers.Trainer`. + """ + pass + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of training. + """ + pass + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of training. + """ + pass + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of an epoch. + """ + pass + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of an epoch. + """ + pass + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the beginning of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + pass + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called at the end of a training step. If using gradient accumulation, one training step might take + several inputs. + """ + pass + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after an evaluation phase. + """ + pass + + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after a checkpoint save. + """ + pass + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after logging the last logs. + """ + pass + + def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + """ + Event called after a prediction step. + """ + pass + + +class CallbackHandler(TrainerCallback): + """ Internal class that just calls the list of callbacks in order. """ + + def __init__(self, callbacks, model, optimizer, lr_scheduler): + self.callbacks = [] + for cb in callbacks: + self.add_callback(cb) + self.model = model + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.train_dataloader = None + self.eval_dataloader = None + + if not any(isinstance(cb, DefaultFlowCallback) for cb in self.callbacks): + logger.warn( + "The Trainer will not work properly if you don't have a `DefaultFlowCallback` in its callbacks. You\n" + + "should add one before training with `trainer.add_callback(DefaultFlowCallback). The current list of" + + "callbacks is\n:" + + self.callback_list + ) + + def add_callback(self, callback): + cb = callback() if isinstance(callback, type) else callback + cb_class = callback if isinstance(callback, type) else callback.__class__ + if cb_class in [c.__class__ for c in self.callbacks]: + logger.warn( + f"You are adding a {cb_class} to the callbacks of this Trainer, but there is already one. The current" + + "list of callbacks is\n:" + + self.callback_list + ) + self.callbacks.append(cb) + + def pop_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return cb + else: + for cb in self.callbacks: + if cb == callback: + self.callbacks.remove(cb) + return cb + + def remove_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return + else: + self.callbacks.remove(callback) + + @property + def callback_list(self): + return "\n".join(self.callbacks) + + def on_init_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_init_end", args, state, control) + + def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_training_stop = False + return self.call_event("on_train_begin", args, state, control) + + def on_train_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_train_end", args, state, control) + + def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_epoch_stop = False + return self.call_event("on_epoch_begin", args, state, control) + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_epoch_end", args, state, control) + + def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_log = False + control.should_evaluate = False + control.should_save = False + return self.call_event("on_step_begin", args, state, control) + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_step_end", args, state, control) + + def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, metrics): + control.should_evaluate = False + return self.call_event("on_evaluate", args, state, control, metrics=metrics) + + def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + control.should_save = False + return self.call_event("on_save", args, state, control) + + def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs): + control.should_log = False + return self.call_event("on_log", args, state, control, logs=logs) + + def on_prediction_step(self, args: TrainingArguments, state: TrainerState, control: TrainerControl): + return self.call_event("on_prediction_step", args, state, control) + + def call_event(self, event, args, state, control, **kwargs): + for callback in self.callbacks: + result = getattr(callback, event)( + args, + state, + control, + model=self.model, + optimizer=self.optimizer, + lr_scheduler=self.lr_scheduler, + train_dataloader=self.train_dataloader, + eval_dataloader=self.eval_dataloader, + **kwargs, + ) + # A Callback can skip the return of `control` if it doesn't change it. + if result is not None: + control = result + return control + + +class DefaultFlowCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation + and checkpoints. + """ + + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + # Log + if state.global_step == 1 and args.logging_first_step: + control.should_log = True + if args.logging_steps > 0 and state.global_step % args.logging_steps == 0: + control.should_log = True + + # Evaluate + if args.evaluation_strategy == EvaluationStrategy.STEPS and state.global_step % args.eval_steps == 0: + control.should_evaluate = True + if args.load_best_model_at_end: + control.should_save = True + + # Save + if not args.load_best_model_at_end and args.save_steps > 0 and state.global_step % args.save_steps == 0: + control.should_save = True + + # End training + if state.global_step >= state.max_steps: + control.should_training_stop = True + + return control + + def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): + if args.evaluation_strategy == EvaluationStrategy.EPOCH: + control.should_evaluate = True + if args.load_best_model_at_end: + control.should_save = True + return control + + +class ProgressCallback(TrainerCallback): + """ + A :class:`~transformers.TrainerCallback` that displays the progress of training or evaluation. + """ + + def __init__(self): + self.training_bar = None + self.prediction_bar = None + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_local_process_zero: + self.training_bar = tqdm(total=state.max_steps) + + def on_step_end(self, args, state, control, **kwargs): + if state.is_local_process_zero: + self.training_bar.update(1) + + def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + if state.is_local_process_zero: + if self.prediction_bar is None: + self.prediction_bar = tqdm(total=len(eval_dataloader), leave=self.training_bar is None) + self.prediction_bar.update(1) + + def on_evaluate(self, args, state, control, **kwargs): + if state.is_local_process_zero: + self.prediction_bar.close() + self.prediction_bar = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_local_process_zero and self.training_bar is not None: + _ = logs.pop("total_flos", None) + self.training_bar.write(str(logs)) + + def on_train_end(self, args, state, control, **kwargs): + if state.is_local_process_zero: + self.training_bar.close() + self.training_bar = None + + +class PrinterCallback(TrainerCallback): + """ + A bare :class:`~transformers.TrainerCallback` that just prints the logs. + """ + + def on_log(self, args, state, control, logs=None, **kwargs): + _ = logs.pop("total_flos", None) + if state.is_local_process_zero: + print(logs) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py new file mode 100644 index 000000000000..74a93f8286b8 --- /dev/null +++ b/src/transformers/trainer_pt_utils.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Torch utilities for the Trainer class. +""" + +import math +import warnings +from contextlib import contextmanager +from typing import List, Optional, Union + +import torch +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import RandomSampler, Sampler + +from .file_utils import is_torch_tpu_available + + +if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + +PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler." + + +def nested_concat(tensors, new_tensors, dim=0): + "Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors." + assert type(tensors) == type( + new_tensors + ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors)) + return torch.cat((tensors, new_tensors), dim=dim) + + +def nested_numpify(tensors): + "Numpify `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_numpify(t) for t in tensors) + return tensors.cpu().numpy() + + +def nested_detach(tensors): + "Detach `tensors` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_detach(t) for t in tensors) + return tensors.detach() + + +def nested_xla_mesh_reduce(tensors, name): + if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) + return xm.mesh_reduce(name, tensors, torch.cat) + else: + raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") + + +def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> torch.Tensor: + try: + if isinstance(tensor, (tuple, list)): + return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) + output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output_tensors, tensor) + concat = torch.cat(output_tensors, dim=0) + + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat + except AssertionError: + raise AssertionError("Not currently using distributed training") + + +def distributed_broadcast_scalars( + scalars: List[Union[int, float]], num_total_examples: Optional[int] = None +) -> torch.Tensor: + try: + tensorized_scalar = torch.tensor(scalars).cuda() + output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output_tensors, tensorized_scalar) + concat = torch.cat(output_tensors, dim=0) + + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat + except AssertionError: + raise AssertionError("Not currently using distributed training") + + +def reissue_pt_warnings(caught_warnings): + # Reissue warnings that are not the PT_LR_SCHEDULER_WARNING + if len(caught_warnings) > 1: + for w in caught_warnings: + if w.category != UserWarning or w.message != PT_LR_SCHEDULER_WARNING: + warnings.warn(w.message, w.category) + + +@contextmanager +def torch_distributed_zero_first(local_rank: int): + """ + Decorator to make all processes in distributed training wait for each local_master to do something. + + Args: + local_rank (:obj:`int`): The rank of the local process. + """ + if local_rank not in [-1, 0]: + torch.distributed.barrier() + yield + if local_rank == 0: + torch.distributed.barrier() + + +class SequentialDistributedSampler(Sampler): + """ + Distributed Sampler that subsamples indicies sequentially, + making it easier to collate all results at the end. + + Even though we only use this sampler for eval and predict (no training), + which means that the model params won't have to be synced (i.e. will not hang + for synchronization even if varied number of forward passes), we still add extra + samples to the sampler to make it evenly divisible (like in `DistributedSampler`) + to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. + """ + + def __init__(self, dataset, num_replicas=None, rank=None): + if num_replicas is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = torch.distributed.get_world_size() + if rank is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = torch.distributed.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + indices = list(range(len(self.dataset))) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert ( + len(indices) == self.total_size + ), f"Indices length {len(indices)} and total size {self.total_size} mismatched" + + # subsample + indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] + assert ( + len(indices) == self.num_samples + ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched" + + return iter(indices) + + def __len__(self): + return self.num_samples + + +def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): + if xm.xrt_world_size() <= 1: + return RandomSampler(dataset) + return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index e816b0772a9c..96757a922e6a 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -1,19 +1,30 @@ -import dataclasses -import json +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow. +""" + import random -from dataclasses import dataclass -from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import numpy as np -from .file_utils import is_tf_available, is_torch_available, is_torch_tpu_available +from .file_utils import is_tf_available, is_torch_available from .tokenization_utils_base import ExplicitEnum -if is_torch_available(): - import torch - - def set_seed(seed: int): """ Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` @@ -139,144 +150,3 @@ class HPSearchBackend(ExplicitEnum): HPSearchBackend.OPTUNA: default_hp_space_optuna, HPSearchBackend.RAY: default_hp_space_ray, } - - -def nested_concat(tensors, new_tensors, dim=0): - "Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors." - if is_torch_available(): - assert type(tensors) == type( - new_tensors - ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." - if isinstance(tensors, (list, tuple)): - return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors)) - return torch.cat((tensors, new_tensors), dim=dim) - else: - raise ImportError("Torch must be installed to use `nested_concat`") - - -def nested_deatch(tensors): - "Detach `tensors` (even if it's a nested list/tuple of tensors)." - if isinstance(tensors, (list, tuple)): - return type(tensors)(nested_detach(t) for t in tensors) - return tensors.detach() - - -def nested_numpify(tensors): - "Numpify `tensors` (even if it's a nested list/tuple of tensors)." - if isinstance(tensors, (list, tuple)): - return type(tensors)(nested_numpify(t) for t in tensors) - return tensors.cpu().numpy() - - -def nested_detach(tensors): - "Detach `tensors` (even if it's a nested list/tuple of tensors)." - if isinstance(tensors, (list, tuple)): - return type(tensors)(nested_detach(t) for t in tensors) - return tensors.detach() - - -def nested_xla_mesh_reduce(tensors, name): - if is_torch_tpu_available(): - import torch_xla.core.xla_model as xm - - if isinstance(tensors, (list, tuple)): - return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) - return xm.mesh_reduce(name, tensors, torch.cat) - else: - raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") - - -def distributed_concat(tensor: "torch.Tensor", num_total_examples: Optional[int] = None) -> "torch.Tensor": - if is_torch_available(): - try: - if isinstance(tensor, (tuple, list)): - return type(tensor)(distributed_concat(t, num_total_examples) for t in tensor) - output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensor) - concat = torch.cat(output_tensors, dim=0) - - # truncate the dummy elements added by SequentialDistributedSampler - if num_total_examples is not None: - concat = concat[:num_total_examples] - return concat - except AssertionError: - raise AssertionError("Not currently using distributed training") - else: - raise ImportError("Torch must be installed to use `distributed_concat`") - - -def distributed_broadcast_scalars( - scalars: List[Union[int, float]], num_total_examples: Optional[int] = None -) -> "torch.Tensor": - if is_torch_available(): - try: - tensorized_scalar = torch.tensor(scalars).cuda() - output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensorized_scalar) - concat = torch.cat(output_tensors, dim=0) - - # truncate the dummy elements added by SequentialDistributedSampler - if num_total_examples is not None: - concat = concat[:num_total_examples] - return concat - except AssertionError: - raise AssertionError("Not currently using distributed training") - else: - raise ImportError("Torch must be installed to use `distributed_broadcast_scalars`") - - -@dataclass -class TrainerState: - """ - A class containing the `Trainer` inner state that will be saved along the model and optimizer. - - .. note:: - - In all this class, one step is to be understood as one update step. When using gradient accumulation, one - update step may require several forward and backward passes: if you use :obj:`gradient_accumulation_steps=n`, - then one update step requires going throuch `n` batches. - - Args: - epoch (:obj:`float`, `optional`): - Only set during training, will represent the epoch the training is at (the decimal part being the - percentage of the current epoch completed). - global_step (:obj:`int`, `optional`, defaults to 0): - During training, represents the number of update steps completed. - max_steps (:obj:`int`, `optional`, defaults to 0): - The number of update steps to do during the current training. - total_flos (:obj:`int`, `optional`, defaults to 0): - The total number of floating operations done by the model since the beginning of training. - log_history (:obj:`List[Dict[str, float]]`, `optional`): - The list of logs done since the beginning of training. - best_metric (:obj:`float`, `optional`): - When tracking the best model, the value of the best metric encountered so far. - best_model_checkpoint (:obj:`str`, `optional`): - When tracking the best model, the value of the name of the checkpoint for the best model encountered so - far. - """ - - epoch: Optional[float] = None - global_step: int = 0 - max_steps: int = 0 - num_train_epochs: int = 0 - total_flos: int = 0 - log_history: List[Dict[str, float]] = None - best_metric: Optional[float] = None - best_model_checkpoint: Optional[str] = None - - def __post_init__(self): - if self.log_history is None: - self.log_history = [] - - def save_to_json(self, json_path: str): - """ Save the content of this instance in JSON format inside :obj:`json_path`.""" - json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n" - with open(json_path, "w", encoding="utf-8") as f: - f.write(json_string) - - @classmethod - def load_from_json(cls, json_path: str): - """ Create an instance from the content of :obj:`json_path`.""" - with open(json_path, "r", encoding="utf-8") as f: - text = f.read() - return cls(**json.loads(text)) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 2d75e7518303..9359a9f17ced 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -54,7 +54,7 @@ class TrainingArguments: :obj:`"no"`. do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether to run predictions on the test set or not. - evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): + evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`): The evaluation strategy to adopt during training. Possible values are: * :obj:`"no"`: No evaluation is done during training. diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index ab854e5e7477..0521051a0a96 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1860,19 +1860,10 @@ def from_pretrained(self, *args, **kwargs): requires_pytorch(self) -class EvalPrediction: - def __init__(self, *args, **kwargs): - requires_pytorch(self) - - class Trainer: def __init__(self, *args, **kwargs): requires_pytorch(self) -def set_seed(*args, **kwargs): - requires_pytorch(set_seed) - - def torch_distributed_zero_first(*args, **kwargs): requires_pytorch(torch_distributed_zero_first) diff --git a/tests/test_trainer_callback.py b/tests/test_trainer_callback.py new file mode 100644 index 000000000000..0469c077fa75 --- /dev/null +++ b/tests/test_trainer_callback.py @@ -0,0 +1,214 @@ +import shutil +import tempfile +import unittest + +from transformers import ( + DefaultFlowCallback, + EvaluationStrategy, + PrinterCallback, + ProgressCallback, + Trainer, + TrainerCallback, + TrainingArguments, + is_torch_available, +) +from transformers.testing_utils import require_torch + + +if is_torch_available(): + from transformers.trainer import DEFAULT_CALLBACKS + + from .test_trainer import RegressionDataset, RegressionModelConfig, RegressionPreTrainedModel + + +class TestTrainerCallback(TrainerCallback): + "A callback that registers the events that goes through." + + def __init__(self): + self.events = [] + + def on_init_end(self, args, state, control, **kwargs): + self.events.append("on_init_end") + + def on_train_begin(self, args, state, control, **kwargs): + self.events.append("on_train_begin") + + def on_train_end(self, args, state, control, **kwargs): + self.events.append("on_train_end") + + def on_epoch_begin(self, args, state, control, **kwargs): + self.events.append("on_epoch_begin") + + def on_epoch_end(self, args, state, control, **kwargs): + self.events.append("on_epoch_end") + + def on_step_begin(self, args, state, control, **kwargs): + self.events.append("on_step_begin") + + def on_step_end(self, args, state, control, **kwargs): + self.events.append("on_step_end") + + def on_evaluate(self, args, state, control, **kwargs): + self.events.append("on_evaluate") + + def on_save(self, args, state, control, **kwargs): + self.events.append("on_save") + + def on_log(self, args, state, control, **kwargs): + self.events.append("on_log") + + def on_prediction_step(self, args, state, control, **kwargs): + self.events.append("on_prediction_step") + + +@require_torch +class TrainerCallbackTest(unittest.TestCase): + def setUp(self): + self.output_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.output_dir) + + def get_trainer(self, a=0, b=0, train_len=64, eval_len=64, callbacks=None, disable_tqdm=False, **kwargs): + # disable_tqdm in TrainingArguments has a flaky default since it depends on the level of logging. We make sure + # its set to False since the tests later on depend on its value. + train_dataset = RegressionDataset(length=train_len) + eval_dataset = RegressionDataset(length=eval_len) + config = RegressionModelConfig(a=a, b=b) + model = RegressionPreTrainedModel(config) + + args = TrainingArguments(self.output_dir, disable_tqdm=disable_tqdm, **kwargs) + return Trainer( + model, + args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + callbacks=callbacks, + ) + + def check_callbacks_equality(self, cbs1, cbs2): + self.assertEqual(len(cbs1), len(cbs2)) + + # Order doesn't matter + cbs1 = list(sorted(cbs1, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__)) + cbs2 = list(sorted(cbs2, key=lambda cb: cb.__name__ if isinstance(cb, type) else cb.__class__.__name__)) + + for cb1, cb2 in zip(cbs1, cbs2): + if isinstance(cb1, type) and isinstance(cb2, type): + self.assertEqual(cb1, cb2) + elif isinstance(cb1, type) and not isinstance(cb2, type): + self.assertEqual(cb1, cb2.__class__) + elif not isinstance(cb1, type) and isinstance(cb2, type): + self.assertEqual(cb1.__class__, cb2) + else: + self.assertEqual(cb1, cb2) + + def get_expected_events(self, trainer): + expected_events = ["on_init_end", "on_train_begin"] + step = 0 + train_dl_len = len(trainer.get_eval_dataloader()) + evaluation_events = ["on_prediction_step"] * len(trainer.get_eval_dataloader()) + ["on_log", "on_evaluate"] + for _ in range(trainer.state.num_train_epochs): + expected_events.append("on_epoch_begin") + for _ in range(train_dl_len): + step += 1 + expected_events += ["on_step_begin", "on_step_end"] + if step % trainer.args.logging_steps == 0: + expected_events.append("on_log") + if ( + trainer.args.evaluation_strategy == EvaluationStrategy.STEPS + and step % trainer.args.eval_steps == 0 + ): + expected_events += evaluation_events.copy() + if step % trainer.args.save_steps == 0: + expected_events.append("on_save") + expected_events.append("on_epoch_end") + if trainer.args.evaluation_strategy == EvaluationStrategy.EPOCH: + expected_events += evaluation_events.copy() + expected_events.append("on_train_end") + return expected_events + + def test_init_callback(self): + trainer = self.get_trainer() + expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback] + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + # Callbacks passed at init are added to the default callbacks + trainer = self.get_trainer(callbacks=[TestTrainerCallback]) + expected_callbacks.append(TestTrainerCallback) + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + # TrainingArguments.disable_tqdm controls if use ProgressCallback or PrinterCallback + trainer = self.get_trainer(disable_tqdm=True) + expected_callbacks = DEFAULT_CALLBACKS.copy() + [PrinterCallback] + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + def test_add_remove_callback(self): + expected_callbacks = DEFAULT_CALLBACKS.copy() + [ProgressCallback] + trainer = self.get_trainer() + + # We can add, pop, or remove by class name + trainer.remove_callback(DefaultFlowCallback) + expected_callbacks.remove(DefaultFlowCallback) + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + trainer = self.get_trainer() + cb = trainer.pop_callback(DefaultFlowCallback) + self.assertEqual(cb.__class__, DefaultFlowCallback) + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + trainer.add_callback(DefaultFlowCallback) + expected_callbacks.insert(0, DefaultFlowCallback) + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + # We can also add, pop, or remove by instance + trainer = self.get_trainer() + cb = trainer.callback_handler.callbacks[0] + trainer.remove_callback(cb) + expected_callbacks.remove(DefaultFlowCallback) + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + trainer = self.get_trainer() + cb1 = trainer.callback_handler.callbacks[0] + cb2 = trainer.pop_callback(cb1) + self.assertEqual(cb1, cb2) + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + trainer.add_callback(cb1) + expected_callbacks.insert(0, DefaultFlowCallback) + self.check_callbacks_equality(trainer.callback_handler.callbacks, expected_callbacks) + + def test_event_flow(self): + trainer = self.get_trainer(callbacks=[TestTrainerCallback]) + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + # Independent log/save/eval + trainer = self.get_trainer(callbacks=[TestTrainerCallback], logging_steps=5) + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + trainer = self.get_trainer(callbacks=[TestTrainerCallback], save_steps=5) + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + trainer = self.get_trainer(callbacks=[TestTrainerCallback], eval_steps=5, evaluation_strategy="steps") + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + trainer = self.get_trainer(callbacks=[TestTrainerCallback], evaluation_strategy="epoch") + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer)) + + # A bit of everything + trainer = self.get_trainer( + callbacks=[TestTrainerCallback], logging_steps=3, save_steps=10, eval_steps=5, evaluation_strategy="steps" + ) + trainer.train() + events = trainer.callback_handler.callbacks[-2].events + self.assertEqual(events, self.get_expected_events(trainer))