diff --git a/src/transformers/hf_argparser.py b/src/transformers/hf_argparser.py index d2b9f22eb004..1316ff3ba993 100644 --- a/src/transformers/hf_argparser.py +++ b/src/transformers/hf_argparser.py @@ -84,7 +84,9 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field): origin_type = getattr(field.type, "__origin__", field.type) if origin_type is Union: - if len(field.type.__args__) != 2 or type(None) not in field.type.__args__: + if str not in field.type.__args__ and ( + len(field.type.__args__) != 2 or type(None) not in field.type.__args__ + ): raise ValueError( "Only `Union[X, NoneType]` (i.e., `Optional[X]`) is allowed for `Union` because" " the argument parser only supports one type per argument." diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index 136762a37858..aa6e586c075a 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -240,7 +240,7 @@ def to_tuple(self) -> Tuple[Any]: return tuple(self[k] for k in self.keys()) -class ExplicitEnum(Enum): +class ExplicitEnum(str, Enum): """ Enum with more explicit error message for missing values. """