-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Default argparser for Trainer #952
Changes from 16 commits
ec07d11
78b3552
77a1fc3
a68baeb
541e2ec
679c9a7
a017312
1c7fccc
c3942cd
b31295f
fc66c0c
7eb4149
7cf6a7c
0ee7128
b417be7
e456cbc
df0a487
9a8d05a
baf5219
1a7c208
60d4ca5
cd25d79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import sys | ||
import warnings | ||
import logging as log | ||
from argparse import ArgumentParser | ||
from typing import Union, Optional, List, Dict, Tuple, Iterable | ||
|
||
import torch | ||
|
@@ -116,6 +117,7 @@ def __init__( | |
profiler: Optional[BaseProfiler] = None, | ||
benchmark: bool = False, | ||
reload_dataloaders_every_epoch: bool = False, | ||
**kwargs | ||
): | ||
r""" | ||
|
||
|
@@ -626,6 +628,7 @@ def on_train_end(self): | |
|
||
# Transfer params | ||
# Backward compatibility | ||
self.num_nodes = num_nodes | ||
if nb_gpu_nodes is not None: | ||
warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" | ||
" and this method will be removed in v0.8.0", DeprecationWarning) | ||
|
@@ -746,10 +749,12 @@ def on_train_end(self): | |
self.weights_save_path = weights_save_path | ||
|
||
# accumulated grads | ||
self.accumulate_grad_batches = accumulate_grad_batches | ||
self.configure_accumulated_gradients(accumulate_grad_batches) | ||
|
||
# allow int, string and gpu list | ||
self.data_parallel_device_ids = parse_gpu_ids(gpus) | ||
self.gpus = gpus | ||
self.data_parallel_device_ids = parse_gpu_ids(self.gpus) | ||
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) | ||
|
||
# tpu state flags | ||
|
@@ -796,6 +801,7 @@ def on_train_end(self): | |
self.row_log_interval = row_log_interval | ||
|
||
# how much of the data to use | ||
self.overfit_pct = overfit_pct | ||
self.determine_data_use_amount(train_percent_check, val_percent_check, | ||
test_percent_check, overfit_pct) | ||
|
||
|
@@ -818,6 +824,29 @@ def slurm_job_id(self) -> int: | |
job_id = None | ||
return job_id | ||
|
||
@property | ||
@classmethod | ||
def default_attributes(cls): | ||
return vars(cls()) | ||
|
||
@classmethod | ||
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: | ||
|
||
parser = ArgumentParser(parents=[parent_parser]) | ||
|
||
trainer_default_params = Trainer.default_attributes | ||
|
||
for arg in trainer_default_args: | ||
parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg) | ||
|
||
return parser | ||
|
||
@classmethod | ||
def from_argparse_args(cls, args) -> Trainer: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Throwing a |
||
|
||
params = vars(args) | ||
return cls(**params) | ||
|
||
def __parse_gpu_ids(self, gpus): | ||
"""Parse GPUs id. | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,10 +1,11 @@ | ||||||
import math | ||||||
import os | ||||||
|
||||||
import pytest | ||||||
import torch | ||||||
import argparse | ||||||
|
||||||
import tests.models.utils as tutils | ||||||
from unittest import mock | ||||||
from pytorch_lightning import Trainer | ||||||
from pytorch_lightning.callbacks import ( | ||||||
EarlyStopping, | ||||||
|
@@ -856,6 +857,26 @@ def test_end(self, outputs): | |||||
Trainer().test(model) | ||||||
|
||||||
|
||||||
@mock.patch('argparse.ArgumentParser.parse_args', | ||||||
return_value=argparse.Namespace(**(Trainer.default_attributes))) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Borda this doesn't seem to work, I see a |
||||||
def test_default_args(tmpdir): | ||||||
"""Tests default argument parser for Trainer""" | ||||||
tutils.reset_seed() | ||||||
|
||||||
# logger file to get meta | ||||||
logger = tutils.get_test_tube_logger(tmpdir, False) | ||||||
|
||||||
parser = argparse.ArgumentParser(add_help=False) | ||||||
args = parser.parse_args() | ||||||
args.logger = logger | ||||||
|
||||||
args.max_epochs = 5 | ||||||
trainer = Trainer.from_argparse_args(args) | ||||||
|
||||||
assert isinstance(trainer, Trainer) | ||||||
assert trainer.max_epochs == 5 | ||||||
|
||||||
|
||||||
def test_trainer_callback_system(tmpdir): | ||||||
"""Test the callback system.""" | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these should be auto-generated given a trainer class. we don't want to start making sure this list maintains parity.
@Borda
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, could we add it to another mixin? This file is already very long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can regenerate it from class attributes, we may also get types (not sure hot the typing would cooperate)
but not sure how to pass help strings and single-letter shortcuts in an argument...
@PyTorchLightning/core-contributors any idea about passing help string
otherwise I would go for the automatic generated, good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've drafted a solution to this that automatically parses intit's doc string for the help fields and the functions signature to populate the argument parser