Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
" the argument parser only supports one type per argument."
f" Problem encountered in field '{field.name}'."
)
if bool not in field.type.__args__:
if type(None) not in field.type.__args__:
# filter `str` in Union
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1]
origin_type = getattr(field.type, "__origin__", field.type)
elif bool not in field.type.__args__:
# filter `NoneType` in Union (except for `Union[bool, NoneType]`)
field.type = (
field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1]
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dataclasses import asdict, dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from .debug_utils import DebugOption
from .trainer_utils import (
Expand Down Expand Up @@ -493,7 +493,7 @@ class TrainingArguments:
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluation_strategy: IntervalStrategy = field(
evaluation_strategy: Union[IntervalStrategy, str] = field(
default="no",
metadata={"help": "The evaluation strategy to use."},
)
Expand Down Expand Up @@ -559,7 +559,7 @@ class TrainingArguments:
default=-1,
metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."},
)
lr_scheduler_type: SchedulerType = field(
lr_scheduler_type: Union[SchedulerType, str] = field(
default="linear",
metadata={"help": "The scheduler type to use."},
)
Expand Down Expand Up @@ -596,14 +596,14 @@ class TrainingArguments:
},
)
logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."})
logging_strategy: IntervalStrategy = field(
logging_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The logging strategy to use."},
)
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."})
save_strategy: IntervalStrategy = field(
save_strategy: Union[IntervalStrategy, str] = field(
default="steps",
metadata={"help": "The checkpoint save strategy to use."},
)
Expand Down Expand Up @@ -815,7 +815,7 @@ class TrainingArguments:
label_smoothing_factor: float = field(
default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."}
)
optim: OptimizerNames = field(
optim: Union[OptimizerNames, str] = field(
default="adamw_hf",
metadata={"help": "The optimizer to use."},
)
Expand Down Expand Up @@ -868,7 +868,7 @@ class TrainingArguments:
hub_model_id: Optional[str] = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_strategy: HubStrategy = field(
hub_strategy: Union[HubStrategy, str] = field(
default="every_save",
metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."},
)
Expand Down