Skip to content

Commit

Permalink
Add required for positional arguments in argparse logic (#12504)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
fschlatt and awaelchli authored Apr 22, 2022
1 parent 5b511da commit f4f70a8
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Marked `swa_lrs` argument in `StochasticWeightAveraging` callback as required ([#12556](https://github.com/PyTorchLightning/pytorch-lightning/pull/12556))


-
- Make positional arguments required for classes passed into the `add_argparse_args` function. ([#1250](https://github.com/PyTorchLightning/pytorch-lightning/pull/12504))


-
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,13 @@ def add_argparse_args(
use_type = _precision_allowed_type

parser.add_argument(
f"--{arg}", dest=arg, default=arg_default, type=use_type, help=args_help.get(arg), **arg_kwargs
f"--{arg}",
dest=arg,
default=arg_default,
type=use_type,
help=args_help.get(arg),
required=(arg_default == inspect._empty),
**arg_kwargs,
)

if use_argument_group:
Expand Down
22 changes: 21 additions & 1 deletion tests/utilities/test_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


class ArgparseExample:
def __init__(self, a: int = 0, b: str = "", c: bool = False):
def __init__(self, a: int, b: str = "", c: bool = False):
self.a = a
self.b = b
self.c = c
Expand Down Expand Up @@ -147,6 +147,16 @@ def __init__(self, invalid_class: SomeClass):
pass


class AddArgparseArgsExampleClassNoDefault:
"""
Args:
my_parameter: A thing.
"""

def __init__(self, my_parameter: int):
pass


def extract_help_text(parser):
help_str_buffer = io.StringIO()
parser.print_help(file=help_str_buffer)
Expand All @@ -160,6 +170,7 @@ def extract_help_text(parser):
[AddArgparseArgsExampleClass, "AddArgparseArgsExampleClass"],
[AddArgparseArgsExampleClassViaInit, "AddArgparseArgsExampleClassViaInit"],
[AddArgparseArgsExampleClassNoDoc, "AddArgparseArgsExampleClassNoDoc"],
[AddArgparseArgsExampleClassNoDefault, "AddArgparseArgsExampleClassNoDefault"],
],
)
def test_add_argparse_args(cls, name):
Expand All @@ -185,6 +196,15 @@ def test_add_argparse_args(cls, name):
assert args.main_arg == "abc"
assert args.my_parameter == 2

fake_argv = ["--main_arg=abc"]
if cls is AddArgparseArgsExampleClassNoDefault:
with pytest.raises(SystemExit):
parser.parse_args(fake_argv)
else:
args = parser.parse_args(fake_argv)
assert args.main_arg == "abc"
assert args.my_parameter == 0


def test_negative_add_argparse_args():
with pytest.raises(RuntimeError, match="Please only pass an ArgumentParser instance."):
Expand Down

0 comments on commit f4f70a8

Please sign in to comment.