diff --git a/CHANGELOG.md b/CHANGELOG.md index 0adee16eff0bb..521f99810ddc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849)) - Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950)) - Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903)) +- Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941)) - Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029)) - Checkpoint and early stopping now work without val step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 395f2f531ec8c..e9d6ab864e5c4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -758,6 +758,15 @@ def configure_optimizers(self): discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10) return [generator_opt, disriminator_opt], [discriminator_sched] + # example with step-based learning_rate schedulers + def configure_optimizers(self): + gen_opt = Adam(self.model_gen.parameters(), lr=0.01) + dis_opt = Adam(self.model_disc.parameters(), lr=0.02) + gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99), + 'interval': 'step'} # called after each training step + dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called after each epoch + return [gen_opt, dis_opt], [gen_sched, dis_sched] + .. note:: Lightning calls .backward() and .step() on each optimizer and learning rate scheduler as needed. .. note:: If you use 16-bit precision (use_amp=True), Lightning will automatically @@ -773,6 +782,8 @@ def configure_optimizers(self): .. note:: If you need to control how often those optimizers step or override the default .step() schedule, override the `optimizer_step` hook. + .. note:: If you only want to call a learning rate schduler every `x` step or epoch, + you can input this as 'frequency' key: dict(scheduler=lr_schudler, interval='step' or 'epoch', frequency=x) """ return Adam(self.parameters(), lr=1e-3) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0558819f4936e..dae535306344b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -6,6 +6,7 @@ from argparse import ArgumentParser import torch +from torch import optim import torch.distributed as dist import torch.multiprocessing as mp from torch.utils.data import DataLoader @@ -743,8 +744,6 @@ def on_train_end(self): # creates a default one if none passed in self.configure_early_stopping(early_stop_callback) - self.reduce_lr_on_plateau_scheduler = None - # configure checkpoint callback self.checkpoint_callback = checkpoint_callback self.weights_save_path = weights_save_path @@ -1079,26 +1078,56 @@ def init_optimizers( optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]] ) -> Tuple[List, List]: - # single optimizer + # single output, single optimizer if isinstance(optimizers, Optimizer): return [optimizers], [] - # two lists - if len(optimizers) == 2 and isinstance(optimizers[0], list): + # two lists, optimizer + lr schedulers + elif len(optimizers) == 2 and isinstance(optimizers[0], list): optimizers, lr_schedulers = optimizers - lr_schedulers, self.reduce_lr_on_plateau_scheduler = self.configure_schedulers(lr_schedulers) + lr_schedulers = self.configure_schedulers(lr_schedulers) return optimizers, lr_schedulers - # single list or tuple - if isinstance(optimizers, (list, tuple)): + # single list or tuple, multiple optimizer + elif isinstance(optimizers, (list, tuple)): return optimizers, [] + # unknown configuration + else: + raise ValueError('Unknown configuration for model optimizers. Output' + 'from model.configure_optimizers() should either be:' + '* single output, single torch.optim.Optimizer' + '* single output, list of torch.optim.Optimizer' + '* two outputs, first being a list of torch.optim.Optimizer', + 'second being a list of torch.optim.lr_scheduler') + def configure_schedulers(self, schedulers: list): - for i, scheduler in enumerate(schedulers): - if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - reduce_lr_on_plateau_scheduler = schedulers.pop(i) - return schedulers, reduce_lr_on_plateau_scheduler - return schedulers, None + # Convert each scheduler into dict sturcture with relevant information + lr_schedulers = [] + default_config = {'interval': 'epoch', # default every epoch + 'frequency': 1, # default every epoch/batch + 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler + 'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau + for scheduler in schedulers: + if isinstance(scheduler, dict): + if 'scheduler' not in scheduler: + raise ValueError(f'Lr scheduler should have key `scheduler`', + ' with item being a lr scheduler') + scheduler['reduce_on_plateau'] = \ + isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau) + + lr_schedulers.append({**default_config, **scheduler}) + + elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): + lr_schedulers.append({**default_config, 'scheduler': scheduler, + 'reduce_on_plateau': True}) + + elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): + lr_schedulers.append({**default_config, 'scheduler': scheduler}) + else: + raise ValueError(f'Input {scheduler} to lr schedulers ' + 'is a invalid input.') + return lr_schedulers def run_pretrain_routine(self, model: LightningModule): """Sanity check a few things before starting actual training. diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 3f0cc4ed92307..c5ade16a84a07 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -1,3 +1,94 @@ +""" +Lightning can automate saving and loading checkpoints +===================================================== + +Checkpointing is enabled by default to the current working directory. +To change the checkpoint path pass in:: + + Trainer(default_save_path='/your/path/to/save/checkpoints') + + +To modify the behavior of checkpointing pass in your own callback. + +.. code-block:: python + + from pytorch_lightning.callbacks import ModelCheckpoint + + # DEFAULTS used by the Trainer + checkpoint_callback = ModelCheckpoint( + filepath=os.getcwd(), + save_best_only=True, + verbose=True, + monitor='val_loss', + mode='min', + prefix='' + ) + + trainer = Trainer(checkpoint_callback=checkpoint_callback) + + +Restoring training session +-------------------------- + +You might want to not only load a model but also continue training it. Use this method to +restore the trainer state as well. This will continue from the epoch and global step you last left off. +However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter). + +Lightning will restore the session if you pass a logger with the same version and there's a saved checkpoint. + +.. code-block:: python + + from pytorch_lightning import Trainer + from pytorch_lightning.loggers import TestTubeLogger + + logger = TestTubeLogger( + save_dir='./savepath', + version=1 # An existing version with a saved checkpoint + ) + trainer = Trainer( + logger=logger, + default_save_path='./savepath' + ) + + # this fit call loads model weights and trainer state + # the trainer continues seamlessly from where you left off + # without having to do anything else. + trainer.fit(model) + + +The trainer restores: + +- global_step +- current_epoch +- All optimizers +- All lr_schedulers +- Model weights + +You can even change the logic of your model as long as the weights and "architecture" of +the system isn't different. If you add a layer, for instance, it might not work. + +At a rough level, here's what happens inside Trainer :py:mod:`pytorch_lightning.base_module.model_saving.py`: + +.. code-block:: python + + self.global_step = checkpoint['global_step'] + self.current_epoch = checkpoint['epoch'] + + # restore the optimizers + optimizer_states = checkpoint['optimizer_states'] + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + optimizer.load_state_dict(opt_state) + + # restore the lr schedulers + lr_schedulers = checkpoint['lr_schedulers'] + for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): + scheduler['scheduler'].load_state_dict(lrs_state) + + # uses the model you passed into trainer + model.load_state_dict(checkpoint['state_dict']) + +""" + import logging as log import os import re @@ -228,8 +319,8 @@ def dump_checkpoint(self): # save lr schedulers lr_schedulers = [] - for i, scheduler in enumerate(self.lr_schedulers): - lr_schedulers.append(scheduler.state_dict()) + for scheduler in self.lr_schedulers: + lr_schedulers.append(scheduler['scheduler'].state_dict()) checkpoint['lr_schedulers'] = lr_schedulers @@ -320,7 +411,7 @@ def restore_training_state(self, checkpoint): # restore the lr schedulers lr_schedulers = checkpoint['lr_schedulers'] for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): - scheduler.load_state_dict(lrs_state) + scheduler['scheduler'].load_state_dict(lrs_state) # ---------------------------------- # PRIVATE OPS diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 82b4ae14cae31..fb76f3ac9b4c6 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -361,17 +361,7 @@ def train(self): self.run_training_epoch() # update LR schedulers - if self.lr_schedulers is not None: - for lr_scheduler in self.lr_schedulers: - lr_scheduler.step() - if self.reduce_lr_on_plateau_scheduler is not None: - val_loss = self.callback_metrics.get('val_loss') - if val_loss is None: - avail_metrics = ','.join(list(self.callback_metrics.keys())) - m = f'ReduceLROnPlateau conditioned on metric val_loss ' \ - f'which is not available. Available metrics are: {avail_metrics}' - raise MisconfigurationException(m) - self.reduce_lr_on_plateau_scheduler.step(val_loss) + self.update_learning_rates(interval='epoch') if self.max_steps and self.max_steps == self.global_step: self.run_training_teardown() @@ -444,6 +434,9 @@ def run_training_epoch(self): # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 + # update lr + self.update_learning_rates(interval='step') + # --------------- # RUN VAL STEP # --------------- @@ -716,6 +709,34 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): return output + def update_learning_rates(self, interval): + ''' Update learning rates + Args: + interval (str): either 'epoch' or 'step'. + ''' + if not self.lr_schedulers: + return + + for lr_scheduler in self.lr_schedulers: + current_idx = self.batch_idx if interval == 'step' else self.current_epoch + current_idx += 1 # account for both batch and epoch starts from 0 + # Take step if call to update_learning_rates matches the interval key and + # the current step modulo the schedulers frequency is zero + if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0: + # If instance of ReduceLROnPlateau, we need to pass validation loss + if lr_scheduler['reduce_on_plateau']: + monitor_key = lr_scheduler['monitor'] + monitor_val = self.callback_metrics.get(monitor_key) + if monitor_val is None: + avail_metrics = ','.join(list(self.callback_metrics.keys())) + m = f'ReduceLROnPlateau conditioned on metric {monitor_key} ' \ + f'which is not available. Available metrics are: {avail_metrics}. ' \ + 'Condition can be set using `monitor` key in lr scheduler dict' + raise MisconfigurationException(m) + lr_scheduler['scheduler'].step(monitor_val) + else: + lr_scheduler['scheduler'].step() + def call_checkpoint_callback(self): if self.checkpoint_callback is not None: self.checkpoint_callback.on_validation_end(self, self.get_model()) diff --git a/tests/models/__init__.py b/tests/models/__init__.py index a10d01d1cebf3..f48b6ab4db351 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -19,6 +19,9 @@ LightValStepFitMultipleDataloadersMixin, LightTrainDataloader, LightTestDataloader, + LightTestOptimizerWithSchedulingMixin, + LightTestMultipleOptimizersWithSchedulingMixin, + LightTestOptimizersWithMixedSchedulingMixin ) diff --git a/tests/models/base.py b/tests/models/base.py index 29fc2177b11ba..8f7f54927a3a1 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -130,7 +130,7 @@ def loss(self, labels, logits): nll = F.nll_loss(logits, labels) return nll - def training_step(self, batch, batch_idx): + def training_step(self, batch, batch_idx, optimizer_idx=None): """ Lightning calls this inside the training loop :param batch: diff --git a/tests/models/mixins.py b/tests/models/mixins.py index 69c0d235761d6..6c3d8f908bfa5 100644 --- a/tests/models/mixins.py +++ b/tests/models/mixins.py @@ -1,7 +1,7 @@ from collections import OrderedDict import torch - +from torch import optim from pytorch_lightning.core.decorators import data_loader @@ -598,6 +598,45 @@ def test_end(self, outputs): return result +class LightTestOptimizerWithSchedulingMixin: + def configure_optimizers(self): + if self.hparams.optimizer_name == 'lbfgs': + optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) + else: + optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) + return [optimizer], [lr_scheduler] + + +class LightTestMultipleOptimizersWithSchedulingMixin: + def configure_optimizers(self): + if self.hparams.optimizer_name == 'lbfgs': + optimizer1 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) + optimizer2 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) + else: + optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] + + +class LightTestOptimizersWithMixedSchedulingMixin: + def configure_optimizers(self): + if self.hparams.optimizer_name == 'lbfgs': + optimizer1 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) + optimizer2 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate) + else: + optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1) + + return [optimizer1, optimizer2], \ + [{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2] + + def _get_output_metric(output, name): if isinstance(output, dict): val = output[name] diff --git a/tests/models/utils.py b/tests/models/utils.py index 8d17984b94bde..2f971162fcf5b 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -82,7 +82,7 @@ def run_model_test(trainer_options, model, on_gpu=True): if trainer.use_ddp or trainer.use_ddp2: # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() + trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model.configure_optimizers()) # test HPC loading / saving trainer.hpc_save(save_dir, logger) diff --git a/tests/test_gpu_models.py b/tests/test_gpu_models.py index 1cdc91fff7cd7..47cd69b521cc2 100644 --- a/tests/test_gpu_models.py +++ b/tests/test_gpu_models.py @@ -116,10 +116,14 @@ def test_optimizer_return_options(): assert len(lr_sched) == 0 # opt tuple of lists - opts = ([opt_a], ['lr_scheduler']) + scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10) + opts = ([opt_a], [scheduler]) optim, lr_sched = trainer.init_optimizers(opts) assert len(optim) == 1 and len(lr_sched) == 1 - assert optim[0] == opts[0][0] and lr_sched[0] == 'lr_scheduler' + assert optim[0] == opts[0][0] and \ + lr_sched[0] == dict(scheduler=scheduler, interval='epoch', + frequency=1, reduce_on_plateau=False, + monitor='val_loss') def test_cpu_slurm_save_load(tmpdir): diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py new file mode 100644 index 0000000000000..bc5dde5f75013 --- /dev/null +++ b/tests/trainer/test_optimizers.py @@ -0,0 +1,146 @@ +import math +import os + +import pytest +import torch + +import tests.models.utils as tutils +from pytorch_lightning import Trainer + +from tests.models import ( + TestModelBase, + LightTrainDataloader, + LightTestOptimizerWithSchedulingMixin, + LightTestMultipleOptimizersWithSchedulingMixin, + LightTestOptimizersWithMixedSchedulingMixin +) + + +def test_optimizer_with_scheduling(tmpdir): + """ Verify that learning rate scheduling is working """ + tutils.reset_seed() + + class CurrentTestModel( + LightTestOptimizerWithSchedulingMixin, + LightTrainDataloader, + TestModelBase): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # fit model + trainer = Trainer(**trainer_options) + results = trainer.fit(model) + + init_lr = hparams.learning_rate + adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups] + + assert len(trainer.lr_schedulers) == 1, \ + 'lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers) + + assert all(a == adjusted_lr[0] for a in adjusted_lr), \ + 'Lr not equally adjusted for all param groups' + adjusted_lr = adjusted_lr[0] + + assert init_lr * 0.1 == adjusted_lr, \ + 'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr) + + +def test_multi_optimizer_with_scheduling(tmpdir): + """ Verify that learning rate scheduling is working """ + tutils.reset_seed() + + class CurrentTestModel( + LightTestMultipleOptimizersWithSchedulingMixin, + LightTrainDataloader, + TestModelBase): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # fit model + trainer = Trainer(**trainer_options) + results = trainer.fit(model) + + init_lr = hparams.learning_rate + adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups] + adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups] + + assert len(trainer.lr_schedulers) == 2, \ + 'all lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers) + + assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \ + 'Lr not equally adjusted for all param groups for optimizer 1' + adjusted_lr1 = adjusted_lr1[0] + + assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \ + 'Lr not equally adjusted for all param groups for optimizer 2' + adjusted_lr2 = adjusted_lr2[0] + + assert init_lr * 0.1 == adjusted_lr1 and init_lr * 0.1 == adjusted_lr2, \ + 'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr1) + + +def test_multi_optimizer_with_scheduling_stepping(tmpdir): + tutils.reset_seed() + + class CurrentTestModel( + LightTestOptimizersWithMixedSchedulingMixin, + LightTrainDataloader, + TestModelBase): + pass + + hparams = tutils.get_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # fit model + trainer = Trainer(**trainer_options) + results = trainer.fit(model) + + init_lr = hparams.learning_rate + adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups] + adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups] + + assert len(trainer.lr_schedulers) == 2, \ + 'all lr scheduler not initialized properly' + + assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \ + 'lr not equally adjusted for all param groups for optimizer 1' + adjusted_lr1 = adjusted_lr1[0] + + assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \ + 'lr not equally adjusted for all param groups for optimizer 2' + adjusted_lr2 = adjusted_lr2[0] + + # Called ones after end of epoch + assert init_lr * (0.1)**3 == adjusted_lr1, \ + 'lr for optimizer 1 not adjusted correctly' + # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times + assert init_lr * 0.1 == adjusted_lr2, \ + 'lr for optimizer 2 not adjusted correctly'