From ed6f35ab7ae8d97a9112c6214248b71bcff90bcd Mon Sep 17 00:00:00 2001 From: Jannis Born Date: Wed, 29 Jun 2022 12:57:08 +0200 Subject: [PATCH 1/3] doc: Unify training arg type annotations --- src/transformers/training_args.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 36448cfd54b9..8d858a1a3328 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -20,7 +20,7 @@ from dataclasses import asdict, dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from .debug_utils import DebugOption from .trainer_utils import ( @@ -493,7 +493,7 @@ class TrainingArguments: do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) - evaluation_strategy: IntervalStrategy = field( + evaluation_strategy: Union[IntervalStrategy, str] = field( default="no", metadata={"help": "The evaluation strategy to use."}, ) @@ -559,7 +559,7 @@ class TrainingArguments: default=-1, metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, ) - lr_scheduler_type: SchedulerType = field( + lr_scheduler_type: Union[SchedulerType, str] = field( default="linear", metadata={"help": "The scheduler type to use."}, ) @@ -596,14 +596,14 @@ class TrainingArguments: }, ) logging_dir: Optional[str] = field(default=None, metadata={"help": "Tensorboard log dir."}) - logging_strategy: IntervalStrategy = field( + logging_strategy: Union[IntervalStrategy, str] = field( default="steps", metadata={"help": "The logging strategy to use."}, ) logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"}) logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."}) logging_nan_inf_filter: bool = field(default=True, metadata={"help": "Filter nan and inf losses for logging."}) - save_strategy: IntervalStrategy = field( + save_strategy: Union[IntervalStrategy, str] = field( default="steps", metadata={"help": "The checkpoint save strategy to use."}, ) @@ -815,7 +815,7 @@ class TrainingArguments: label_smoothing_factor: float = field( default=0.0, metadata={"help": "The label smoothing epsilon to apply (zero means no label smoothing)."} ) - optim: OptimizerNames = field( + optim: Union[OptimizerNames, str] = field( default="adamw_hf", metadata={"help": "The optimizer to use."}, ) @@ -868,7 +868,7 @@ class TrainingArguments: hub_model_id: Optional[str] = field( default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) - hub_strategy: HubStrategy = field( + hub_strategy: Union[HubStrategy, str] = field( default="every_save", metadata={"help": "The hub strategy to use when `--push_to_hub` is activated."}, ) From adb7df4ccebbbb7648c14115c05d07a02cd9eb19 Mon Sep 17 00:00:00 2001 From: Jannis Born Date: Thu, 30 Jun 2022 00:23:20 +0200 Subject: [PATCH 2/3] wip: extracting enum type from Union --- src/transformers/hf_argparser.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 1316ff3ba993..9b5084401ea3 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -92,7 +92,14 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): " the argument parser only supports one type per argument." f" Problem encountered in field '{field.name}'." ) - if bool not in field.type.__args__: + if type(None) not in field.type.__args__: + # filter `str` in Union + field.type = ( + field.type.__args__[0] + if field.type.__args__[1] == str else field.type.__args__[1] + ) + origin_type = getattr(field.type, "__origin__", field.type) + elif bool not in field.type.__args__: # filter `NoneType` in Union (except for `Union[bool, NoneType]`) field.type = ( field.type.__args__[0] if isinstance(None, field.type.__args__[1]) else field.type.__args__[1] From b2ffc67216924304a0a61fa45976dafbac654385 Mon Sep 17 00:00:00 2001 From: Jannis Born Date: Thu, 30 Jun 2022 07:36:15 +0200 Subject: [PATCH 3/3] blackening --- src/transformers/hf_argparser.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index 9b5084401ea3..ac3245a29c89 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -94,10 +94,7 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): ) if type(None) not in field.type.__args__: # filter `str` in Union - field.type = ( - field.type.__args__[0] - if field.type.__args__[1] == str else field.type.__args__[1] - ) + field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1] origin_type = getattr(field.type, "__origin__", field.type) elif bool not in field.type.__args__: # filter `NoneType` in Union (except for `Union[bool, NoneType]`)