Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom argparser extension with Trainer arguments (argument types added) #1147

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ea6b8f1
`add_argparse_args` method fixed (argument types added)
Mar 14, 2020
28e9348
CHANGELOG.md upd
Mar 14, 2020
6c290dd
autopep8 fixes
Mar 14, 2020
4444701
--gpus=0 removed from test (for ci tests)
Mar 14, 2020
0ec101f
typo fixed
Mar 14, 2020
bc7dd6f
reduce on plateau scheduler fixed
Mar 14, 2020
57a22ea
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
2d220d5
refactored: get_init_arguments_and_types is a public classmethod of t…
Mar 16, 2020
5b5042e
test_get_init_arguments_and_types added
Mar 16, 2020
dd81a39
autopep8 fixes
Mar 16, 2020
b4d9605
Merge remote-tracking branch 'upstream/master'
Mar 16, 2020
977b2de
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
7a9ba50
refactored: get_init_arguments_and_types is a public classmethod of t…
Mar 16, 2020
551ff24
test_get_init_arguments_and_types added
Mar 16, 2020
d3c1fdc
autopep8 fixes
Mar 16, 2020
1f92d92
Merge remote-tracking branch 'origin/trainer-argparser-types' into tr…
Mar 16, 2020
2b27e3c
merged
Mar 17, 2020
f501ce6
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
4bccfd7
refactored: get_init_arguments_and_types is a public classmethod of t…
Mar 16, 2020
38dcba7
test_get_init_arguments_and_types added
Mar 16, 2020
e1abf4d
autopep8 fixes
Mar 16, 2020
ea3334a
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
f79dcbf
test_get_init_arguments_and_types added
Mar 16, 2020
299574d
autopep8 fixes
Mar 16, 2020
4782fe6
Merge remote-tracking branch 'origin/trainer-argparser-types' into tr…
Mar 17, 2020
20c4073
Apply suggestions from code review
Borda Mar 17, 2020
201d855
cosmetics
Mar 17, 2020
567d9b8
Merge remote-tracking branch 'origin/trainer-argparser-types' into tr…
Mar 17, 2020
0c292b4
cosmetics
Mar 17, 2020
f45fa6f
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 17, 2020
eb23637
`Trainer.get_init_arguments_and_types` now returns arg types wrapped …
Mar 17, 2020
5f4a5fe
deprecated args are now ignored in argparser
Mar 17, 2020
f48f17e
get_deprecated_arg_names small refactor
Mar 17, 2020
2207c4e
get_deprecated_arg_names bug fixed
Mar 17, 2020
0a73a34
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
0cd9aaa
refactored: get_init_arguments_and_types is a public classmethod of t…
Mar 16, 2020
25bd99d
test_get_init_arguments_and_types added
Mar 16, 2020
5253a50
autopep8 fixes
Mar 16, 2020
3fd9a99
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
36ca20b
autopep8 fixes
Mar 16, 2020
f7cc4e1
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
ce837fe
autopep8 fixes
Mar 16, 2020
8aa062a
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
d447aa3
test_get_init_arguments_and_types added
Mar 16, 2020
aaada85
autopep8 fixes
Mar 16, 2020
87b6cbb
Apply suggestions from code review
Borda Mar 17, 2020
a9f996b
cosmetics
Mar 17, 2020
a4702ff
cosmetics
Mar 17, 2020
162f3ab
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 17, 2020
a8b21d9
`Trainer.get_init_arguments_and_types` now returns arg types wrapped …
Mar 17, 2020
2a5d46d
deprecated args are now ignored in argparser
Mar 17, 2020
3b7c865
get_deprecated_arg_names small refactor
Mar 17, 2020
d07870d
get_deprecated_arg_names bug fixed
Mar 17, 2020
012d425
Merge remote-tracking branch 'origin/trainer-argparser-types' into tr…
Mar 17, 2020
9449b98
Merge branch 'master' into trainer-argparser-types
Borda Mar 18, 2020
94c5e09
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 20, 2020
4f8314d
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 20, 2020
77847b7
Merge branch 'master' into trainer-argparser-types
Borda Mar 20, 2020
63d30fd
Merge branch 'master' into trainer-argparser-types
williamFalcon Mar 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed bug related to type checking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))

- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)).
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class UNet(nn.Module):
bilinear (bool) - Whether to use bilinear interpolation or transposed
convolutions for upsampling.
'''

def __init__(self, num_classes=19, bilinear=False):
super().__init__()
self.layer1 = DoubleConv(3, 64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class DoubleConv(nn.Module):
Double Convolution and BN and ReLU
(3x3 conv -> BN -> ReLU) ** 2
'''

def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
Expand All @@ -27,6 +28,7 @@ class Down(nn.Module):
'''
Combination of MaxPool2d and DoubleConv in series
'''

def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
Expand All @@ -44,6 +46,7 @@ class Up(nn.Module):
followed by concatenation of feature map from contracting path,
followed by double 3x3 convolution.
'''

def __init__(self, in_ch, out_ch, bilinear=False):
super().__init__()
self.upsample = None
Expand Down
2 changes: 2 additions & 0 deletions pl_examples/full_examples/semantic_segmentation/semseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class KITTI(Dataset):
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
'''

