Skip to content

Commit

Permalink
Fix lr finder for optimizers with states (#3897)
Browse files Browse the repository at this point in the history
* fix lr finder

* changelog

* add test
  • Loading branch information
SkafteNicki authored Oct 6, 2020
1 parent 04303b3 commit 3ab43dd
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 26 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `current_epoch` and `global_step` properties mismatch between `Trainer` and `LightningModule` ([#3785](https://github.com/PyTorchLightning/pytorch-lightning/pull/3785))

- Fixed learning rate scheduler for optimizers with internal state ([#3897](https://github.com/PyTorchLightning/pytorch-lightning/pull/3897))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
65 changes: 40 additions & 25 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
import importlib
import os
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Union, Callable
from functools import wraps

import numpy as np
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -165,13 +167,7 @@ def lr_find(
trainer.save_checkpoint(str(save_path))

# Configure optimizer and scheduler
optimizers, _, _ = trainer.init_optimizers(model)

if len(optimizers) != 1:
raise MisconfigurationException(
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
' learning rate finder only works with single optimizer')
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])
model.configure_optimizers = lr_finder._exchange_scheduler(model.configure_optimizers)

# Fit, lr & loss logged in callback
trainer.fit(model,
Expand Down Expand Up @@ -261,28 +257,47 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.results = {}
self._total_batch_idx = 0 # for debug purpose

def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
""" Construct a new `configure_optimizers()` method, that has a optimizer
with initial lr set to lr_min and a scheduler that will either
linearly or exponentially increase the lr to lr_max in num_training steps.
Args:
optimizer: instance of `torch.optim.Optimizer`
def _exchange_scheduler(self, configure_optimizers: Callable):
""" Decorate configure_optimizers methods such that it returns the users
originally specified optimizer together with a new scheduler that
that takes care of the learning rate search.
"""
new_lrs = [self.lr_min] * len(optimizer.param_groups)
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
param_group["lr"] = new_lr
param_group["initial_lr"] = new_lr

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)
@wraps(configure_optimizers)
def func():
# Decide the structure of the output from configure_optimizers
# Same logic as method `init_optimizers` in trainer/optimizers.py
optim_conf = configure_optimizers()
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
and isinstance(optim_conf[0], list):
optimizers, _ = optim_conf
elif isinstance(optim_conf, dict):
optimizers = [optim_conf["optimizer"]]
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
elif isinstance(optim_conf, (list, tuple)):
optimizers = [optim_conf]

if len(optimizers) != 1:
raise MisconfigurationException(
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
' learning rate finder only works with single optimizer')

optimizer = optimizers[0]

new_lrs = [self.lr_min] * len(optimizer.param_groups)
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
param_group["lr"] = new_lr
param_group["initial_lr"] = new_lr

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)

def configure_optimizers():
return [optimizer], [{'scheduler': scheduler,
'interval': 'step'}]

return configure_optimizers
return func

def plot(self, suggest: bool = False, show: bool = False):
""" Plot results from lr_find run
Expand Down
4 changes: 4 additions & 0 deletions tests/base/model_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def configure_optimizers__lbfgs(self):
optimizer = optim.LBFGS(self.parameters(), lr=self.learning_rate)
return optimizer

def configure_optimizers__adagrad(self):
optimizer = optim.Adagrad(self.parameters(), lr=self.learning_rate)
return optimizer

def configure_optimizers__multiple_optimizers(self):
"""
return whatever optimizers we want here.
Expand Down
5 changes: 4 additions & 1 deletion tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,14 @@ def test_trainer_arg_str(tmpdir, use_hparams):
'Learning rate was not altered after running learning rate finder'


def test_call_to_trainer_method(tmpdir):
@pytest.mark.parametrize('optimizer', ['Adam', 'Adagrad'])
def test_call_to_trainer_method(tmpdir, optimizer):
""" Test that directly calling the trainer method works """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
if optimizer == 'adagrad':
model.configure_optimizers = model.configure_optimizers__adagrad

before_lr = hparams.get('learning_rate')
# logger file to get meta
Expand Down

0 comments on commit 3ab43dd

Please sign in to comment.