Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default argparser for Trainer #952

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import warnings
import logging as log
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable

import torch
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
profiler: Optional[BaseProfiler] = None,
benchmark: bool = False,
reload_dataloaders_every_epoch: bool = False,
**kwargs
):
r"""

Expand Down Expand Up @@ -626,6 +628,7 @@ def on_train_end(self):

# Transfer params
# Backward compatibility
self.num_nodes = num_nodes
if nb_gpu_nodes is not None:
warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0"
" and this method will be removed in v0.8.0", DeprecationWarning)
Expand Down Expand Up @@ -746,10 +749,12 @@ def on_train_end(self):
self.weights_save_path = weights_save_path

# accumulated grads
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

# allow int, string and gpu list
self.data_parallel_device_ids = parse_gpu_ids(gpus)
self.gpus = gpus
self.data_parallel_device_ids = parse_gpu_ids(self.gpus)
self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids)

# tpu state flags
Expand Down Expand Up @@ -796,6 +801,7 @@ def on_train_end(self):
self.row_log_interval = row_log_interval

# how much of the data to use
self.overfit_pct = overfit_pct
self.determine_data_use_amount(train_percent_check, val_percent_check,
test_percent_check, overfit_pct)

Expand All @@ -818,6 +824,29 @@ def slurm_job_id(self) -> int:
job_id = None
return job_id

@property
@classmethod
def default_attributes(cls):
return vars(cls())

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:

parser = ArgumentParser(parents=[parent_parser])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these should be auto-generated given a trainer class. we don't want to start making sure this list maintains parity.

@Borda

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could we add it to another mixin? This file is already very long.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can regenerate it from class attributes, we may also get types (not sure hot the typing would cooperate)
but not sure how to pass help strings and single-letter shortcuts in an argument...

@PyTorchLightning/core-contributors any idea about passing help string
otherwise I would go for the automatic generated, good idea

Copy link
Contributor

@XDynames XDynames Mar 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've drafted a solution to this that automatically parses intit's doc string for the help fields and the functions signature to populate the argument parser


trainer_default_params = Trainer.default_attributes

for arg in trainer_default_args:
parser.add_argument('--{0}'.format(arg), default=trainer_default_params[arg], dest=arg)

return parser

@classmethod
def from_argparse_args(cls, args) -> Trainer:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Throwing a NameError: name 'Trainer' is not defined while running the tests


params = vars(args)
return cls(**params)

def __parse_gpu_ids(self, gpus):
"""Parse GPUs id.

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def print_nan_gradients(self):
log.info(param, param.grad)

def configure_accumulated_gradients(self, accumulate_grad_batches):
self.accumulate_grad_batches = None

if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
Expand Down
23 changes: 22 additions & 1 deletion tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
import os

import pytest
import torch
import argparse

import tests.models.utils as tutils
from unittest import mock
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
EarlyStopping,
Expand Down Expand Up @@ -856,6 +857,26 @@ def test_end(self, outputs):
Trainer().test(model)


@mock.patch('argparse.ArgumentParser.parse_args',
return_value=argparse.Namespace(**(Trainer.get_default_args())))
def test_default_args(tmpdir):
"""Tests default argument parser for Trainer"""
tutils.reset_seed()

# logger file to get meta
logger = tutils.get_test_tube_logger(tmpdir, False)

parser = argparse.ArgumentParser(add_help=False)
args = parser.parse_args()
args.logger = logger

args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5


def test_trainer_callback_system(tmpdir):
"""Test the callback system."""

Expand Down