def __init__(
self,
root_path,
Expand Down Expand Up @@ -120,6 +121,7 @@ class SegModel(pl.LightningModule):

Adam optimizer is used along with Cosine Annealing learning rate scheduler.
'''

def __init__(self, hparams):
super(SegModel, self).__init__()
self.root_path = hparams.root
Expand Down
112 changes: 92 additions & 20 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,30 @@
import sys
import warnings
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
import distutils

alexeykarnachev marked this conversation as resolved.
Show resolved Hide resolved
import torch
from torch import optim
import torch.distributed as torch_distrib
import torch.multiprocessing as mp
from torch import optim
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
from pytorch_lightning.profiler.profiler import BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import (
TrainerDPMixin,
parse_gpu_ids,
determine_root_gpu_device
)
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -70,6 +67,11 @@ class Trainer(
TrainerCallbackHookMixin,
TrainerDeprecatedAPITillVer0_8,
):
DEPRECATED_IN_0_8 = (
'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs',
'add_row_log_interval', 'nb_sanity_val_steps'
)
DEPRECATED_IN_0_9 = ('use_amp',)

def __init__(
self,
Expand Down Expand Up @@ -466,21 +468,91 @@ def default_attributes(cls):

return args

@classmethod
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the Trainer signature and returns argument names, types and default values.

Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
Borda marked this conversation as resolved.
Show resolved Hide resolved

Examples:
>>> args = Trainer.get_init_arguments_and_types()
>>> import pprint
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
[('accumulate_grad_batches',
(<class 'int'>, typing.Dict[int, int], typing.List[list]),
1),
...
('callbacks', (<class 'pytorch_lightning.callbacks.base.Callback'>,), []),
('check_val_every_n_epoch', (<class 'int'>,), 1),
...
('max_epochs', (<class 'int'>,), 1000),
...
('precision', (<class 'int'>,), 32),
('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
<class 'NoneType'>),
None),
...
"""
alexeykarnachev marked this conversation as resolved.
Show resolved Hide resolved
trainer_default_params = inspect.signature(cls).parameters
name_type_default = []
for arg in trainer_default_params:
arg_type = trainer_default_params[arg].annotation
arg_default = trainer_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type,)

name_type_default.append((arg, arg_types, arg_default))

return name_type_default

@classmethod
def get_deprecated_arg_names(cls) -> List:
"""Returns a list with deprecated Trainer arguments."""
depr_arg_names = []
for name, val in cls.__dict__.items():
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
"""Extend existing argparse by default `Trainer` attributes."""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
r"""Extends existing argparse by default `Trainer` attributes.

trainer_default_params = Trainer.default_attributes()
Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.

Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, )

depr_arg_names = cls.get_deprecated_arg_names()

allowed_types = (str, float, int, bool)
# TODO: get "help" from docstring :)
for arg in trainer_default_params:
parser.add_argument(
f'--{arg}',
default=trainer_default_params[arg],
dest=arg,
help='autogenerated by pl.Trainer'
)
for arg, arg_types, arg_default in cls.get_init_arguments_and_types():
if arg not in depr_arg_names:
for allowed_type in allowed_types:
if allowed_type in arg_types:
alexeykarnachev marked this conversation as resolved.
Show resolved Hide resolved
if allowed_type is bool:
allowed_type = lambda x: bool(distutils.util.strtobool(x))
parser.add_argument(
f'--{arg}',
default=arg_default,
type=allowed_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
break

return parser

Expand Down
24 changes: 2 additions & 22 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import glob
import math
import os
from argparse import ArgumentParser, Namespace
from unittest import mock
from argparse import Namespace

import pytest
import torch
Expand Down Expand Up @@ -251,6 +250,7 @@ def test_dp_output_reduce():

def test_model_checkpoint_options(tmpdir):
"""Test ModelCheckpoint options."""

def mock_save_function(filepath):
open(filepath, 'a').close()

Expand Down Expand Up @@ -624,23 +624,3 @@ def test_epoch_end(self, outputs):

model = LightningTestModel(hparams)
Trainer().test(model)


@mock.patch('argparse.ArgumentParser.parse_args',
return_value=Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
"""Tests default argument parser for Trainer"""
tutils.reset_seed()

# logger file to get meta
logger = tutils.get_test_tube_logger(tmpdir, False)

parser = ArgumentParser(add_help=False)
args = parser.parse_args()
args.logger = logger

args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5
85 changes: 85 additions & 0 deletions tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import inspect
from argparse import ArgumentParser, Namespace
from unittest import mock

import pytest

import tests.models.utils as tutils
from pytorch_lightning import Trainer


@mock.patch('argparse.ArgumentParser.parse_args',
return_value=Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
"""Tests default argument parser for Trainer"""
tutils.reset_seed()

# logger file to get meta
logger = tutils.get_test_tube_logger(tmpdir, False)

parser = ArgumentParser(add_help=False)
args = parser.parse_args()
args.logger = logger

args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5


@pytest.mark.parametrize('cli_args', [
['--accumulate_grad_batches=22'],
['--print_nan_grads=1', '--weights_save_path=./'],
[]
])
def test_add_argparse_args_redefined(cli_args):
"""Redefines some default Trainer arguments via the cli and
tests the Trainer initialization correctness.
"""
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)

args = parser.parse_args(cli_args)

# Check few deprecated args are not in namespace:
for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'):
assert depr_name not in args

trainer = Trainer.from_argparse_args(args=args)
assert isinstance(trainer, Trainer)


def test_get_init_arguments_and_types():
"""Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod."""
args = Trainer.get_init_arguments_and_types()
parameters = inspect.signature(Trainer).parameters
assert len(parameters) == len(args)
for arg in args:
assert parameters[arg[0]].default == arg[2]

kwargs = {arg[0]: arg[2] for arg in args}
trainer = Trainer(**kwargs)
assert isinstance(trainer, Trainer)


@pytest.mark.parametrize('cli_args', [
['--callbacks=1', '--logger'],
['--foo', '--bar=1']
])
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
"""Asserts thar an error raised in case of passing not default cli arguments."""

class _UnkArgError(Exception):
pass

def _raise():
raise _UnkArgError

parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)

monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True)

with pytest.raises(_UnkArgError):
parser.parse_args(cli_args)