diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index de9b36a445c2..67a796e713e6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -474,8 +474,6 @@ class TrainingArguments: The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, adamw_anyprecision or adafactor. optim_args (`str`, *optional*): Optional arguments that are supplied to AnyPrecisionAdamW. - adafactor (`bool`, *optional*, defaults to `False`): - This argument is deprecated. Use `--optim adafactor` instead. group_by_length (`bool`, *optional*, defaults to `False`): Whether or not to group together samples of roughly the same length in the training dataset (to minimize padding applied and be more efficient). Only useful if applying dynamic padding. @@ -1905,6 +1903,524 @@ def to_sanitized_dict(self) -> Dict[str, Any]: return {k: v if type(v) in valid_types else str(v) for k, v in d.items()} + # The following methods are there to simplify the instantiation of `TrainingArguments` + def set_training( + self, + learning_rate: float = 5e-5, + batch_size: int = 8, + weight_decay: float = 0, + num_epochs: float = 3, + max_steps: int = -1, + gradient_accumulation_steps: int = 1, + seed: int = 42, + gradient_checkpointing: bool = False, + ): + """ + A method that regroups all basic arguments linked to the training. + + + + Calling this method will automatically set `self.do_train` to `True`. + + + + Args: + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate for the optimizer. + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for training. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in the + optimizer. + num_train_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents + of the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides + `num_train_epochs`. In case of using a finite iterable dataset the training may stop before reaching + the set number of steps when all data is exhausted. + gradient_accumulation_steps (`int`, *optional*, defaults to 1): + Number of updates steps to accumulate the gradients for, before performing a backward/update pass. + + + + When using gradient accumulation, one step is counted as one step with backward pass. Therefore, + logging, evaluation, save will be conducted every `gradient_accumulation_steps * xxx_step` training + examples. + + + + seed (`int`, *optional*, defaults to 42): + Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use + the [`~Trainer.model_init`] function to instantiate the model if it has some randomly initialized + parameters. + gradient_checkpointing (`bool`, *optional*, defaults to `False`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_training(learning_rate=1e-4, batch_size=32) + >>> args.learning_rate + 1e-4 + ``` + """ + self.do_train = True + self.learning_rate = learning_rate + self.per_device_train_batch_size = batch_size + self.weight_decay = weight_decay + self.num_train_epochs = num_epochs + self.max_steps = max_steps + self.gradient_accumulation_steps = gradient_accumulation_steps + self.seed = seed + self.gradient_checkpointing = gradient_checkpointing + return self + + def set_evaluate( + self, + strategy: Union[str, IntervalStrategy] = "no", + steps: int = 500, + batch_size: int = 8, + accumulation_steps: Optional[int] = None, + delay: Optional[float] = None, + loss_only: bool = False, + jit_mode: bool = False, + ): + """ + A method that regroups all arguments linked to the evaluation. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"no"`): + The evaluation strategy to adopt during training. Possible values are: + + - `"no"`: No evaluation is done during training. + - `"steps"`: Evaluation is done (and logged) every `steps`. + - `"epoch"`: Evaluation is done at the end of each epoch. + + Setting a `strategy` different from `"no"` will set `self.do_eval` to `True`. + steps (`int`, *optional*, defaults to 500): + Number of update steps between two evaluations if `strategy="steps"`. + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for evaluation. + accumulation_steps (`int`, *optional*): + Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. + If left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster + but requires more memory). + delay (`float`, *optional*): + Number of epochs or steps to wait for before the first evaluation can be performed, depending on the + evaluation_strategy. + loss_only (`bool`, *optional*, defaults to `False`): + Ignores all outputs except the loss. + jit_mode (`bool`, *optional*): + Whether or not to use PyTorch jit trace for inference. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_evaluate(strategy="steps", steps=100) + >>> args.eval_steps + 100 + ``` + """ + self.evaluation_strategy = IntervalStrategy(strategy) + if self.evaluation_strategy == IntervalStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.do_eval = self.evaluation_strategy != IntervalStrategy.NO + self.eval_steps = steps + self.per_device_eval_batch_size = batch_size + self.eval_accumulation_steps = accumulation_steps + self.eval_delay = delay + self.prediction_loss_only = loss_only + self.jit_mode_eval = jit_mode + return self + + def set_testing( + self, + batch_size: int = 8, + loss_only: bool = False, + jit_mode: bool = False, + ): + """ + A method that regroups all basic arguments linked to testing on a held-out dataset. + + + + Calling this method will automatically set `self.do_predict` to `True`. + + + + Args: + batch_size (`int` *optional*, defaults to 8): + The batch size per device (GPU/TPU core/CPU...) used for testing. + loss_only (`bool`, *optional*, defaults to `False`): + Ignores all outputs except the loss. + jit_mode (`bool`, *optional*): + Whether or not to use PyTorch jit trace for inference. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_testing(batch_size=32) + >>> args.per_device_eval_batch_size + 32 + ``` + """ + self.do_predict = True + self.per_device_eval_batch_size = batch_size + self.prediction_loss_only = loss_only + self.jit_mode_eval = jit_mode + return self + + def set_save( + self, + strategy: Union[str, IntervalStrategy] = "steps", + steps: int = 500, + total_limit: Optional[int] = None, + on_each_node: bool = False, + ): + """ + A method that regroups all arguments linked to the evaluation. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The checkpoint save strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + + steps (`int`, *optional*, defaults to 500): + Number of updates steps before two checkpoint saves if `strategy="steps"`. + total_limit (`int`, *optional*): + If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in + `output_dir`. + on_each_node (`bool`, *optional*, defaults to `False`): + When doing multi-node distributed training, whether to save models and checkpoints on each node, or + only on the main one. + + This should not be activated when the different nodes use the same storage as the files will be saved + with the same names for each node. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_save(strategy="steps", steps=100) + >>> args.save_steps + 100 + ``` + """ + self.save_strategy = IntervalStrategy(strategy) + if self.save_strategy == IntervalStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.save_steps = steps + self.save_total_limit = total_limit + self.save_on_each_node = on_each_node + return self + + def set_logging( + self, + strategy: Union[str, IntervalStrategy] = "steps", + steps: int = 500, + report_to: Union[str, List[str]] = "none", + level: str = "passive", + first_step: bool = False, + nan_inf_filter: bool = False, + on_each_node: bool = False, + replica_level: str = "passive", + ): + """ + A method that regroups all arguments linked to the evaluation. + + Args: + strategy (`str` or [`~trainer_utils.IntervalStrategy`], *optional*, defaults to `"steps"`): + The logging strategy to adopt during training. Possible values are: + + - `"no"`: No save is done during training. + - `"epoch"`: Save is done at the end of each epoch. + - `"steps"`: Save is done every `save_steps`. + + steps (`int`, *optional*, defaults to 500): + Number of update steps between two logs if `strategy="steps"`. + level (`str`, *optional*, defaults to `"passive"`): + Logger log level to use on the main process. Possible choices are the log levels as strings: `"debug"`, + `"info"`, `"warning"`, `"error"` and `"critical"`, plus a `"passive"` level which doesn't set anything + and lets the application set the level. + report_to (`str` or `List[str]`, *optional*, defaults to `"none"`): + The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`, + `"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. Use `"all"` to report + to all integrations installed, `"none"` for no integrations. + first_step (`bool`, *optional*, defaults to `False`): + Whether to log and evaluate the first `global_step` or not. + nan_inf_filter (`bool`, *optional*, defaults to `True`): + Whether to filter `nan` and `inf` losses for logging. If set to `True` the loss of every step that is + `nan` or `inf` is filtered and the average loss of the current logging window is taken instead. + + + + `nan_inf_filter` only influences the logging of loss values, it does not change the behavior the + gradient is computed or applied to the model. + + + + on_each_node (`bool`, *optional*, defaults to `True`): + In multinode distributed training, whether to log using `log_level` once per node, or only on the main + node. + replica_level (`str`, *optional*, defaults to `"passive"`): + Logger log level to use on replicas. Same choices as `log_level` + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_logging(strategy="steps", steps=100) + >>> args.logging_steps + 100 + ``` + """ + self.logging_strategy = IntervalStrategy(strategy) + if self.logging_strategy == IntervalStrategy.STEPS and steps == 0: + raise ValueError("Setting `strategy` as 'steps' requires a positive value for `steps`.") + self.logging_steps = steps + self.report_to = report_to + self.log_level = level + self.logging_first_step = first_step + self.logging_nan_inf_filter = nan_inf_filter + self.log_on_each_node = on_each_node + self.log_level_replica = replica_level + return self + + def set_push_to_hub( + self, + model_id: str, + strategy: Union[str, HubStrategy] = "every_save", + token: Optional[str] = None, + private_repo: bool = False, + ): + """ + A method that regroups all arguments linked to synchronizing checkpoints with the Hub. + + + + Calling this method will set `self.push_to_hub` to `True`, which means the `output_dir` will begin a git + directory synced with the repo (determined by `model_id`) and the content will be pushed each time a save is + triggered (depending on`self.save_strategy`). Calling [`~Trainer.save_model`] will also trigger a push. + + + + Args: + model_id (`str`): + The name of the repository to keep in sync with the local *output_dir*. It can be a simple model ID in + which case the model will be pushed in your namespace. Otherwise it should be the whole repository + name, for instance `"user_name/model"`, which allows you to push to an organization you are a member of + with `"organization_name/model"`. + strategy (`str` or [`~trainer_utils.HubStrategy`], *optional*, defaults to `"every_save"`): + Defines the scope of what is pushed to the Hub and when. Possible values are: + + - `"end"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and a + draft of a model card when the [`~Trainer.save_model`] method is called. + - `"every_save"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) + and + a draft of a model card each time there is a model save. The pushes are asynchronous to not block + training, and in case the save are very frequent, a new push is only attempted if the previous one is + finished. A last push is made with the final model at the end of training. + - `"checkpoint"`: like `"every_save"` but the latest checkpoint is also pushed in a subfolder named + last-checkpoint, allowing you to resume training easily with + `trainer.train(resume_from_checkpoint="last-checkpoint")`. + - `"all_checkpoints"`: like `"checkpoint"` but all checkpoints are pushed like they appear in the + output + folder (so you will get one checkpoint folder per folder in your final repository) + + token (`str`, *optional*): + The token to use to push the model to the Hub. Will default to the token in the cache folder obtained + with `huggingface-cli login`. + private_repo (`bool`, *optional*, defaults to `False`): + If True, the Hub repo will be set to private. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_push_to_hub("me/awesome-model") + >>> args.hub_model_id + 'me/awesome-model' + ``` + """ + self.push_to_hub = True + self.hub_model_id = model_id + self.hub_strategy = HubStrategy(strategy) + self.hub_token = token + self.hub_private_repo = private_repo + return self + + def set_optimizer( + self, + name: Union[str, OptimizerNames] = "adamw_hf", + learning_rate: float = 5e-5, + weight_decay: float = 0, + beta1: float = 0.9, + beta2: float = 0.999, + epsilon: float = 1e-8, + args: Optional[str] = None, + ): + """ + A method that regroups all arguments linked to the optimizer and its hyperparameters. + + Args: + name (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`): + The optimizer to use: `"adamw_hf"`, `"adamw_torch"`, `"adamw_apex_fused"`, `"adamw_anyprecision"` or + `"adafactor"`. + learning_rate (`float`, *optional*, defaults to 5e-5): + The initial learning rate. + weight_decay (`float`, *optional*, defaults to 0): + The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights. + beta1 (`float`, *optional*, defaults to 0.9): + The beta1 hyperparameter for the adam optimizer or its variants. + beta2 (`float`, *optional*, defaults to 0.999): + The beta2 hyperparameter for the adam optimizer or its variants. + epsilon (`float`, *optional*, defaults to 1e-8): + The epsilon hyperparameter for the adam optimizer or its variants. + args (`str`, *optional*): + Optional arguments that are supplied to AnyPrecisionAdamW (only useful when + `optim="adamw_anyprecision"`). + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_optimizer(name="adamw_torch", beta1=0.8) + >>> args.optim + 'adamw_torch' + ``` + """ + self.optim = OptimizerNames(name) + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.adam_beta1 = beta1 + self.adam_beta2 = beta2 + self.adam_epsilon = epsilon + self.optim_args = args + return self + + def set_lr_scheduler( + self, + name: Union[str, SchedulerType] = "linear", + num_epochs: float = 3.0, + max_steps: int = -1, + warmup_ratio: float = 0, + warmup_steps: int = 0, + ): + """ + A method that regroups all arguments linked to the learning rate scheduler and its hyperparameters. + + Args: + name (`str` or [`SchedulerType`], *optional*, defaults to `"linear"`): + The scheduler type to use. See the documentation of [`SchedulerType`] for all possible values. + num_epochs(`float`, *optional*, defaults to 3.0): + Total number of training epochs to perform (if not an integer, will perform the decimal part percents + of the last epoch before stopping training). + max_steps (`int`, *optional*, defaults to -1): + If set to a positive number, the total number of training steps to perform. Overrides + `num_train_epochs`. In case of using a finite iterable dataset the training may stop before reaching + the set number of steps when all data is exhausted. + warmup_ratio (`float`, *optional*, defaults to 0.0): + Ratio of total training steps used for a linear warmup from 0 to `learning_rate`. + warmup_steps (`int`, *optional*, defaults to 0): + Number of steps used for a linear warmup from 0 to `learning_rate`. Overrides any effect of + `warmup_ratio`. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_lr_scheduler(name="cosine", warmup_ratio=0.05) + >>> args.warmup_ratio + 0.05 + ``` + """ + self.lr_scheduler_type = SchedulerType(name) + self.num_train_epochs = num_epochs + self.max_steps = max_steps + self.warmup_ratio = warmup_ratio + self.warmup_steps = warmup_steps + return self + + def set_dataloader( + self, + train_batch_size: int = 8, + eval_batch_size: int = 8, + drop_last: bool = False, + num_workers: int = 0, + pin_memory: bool = True, + auto_find_batch_size: bool = False, + ignore_data_skip: bool = False, + sampler_seed: Optional[int] = None, + ): + """ + A method that regroups all arguments linked to the dataloaders creation. + + Args: + drop_last (`bool`, *optional*, defaults to `False`): + Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch + size) or not. + num_workers (`int`, *optional*, defaults to 0): + Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in + the main process. + pin_memory (`bool`, *optional*, defaults to `True`): + Whether you want to pin memory in data loaders or not. Will default to `True`. + auto_find_batch_size (`bool`, *optional*, defaults to `False`) + Whether to find a batch size that will fit into memory automatically through exponential decay, + avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`) + ignore_data_skip (`bool`, *optional*, defaults to `False`): + When resuming training, whether or not to skip the epochs and batches to get the data loading at the + same stage as in the previous training. If set to `True`, the training will begin faster (as that + skipping step can take a long time) but will not yield the same results as the interrupted training + would have. + sampler_seed (`int`, *optional*): + Random seed to be used with data samplers. If not set, random generators for data sampling will use the + same seed as `self.seed`. This can be used to ensure reproducibility of data sampling, independent of + the model seed. + + Example: + + ```py + >>> from transformers import TrainingArguments + + >>> args = TrainingArguments("working_dir") + >>> args = args.set_dataloader(train_batch_size=16, eval_batch_size=64) + >>> args.per_device_train_batch_size + 16 + ``` + """ + self.per_device_train_batch_size = train_batch_size + self.per_device_eval_batch_size = eval_batch_size + self.dataloader_drop_last = drop_last + self.dataloader_num_workers = num_workers + self.dataloader_pin_memory = pin_memory + self.auto_find_batch_size = auto_find_batch_size + self.ignore_data_skip = ignore_data_skip + self.data_seed = sampler_seed + return self + class ParallelMode(Enum): NOT_PARALLEL = "not_parallel"