Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ class TrainingArguments:
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: int = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
bf16: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -616,14 +616,14 @@ class TrainingArguments:
default=False,
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
)
tf32: bool = field(
tf32: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
},
)
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
xpu_backend: str = field(
xpu_backend: Optional[str] = field(
default=None,
metadata={"help": "The backend to be used for distributed training on Intel XPU.", "choices": ["mpi", "ccl"]},
)
Expand All @@ -648,7 +648,7 @@ class TrainingArguments:
dataloader_drop_last: bool = field(
default=False, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."}
)
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
eval_steps: Optional[int] = field(default=None, metadata={"help": "Run an evaluation every X steps."})
dataloader_num_workers: int = field(
default=0,
metadata={
Expand Down Expand Up @@ -770,14 +770,14 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
hub_model_id: str = field(
hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_token: Optional[str] = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
hub_private_repo: bool = field(default=False, metadata={"help": "Whether the model repository is private or not."})
gradient_checkpointing: bool = field(
default=False,
Expand All @@ -793,13 +793,15 @@ class TrainingArguments:
default="auto",
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
)
push_to_hub_model_id: str = field(
push_to_hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
)
push_to_hub_organization: str = field(
push_to_hub_organization: Optional[str] = field(
default=None, metadata={"help": "The name of the organization in with to which push the `Trainer`."}
)
push_to_hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
push_to_hub_token: Optional[str] = field(
default=None, metadata={"help": "The token to use to push to the Model Hub."}
)
_n_gpu: int = field(init=False, repr=False, default=-1)
mp_parameters: str = field(
default="",
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/training_args_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import warnings
from dataclasses import dataclass, field
from typing import Tuple
from typing import Optional, Tuple

from .training_args import TrainingArguments
from .utils import cached_property, is_tf_available, logging, tf_required
Expand Down Expand Up @@ -161,17 +161,17 @@ class TFTrainingArguments(TrainingArguments):
Whether to activate the XLA compilation or not.
"""

tpu_name: str = field(
tpu_name: Optional[str] = field(
default=None,
metadata={"help": "Name of TPU"},
)

tpu_zone: str = field(
tpu_zone: Optional[str] = field(
default=None,
metadata={"help": "Zone of TPU"},
)

gcp_project: str = field(
gcp_project: Optional[str] = field(
default=None,
metadata={"help": "Name of Cloud TPU-enabled project"},
)
Expand Down