Skip to content

Commit

Permalink
Fixed configure optimizer from dict without "scheduler" key (#1443)
Browse files Browse the repository at this point in the history
* `configure_optimizer` from dict with only "optimizer" key. bug fixed

* autopep8

* pep8speaks suggested fixes

* CHANGELOG.md upd
  • Loading branch information
alexeykarnachev authored Apr 10, 2020
1 parent 7857a73 commit 4c34d16
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed optimizer configuration when `configure_optimizers` returns dict without `lr_scheduler` ([#1443](https://github.com/PyTorchLightning/pytorch-lightning/pull/1443))
- Fixed default `DistributedSampler` for DDP training ([#1425](https://github.com/PyTorchLightning/pytorch-lightning/pull/1425))
- Fixed workers warning not on windows ([#1430](https://github.com/PyTorchLightning/pytorch-lightning/pull/1430))
- Fixed returning tuple from `run_training_batch` ([#1431](https://github.com/PyTorchLightning/pytorch-lightning/pull/1431))
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def init_optimizers(
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler])
else:
lr_schedulers = []
return [optimizer], lr_schedulers, []

# multiple dictionaries
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/supporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class TensorRunningAccum(object):
>>> accum.last(), accum.mean(), accum.min(), accum.max()
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
"""

def __init__(self, window_length: int):
self.window_length = window_length
self.memory = torch.Tensor(self.window_length)
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,8 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
if at[0] not in depr_arg_names):
for allowed_type in (at for at in allowed_types if at in arg_types):
if isinstance(allowed_type, bool):
allowed_type = lambda x: bool(distutils.util.strtobool(x))
def allowed_type(x):
return bool(distutils.util.strtobool(x))
parser.add_argument(
f'--{arg}',
default=arg_default,
Expand Down
23 changes: 23 additions & 0 deletions tests/trainer/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,26 @@ class CurrentTestModel(

# verify training completed
assert result == 1


def test_configure_optimizer_from_dict(tmpdir):
"""Tests if `configure_optimizer` method could return a dictionary with
`optimizer` field only.
"""

class CurrentTestModel(LightTrainDataloader, TestModelBase):
def configure_optimizers(self):
config = {
'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03)
}
return config

hparams = tutils.get_default_hparams()
model = CurrentTestModel(hparams)

trainer_options = dict(default_save_path=tmpdir, max_epochs=1)

# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1

0 comments on commit 4c34d16

Please sign in to comment.