Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skepticleo trainer argparser #1023

Merged
merged 32 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ec07d11
Added default parser for trainer and class method to construct traine…
Feb 25, 2020
78b3552
Removed print statement
Feb 25, 2020
77a1fc3
Added test for constructing Trainer from command line args
Feb 26, 2020
a68baeb
Removed extra line
Feb 26, 2020
541e2ec
Merge branch 'master' into trainerArgparser
mtnwni Feb 26, 2020
679c9a7
Removed redundant imports, removed whitespace from empty lines
Feb 26, 2020
a017312
Fixed typo
Feb 26, 2020
1c7fccc
Updated default parser creation to get class attributes automatically
Mar 2, 2020
c3942cd
Updated default parser creation to get class attributes automatically
Mar 2, 2020
b31295f
Merge branch 'trainerArgparser' of https://github.com/skepticleo/pyto…
Mar 2, 2020
fc66c0c
Added method to get default args for trainer
Mar 2, 2020
7eb4149
Trimmed trainer get default args method
Mar 2, 2020
7cf6a7c
Updated from argparse method to not return trainer with static arguments
Mar 3, 2020
0ee7128
Update trainer get default args to classmethod
Mar 3, 2020
b417be7
adjustment
Borda Mar 3, 2020
e456cbc
fix
Borda Mar 3, 2020
df0a487
Fixed variable name
Mar 3, 2020
9a8d05a
Update trainer.py
williamFalcon Mar 3, 2020
baf5219
Update test_trainer.py
williamFalcon Mar 3, 2020
b7e4e88
added checkpoint defaults
williamFalcon Mar 3, 2020
2325f7a
added checkpoint defaults
Mar 3, 2020
aa74241
Merge branch 'skepticleo-trainerArgparser' of https://github.com/PyTo…
williamFalcon Mar 3, 2020
c361ca8
Update trainer.py
williamFalcon Mar 3, 2020
daec5e7
Update tests/trainer/test_trainer.py
williamFalcon Mar 3, 2020
726a4af
Update trainer.py
williamFalcon Mar 3, 2020
2482159
Update test_trainer.py
williamFalcon Mar 3, 2020
104aeaf
Update trainer.py
williamFalcon Mar 3, 2020
ba294c3
Update test_trainer.py
williamFalcon Mar 3, 2020
16b4b97
Update tests/trainer/test_trainer.py
williamFalcon Mar 3, 2020
d74b5b0
Update pytorch_lightning/trainer/trainer.py
williamFalcon Mar 3, 2020
e12e3c6
Update trainer.py
williamFalcon Mar 3, 2020
627cf41
Update test_trainer.py
williamFalcon Mar 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
import logging as log
from typing import Union, Optional, List, Dict, Tuple, Iterable
from argparse import ArgumentParser

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
profiler: Optional[BaseProfiler] = None,
benchmark: bool = False,
reload_dataloaders_every_epoch: bool = False,
**kwargs
):
r"""
Expand Down Expand Up @@ -627,6 +629,7 @@ def on_train_end(self):

# Transfer params
# Backward compatibility
self.num_nodes = num_nodes
if nb_gpu_nodes is not None:
warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
" and this method will be removed in v0.8.0", DeprecationWarning)
Expand Down Expand Up @@ -747,10 +750,12 @@ def on_train_end(self):
self.weights_save_path = weights_save_path

# accumulated grads
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

# allow int, string and gpu list
self.data_parallel_device_ids = parse_gpu_ids(gpus)
self.gpus = gpus
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)

# tpu state flags
Expand Down Expand Up @@ -797,6 +802,7 @@ def on_train_end(self):
self.row_log_interval = row_log_interval

# how much of the data to use
self.overfit_pct = overfit_pct
self.determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)

Expand All @@ -822,6 +828,28 @@ def slurm_job_id(self) -> int:
job_id = None
return job_id

@classmethod
def default_attributes(cls):
return vars(cls())

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
"""Extend existing argparse by default `Trainer` attributes."""
parser = ArgumentParser(parents=[parent_parser])

trainer_default_params = Trainer.default_attributes()

for arg in trainer_default_params:
parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg)

return parser

@classmethod
def from_argparse_args(cls, args):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't there be documentation where the blank line is? did it get lost?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it should...

params = vars(args)
return cls(**params)

def __parse_gpu_ids(self, gpus):
"""Parse GPUs id.
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def print_nan_gradients(self):
log.info(param, param.grad)

def configure_accumulated_gradients(self, accumulate_grad_batches):
self.accumulate_grad_batches = None

if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
Expand Down
22 changes: 21 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
import os

import pytest
import torch
import argparse

import tests.models.utils as tutils
from unittest import mock
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
Expand Down Expand Up @@ -600,3 +601,22 @@ def test_end(self, outputs):

model = LightningTestModel(hparams)
Trainer().test(model)

@mock.patch('argparse.ArgumentParser.parse_args',
return_value=argparse.Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
"""Tests default argument parser for Trainer"""
tutils.reset_seed()

# logger file to get meta
logger = tutils.get_test_tube_logger(tmpdir, False)

parser = argparse.ArgumentParser(add_help=False)
args = parser.parse_args()
args.logger = logger

args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5