diff --git a/docs/source/en/internal/trainer_utils.mdx b/docs/source/en/internal/trainer_utils.mdx index 054bd69b440c..bba182d5ab64 100644 --- a/docs/source/en/internal/trainer_utils.mdx +++ b/docs/source/en/internal/trainer_utils.mdx @@ -22,6 +22,8 @@ Most of those are only useful if you are studying the code of the Trainer in the [[autodoc]] IntervalStrategy +[[autodoc]] enable_full_determinism + [[autodoc]] set_seed [[autodoc]] torch_distributed_zero_first diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 248415be95bd..61a2538844a9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -372,7 +372,7 @@ "TrainerControl", "TrainerState", ], - "trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "set_seed"], + "trainer_utils": ["EvalPrediction", "IntervalStrategy", "SchedulerType", "enable_full_determinism", "set_seed"], "training_args": ["TrainingArguments"], "training_args_seq2seq": ["Seq2SeqTrainingArguments"], "training_args_tf": ["TFTrainingArguments"], @@ -2809,7 +2809,7 @@ TrainerControl, TrainerState, ) - from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, set_seed + from .trainer_utils import EvalPrediction, IntervalStrategy, SchedulerType, enable_full_determinism, set_seed from .training_args import TrainingArguments from .training_args_seq2seq import Seq2SeqTrainingArguments from .training_args_tf import TFTrainingArguments diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index dda278471811..a4db18dfc186 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -115,10 +115,12 @@ default_compute_objective, default_hp_space, denumpify_detensorize, + enable_full_determinism, find_executable_batch_size, get_last_checkpoint, has_length, number_of_arguments, + seed_worker, set_seed, speed_metrics, ) @@ -300,7 +302,7 @@ def __init__( args = TrainingArguments(output_dir=output_dir) self.args = args # Seed must be set before instantiating the model when using model - set_seed(self.args.seed) + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.hp_name = None self.deepspeed = None self.is_in_train = False @@ -746,6 +748,7 @@ def get_train_dataloader(self) -> DataLoader: drop_last=self.args.dataloader_drop_last, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, + worker_init_fn=seed_worker, ) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: @@ -1250,7 +1253,7 @@ def train( model_reloaded = False if self.model_init is not None: # Seed must be set before instantiating the model when using model_init. - set_seed(args.seed) + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) self.model = self.call_model_init(trial) model_reloaded = True # Reinitializes optimizer and scheduler diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index 71c2e691d2a7..737dd4deaf68 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -34,7 +34,14 @@ from .modeling_tf_utils import TFPreTrainedModel from .optimization_tf import GradientAccumulator, create_optimizer -from .trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, IntervalStrategy, PredictionOutput, set_seed +from .trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + EvalPrediction, + IntervalStrategy, + PredictionOutput, + enable_full_determinism, + set_seed, +) from .training_args_tf import TFTrainingArguments from .utils import logging @@ -134,7 +141,7 @@ def __init__( "see https://www.comet.ml/docs/python-sdk/huggingface/" ) - set_seed(self.args.seed) + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) def get_train_tfdataset(self) -> tf.data.Dataset: """ diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index 62cab858b7e1..d74d0aed9fc6 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -47,6 +47,39 @@ import tensorflow as tf +def seed_worker(_): + """ + Helper function to set worker seed during Dataloader initialization. + """ + worker_seed = torch.initial_seed() % 2**32 + set_seed(worker_seed) + + +def enable_full_determinism(seed: int): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + - https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism for tensorflow + """ + # set seed first + set_seed(seed) + + if is_torch_available(): + #  Enable PyTorch deterministic mode. This potentially requires either the environment + #  variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if is_tf_available(): + tf.config.experimental.enable_op_determinism() + + def set_seed(seed: int): """ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch` and/or `tf` (if installed). diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 4eb9b274e1d0..d3f381e7d8de 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -446,6 +446,9 @@ class TrainingArguments: auto_find_batch_size (`bool`, *optional*, defaults to `False`) Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) + full_determinism (`bool`, *optional*, defaults to `False`) + If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in + distributed training """ output_dir: str = field( @@ -814,6 +817,12 @@ class TrainingArguments: "help": "Whether to automatically decrease the batch size in half and rerun the training loop again each time a CUDA Out-of-Memory was reached" }, ) + full_determinism: bool = field( + default=False, + metadata={ + "help": "Whether to call enable_full_determinism instead of set_seed for reproducibility in distributed training" + }, + ) def __post_init__(self): # Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).