diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 631fe0b3ec57..4eb9b274e1d0 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -582,7 +582,7 @@ class TrainingArguments: ) no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) - data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."}) + data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."}) bf16: bool = field( default=False, metadata={ @@ -616,14 +616,14 @@ class TrainingArguments: default=False, metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"}, ) - tf32: bool = field( + tf32: Optional[bool] = field( default=None, metadata={ "help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change." }, ) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) - xpu_backend: str = field( + xpu_backend: Optional[str] = field( default=None, metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]}, ) @@ -648,7 +648,7 @@ class TrainingArguments: dataloader_drop_last: bool = field( default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} ) - eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."}) + eval_steps: Optional[int] = field(default=None, metadata={"help": "Run an evaluation every X steps."}) dataloader_num_workers: int = field( default=0, metadata={ @@ -770,14 +770,14 @@ class TrainingArguments: default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}, ) - hub_model_id: str = field( + hub_model_id: Optional[str] = field( default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) hub_strategy: HubStrategy = field( default="every_save", metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, ) - hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."}) gradient_checkpointing: bool = field( default=False, @@ -793,13 +793,15 @@ class TrainingArguments: default="auto", metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]}, ) - push_to_hub_model_id: str = field( + push_to_hub_model_id: Optional[str] = field( default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} ) - push_to_hub_organization: str = field( + push_to_hub_organization: Optional[str] = field( default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."} ) - push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."}) + push_to_hub_token: Optional[str] = field( + default=None, metadata={"help": "The token to use to push to the Model Hub."} + ) _n_gpu: int = field(init=False, repr=False, default=-1) mp_parameters: str = field( default="", diff --git a/src/transformers/training_args_tf.py b/src/transformers/training_args_tf.py index 4f3c41e2cab2..4b35b66b07e2 100644 --- a/src/transformers/training_args_tf.py +++ b/src/transformers/training_args_tf.py @@ -14,7 +14,7 @@ import warnings from dataclasses import dataclass, field -from typing import Tuple +from typing import Optional, Tuple from .training_args import TrainingArguments from .utils import cached_property, is_tf_available, logging, tf_required @@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments): Whether to activate the XLA compilation or not. """ - tpu_name: str = field( + tpu_name: Optional[str] = field( default=None, metadata={"help": "Name of TPU"}, ) - tpu_zone: str = field( + tpu_zone: Optional[str] = field( default=None, metadata={"help": "Zone of TPU"}, ) - gcp_project: str = field( + gcp_project: Optional[str] = field( default=None, metadata={"help": "Name of Cloud TPU-enabled project"}, )