Skip to content

Commit

Permalink
Fix mypy in utilities.argparse (#8124)
Browse files Browse the repository at this point in the history
Co-authored-by: tchaton <[email protected]>
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
3 people authored Jul 30, 2021
1 parent 16392a7 commit 1f01db8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ module = [
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.device_dtype_mixin",
"pytorch_lightning.utilities.device_parser",
Expand Down
45 changes: 30 additions & 15 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,27 @@
# limitations under the License.
import inspect
import os
from abc import ABC
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import suppress
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
class ParseArgparserDataType(ABC):
def __init__(self, *_: Any, **__: Any) -> None:
pass

@classmethod
def parse_argparser(cls, args: "ArgumentParser") -> Any:
pass


def from_argparse_args(
cls: Type[ParseArgparserDataType], args: Union[Namespace, ArgumentParser], **kwargs: Any
) -> ParseArgparserDataType:
"""Create an instance from CLI arguments.
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"
Expand Down Expand Up @@ -52,7 +65,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
return cls(**trainer_kwargs)


def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser

Expand All @@ -77,7 +90,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
return Namespace(**modified_args)


def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.
Example:
Expand Down Expand Up @@ -106,7 +119,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s")
return Namespace(**env_args)


def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the class signature and returns argument names, types and default values.
Returns:
Expand Down Expand Up @@ -134,7 +147,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
return name_type_default


def _get_abbrev_qualified_cls_name(cls):
def _get_abbrev_qualified_cls_name(cls: Any) -> str:
assert isinstance(cls, type), repr(cls)
if cls.__module__.startswith("pytorch_lightning."):
# Abbreviate.
Expand All @@ -143,7 +156,9 @@ def _get_abbrev_qualified_cls_name(cls):
return f"{cls.__module__}.{cls.__qualname__}"


def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=True) -> ArgumentParser:
def add_argparse_args(
cls: Type["pl.Trainer"], parent_parser: ArgumentParser, *, use_argument_group: bool = True
) -> Union[_ArgumentGroup, ArgumentParser]:
r"""Extends existing argparse by default attributes for ``cls``.
Args:
Expand Down Expand Up @@ -187,7 +202,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=
raise RuntimeError("Please only pass an ArgumentParser instance.")
if use_argument_group:
group_name = _get_abbrev_qualified_cls_name(cls)
parser = parent_parser.add_argument_group(group_name)
parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name)
else:
parser = ArgumentParser(parents=[parent_parser], add_help=False)

Expand All @@ -207,16 +222,16 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=
args_help = _parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "")

for arg, arg_types, arg_default in args_and_types:
arg_types = [at for at in allowed_types if at in arg_types]
arg_types = tuple(at for at in allowed_types if at in arg_types)
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
arg_kwargs: Dict[str, Any] = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = str_to_bool
use_type: Callable[[str], Union[bool, int, float, str]] = str_to_bool
elif int in arg_types:
use_type = str_to_bool_or_int
elif str in arg_types:
Expand Down Expand Up @@ -249,7 +264,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=

def _parse_args_from_docstring(docstring: str) -> Dict[str, str]:
arg_block_indent = None
current_arg = None
current_arg = ""
parsed = {}
for line in docstring.split("\n"):
stripped = line.lstrip()
Expand All @@ -270,20 +285,20 @@ def _parse_args_from_docstring(docstring: str) -> Dict[str, str]:
return parsed


def _gpus_allowed_type(x) -> Union[int, str]:
def _gpus_allowed_type(x: str) -> Union[int, str]:
if "," in x:
return str(x)
return int(x)


def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover
def _gpus_arg_default(x: str) -> Union[int, str]: # pragma: no-cover
# unused, but here for backward compatibility with old checkpoints that need to be able to
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
pass


def _int_or_float_type(x) -> Union[int, float]:
def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]:
if "." in str(x):
return float(x)
return int(x)

0 comments on commit 1f01db8

Please sign in to comment.