From 1f01db8b303e647b102f48c36e91ddb17784414f Mon Sep 17 00:00:00 2001 From: Daniel Stancl <46073029+stancld@users.noreply.github.com> Date: Fri, 30 Jul 2021 18:36:55 +0200 Subject: [PATCH] Fix mypy in utilities.argparse (#8124) Co-authored-by: tchaton Co-authored-by: Carlos Mocholi --- pyproject.toml | 1 + pytorch_lightning/utilities/argparse.py | 45 ++++++++++++++++--------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 874781367ddd0..d0d9aaf383ebe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ module = [ "pytorch_lightning.callbacks.pruning", "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector", + "pytorch_lightning.utilities.argparse", "pytorch_lightning.utilities.cli", "pytorch_lightning.utilities.device_dtype_mixin", "pytorch_lightning.utilities.device_parser", diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 655fc52a4bdfb..bed2461395c98 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -13,14 +13,27 @@ # limitations under the License. import inspect import os +from abc import ABC from argparse import _ArgumentGroup, ArgumentParser, Namespace from contextlib import suppress -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, Type, Union +import pytorch_lightning as pl from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str -def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): +class ParseArgparserDataType(ABC): + def __init__(self, *_: Any, **__: Any) -> None: + pass + + @classmethod + def parse_argparser(cls, args: "ArgumentParser") -> Any: + pass + + +def from_argparse_args( + cls: Type[ParseArgparserDataType], args: Union[Namespace, ArgumentParser], **kwargs: Any +) -> ParseArgparserDataType: """Create an instance from CLI arguments. Eventually use varibles from OS environement which are defined as "PL__" @@ -52,7 +65,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): return cls(**trainer_kwargs) -def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: +def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: """Parse CLI arguments, required for custom bool types.""" args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser @@ -77,7 +90,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp return Namespace(**modified_args) -def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: +def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace: """Parse environment arguments if they are defined. Example: @@ -106,7 +119,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") return Namespace(**env_args) -def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: +def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]: r"""Scans the class signature and returns argument names, types and default values. Returns: @@ -134,7 +147,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: return name_type_default -def _get_abbrev_qualified_cls_name(cls): +def _get_abbrev_qualified_cls_name(cls: Any) -> str: assert isinstance(cls, type), repr(cls) if cls.__module__.startswith("pytorch_lightning."): # Abbreviate. @@ -143,7 +156,9 @@ def _get_abbrev_qualified_cls_name(cls): return f"{cls.__module__}.{cls.__qualname__}" -def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=True) -> ArgumentParser: +def add_argparse_args( + cls: Type["pl.Trainer"], parent_parser: ArgumentParser, *, use_argument_group: bool = True +) -> Union[_ArgumentGroup, ArgumentParser]: r"""Extends existing argparse by default attributes for ``cls``. Args: @@ -187,7 +202,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group= raise RuntimeError("Please only pass an ArgumentParser instance.") if use_argument_group: group_name = _get_abbrev_qualified_cls_name(cls) - parser = parent_parser.add_argument_group(group_name) + parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name) else: parser = ArgumentParser(parents=[parent_parser], add_help=False) @@ -207,16 +222,16 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group= args_help = _parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "") for arg, arg_types, arg_default in args_and_types: - arg_types = [at for at in allowed_types if at in arg_types] + arg_types = tuple(at for at in allowed_types if at in arg_types) if not arg_types: # skip argument with not supported type continue - arg_kwargs = {} + arg_kwargs: Dict[str, Any] = {} if bool in arg_types: arg_kwargs.update(nargs="?", const=True) # if the only arg type is bool if len(arg_types) == 1: - use_type = str_to_bool + use_type: Callable[[str], Union[bool, int, float, str]] = str_to_bool elif int in arg_types: use_type = str_to_bool_or_int elif str in arg_types: @@ -249,7 +264,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group= def _parse_args_from_docstring(docstring: str) -> Dict[str, str]: arg_block_indent = None - current_arg = None + current_arg = "" parsed = {} for line in docstring.split("\n"): stripped = line.lstrip() @@ -270,20 +285,20 @@ def _parse_args_from_docstring(docstring: str) -> Dict[str, str]: return parsed -def _gpus_allowed_type(x) -> Union[int, str]: +def _gpus_allowed_type(x: str) -> Union[int, str]: if "," in x: return str(x) return int(x) -def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover +def _gpus_arg_default(x: str) -> Union[int, str]: # pragma: no-cover # unused, but here for backward compatibility with old checkpoints that need to be able to # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8 # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 pass -def _int_or_float_type(x) -> Union[int, float]: +def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]: if "." in str(x): return float(x) return int(x)