From 82c19e144484c3d9b0c0f00169e9c3c5431ba9e3 Mon Sep 17 00:00:00 2001 From: Hemil Desai Date: Tue, 4 May 2021 15:07:40 +0530 Subject: [PATCH] =?UTF-8?q?Update=20LR=20schedulers=20only=20when=20their?= =?UTF-8?q?=20corresponding=20Optimizer=20is=20being=E2=80=A6=20(#4868)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update LR schedulers only when their corresponding Optimizer is being used. In the case when optimizer frequencies are specified, the LR scheduler corresponding to a particular optimizer is updated only when that optimizer is being used in the training loop or epoch. * pep8speak fixes * Fix failing tests * Add docs * PR Feedback * Apply suggestions from code review Co-authored-by: Rohit Gupta * formatting fix * PR Feedback - part 2 * More PR feedback * Apply suggestions from code review Co-authored-by: Rohit Gupta * Add typing imports * Stronger tests and fixes related to that * Add more tests plus PR feedback * Make optimizer_freq_cumsum a cached property @cached_property is only available after Python 3.8 so had to do it manually. * Fix tests * Apply suggestions from code review Co-authored-by: Carlos MocholĂ­ * Avoid mutable defaults * Parametrize lr scheduling tests * PR feedback * Apply suggestions from code review * spell * Apply suggestions from code review * flake8 Co-authored-by: Rohit Gupta Co-authored-by: chaton Co-authored-by: Carlos MocholĂ­ Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 + pytorch_lightning/core/lightning.py | 29 +- .../trainer/connectors/optimizer_connector.py | 12 +- pytorch_lightning/trainer/optimizers.py | 20 +- pytorch_lightning/trainer/trainer.py | 10 +- pytorch_lightning/trainer/training_loop.py | 25 +- tests/base/model_optimizers.py | 11 - tests/trainer/optimization/test_optimizers.py | 329 +++++++++++------- 8 files changed, 288 insertions(+), 151 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 48211a6363598..0250507d61f4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -364,6 +364,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed bug where `BaseFinetuning.flatten_modules()` was duplicating leaf node parameters ([#6879](https://github.com/PyTorchLightning/pytorch-lightning/pull/6879)) +- Fixed bug where the learning rate schedulers did not follow the optimizer frequencies ([#4868](https://github.com/PyTorchLightning/pytorch-lightning/pull/4868)) + + - Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c6151d96b52dd..ecab75ddb7e42 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1109,8 +1109,33 @@ def configure_optimizers(self): - **None** - Fit will run without any optimizer. Note: - The lr_dict is a dictionary which contains the scheduler and its associated configuration. The default - configuration is shown below. + The ``frequency`` value specified in a dict along with the ``optimizer`` key is an int corresponding + to the number of sequential batches optimized with the specific optimizer. + It should be given to none or to all of the optimizers. + There is a difference between passing multiple optimizers in a list, + and passing multiple optimizers in dictionaries with a frequency of 1: + In the former case, all optimizers will operate on the given batch in each optimization step. + In the latter, only one optimizer will operate on the given batch at every step. + This is different from the ``frequency`` value specified in the lr_dict mentioned below. + + .. code-block:: python + + def configure_optimizers(self): + optimizer_one = torch.optim.SGD(self.model.parameters(), lr=0.01) + optimizer_two = torch.optim.SGD(self.model.parameters(), lr=0.01) + return [ + {'optimizer': optimizer_one, 'frequency': 5}, + {'optimizer': optimizer_two, 'frequency': 10}, + ] + + In this example, the first optimizer will be used for the first 5 steps, + the second optimizer for the next 10 steps and that cycle will continue. + If an LR scheduler is specified for an optimizer using the ``lr_scheduler`` key in the above dict, + the scheduler will only be updated when its optimizer is being used. + + Note: + The lr_dict is a dictionary which contains the scheduler and its associated configuration. + The default configuration is shown below. .. code-block:: python diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 62803757aa2fe..d45cbad927936 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Optional + from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -25,7 +27,9 @@ def on_trainer_init(self): self.trainer.optimizers = [] self.trainer.optimizer_frequencies = [] - def update_learning_rates(self, interval: str, monitor_metrics=None): + def update_learning_rates( + self, interval: str, monitor_metrics: Optional[Dict[str, Any]] = None, opt_indices: Optional[List[int]] = None + ): """Update learning rates. Args: @@ -35,7 +39,13 @@ def update_learning_rates(self, interval: str, monitor_metrics=None): if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: return + if opt_indices is None: + opt_indices = [] + for scheduler_idx, lr_scheduler in enumerate(self.trainer.lr_schedulers): + if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices: + continue + current_idx = self.trainer.batch_idx if interval == 'step' else self.trainer.current_epoch current_idx += 1 # account for both batch and epoch starts from 0 # Take step if call to update_learning_rates matches the interval key and diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 5a7873232b394..b5afe7bf75168 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -46,7 +46,10 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: if isinstance(optim_conf, Optimizer): optimizers = [optim_conf] # two lists, optimizer + lr schedulers - elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list): + elif ( + isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list) + and all(isinstance(opt, Optimizer) for opt in optim_conf[0]) + ): opt, sch = optim_conf optimizers = opt lr_schedulers = sch if isinstance(sch, list) else [sch] @@ -58,7 +61,17 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: # multiple dictionaries elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf): optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] - lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict] + scheduler_dict = ( + lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx) if isinstance(scheduler, dict) else { + 'scheduler': scheduler, + 'opt_idx': opt_idx + } + ) + + lr_schedulers = [ + scheduler_dict(opt_dict["lr_scheduler"], opt_idx) for opt_idx, opt_dict in enumerate(optim_conf) + if "lr_scheduler" in opt_dict + ] optimizer_frequencies = [ opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None ] @@ -66,7 +79,7 @@ def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): raise ValueError("A frequency must be given to each optimizer.") # single list or tuple, multiple optimizer - elif isinstance(optim_conf, (list, tuple)): + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf): optimizers = list(optim_conf) # unknown configuration else: @@ -207,4 +220,5 @@ def _get_default_scheduler_config() -> Dict[str, Any]: 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler 'monitor': None, # value to monitor for ReduceLROnPlateau 'strict': True, # enforce that the monitor exists for ReduceLROnPlateau + 'opt_idx': None, # necessary to store opt_idx when optimizer frequencies are specified } diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 620d0ccfab043..14264804b34f6 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -985,7 +985,15 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT: # update epoch-level lr_schedulers if on_epoch: - self.optimizer_connector.update_learning_rates(interval='epoch') + self.optimizer_connector.update_learning_rates( + interval='epoch', + opt_indices=[ + opt_idx + for opt_idx, _ in self.train_loop.get_optimizers_iterable(batch_idx=( + self.total_batch_idx - 1 + )) # Select the optimizers which were used in the last batch of the epoch + ], + ) # hook self.evaluation_loop.on_evaluation_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f96c17a0686ce..4059235dd84d5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -47,6 +47,7 @@ def __init__(self, trainer, multiple_trainloader_mode: str): self._multiple_trainloader_mode = multiple_trainloader_mode self._skip_backward = False self.trainer._multiple_trainloader_mode = multiple_trainloader_mode + self._optimizer_freq_cumsum = None def on_trainer_init( self, @@ -83,6 +84,12 @@ def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable()) return num_optimizers + @property + def optimizer_freq_cumsum(self): + if self._optimizer_freq_cumsum is None: + self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) + return self._optimizer_freq_cumsum + def should_skip_training(self): should_by_max_steps = self.trainer.max_steps is not None and self.trainer.global_step >= self.trainer.max_steps should_by_epoch = self.trainer.max_epochs is not None and self.trainer.current_epoch >= self.trainer.max_epochs @@ -211,7 +218,7 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): epoch_output[opt_idx].append(opt_outputs) - def get_optimizers_iterable(self): + def get_optimizers_iterable(self, batch_idx=None): """ Generates an iterable with (idx, optimizer) for each optimizer. """ @@ -219,12 +226,14 @@ def get_optimizers_iterable(self): # call training_step once per optimizer return list(enumerate(self.trainer.optimizers)) - optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) - optimizers_loop_length = optimizer_freq_cumsum[-1] - current_place_in_loop = self.trainer.total_batch_idx % optimizers_loop_length + if batch_idx is None: + batch_idx = self.trainer.total_batch_idx + + optimizers_loop_length = self.optimizer_freq_cumsum[-1] + current_place_in_loop = batch_idx % optimizers_loop_length # find optimzier index by looking for the first {item > current_place} in the cumsum list - opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) + opt_idx = np.argmax(self.optimizer_freq_cumsum > current_place_in_loop) return [[opt_idx, self.trainer.optimizers[opt_idx]]] def on_after_backward(self, training_step_output, batch_idx, untouched_loss): @@ -801,7 +810,11 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None): if num_accumulated_batches_reached or num_training_batches_reached: # update lr - self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics) + self.trainer.optimizer_connector.update_learning_rates( + interval="step", + monitor_metrics=monitor_metrics, + opt_indices=[opt_idx for opt_idx, _ in self.get_optimizers_iterable()], + ) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py index 39e67748f0a90..5d18382edb0a9 100644 --- a/tests/base/model_optimizers.py +++ b/tests/base/model_optimizers.py @@ -26,9 +26,6 @@ def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer - def configure_optimizers__empty(self): - return None - def configure_optimizers__lbfgs(self): """ return whatever optimizers we want here. @@ -41,14 +38,6 @@ def configure_optimizers__adagrad(self): optimizer = optim.Adagrad(self.parameters(), lr=self.learning_rate) return optimizer - def configure_optimizers__multiple_optimizers_frequency(self): - optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate) - optimizer2 = optim.Adam(self.parameters(), lr=self.learning_rate) - return [ - dict(optimizer=optimizer1, frequency=1), - dict(optimizer=optimizer2, frequency=5), - ] - def configure_optimizers__single_scheduler(self): optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 71ef6e49385aa..f078f744d9ab3 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch +from torch import optim from pytorch_lightning import Callback, Trainer from pytorch_lightning.trainer.states import TrainerState @@ -25,11 +26,7 @@ def test_optimizer_with_scheduling(tmpdir): """ Verify that learning rate scheduling is working """ - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - model.configure_optimizers = model.configure_optimizers__single_scheduler - - # fit model + model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -40,27 +37,32 @@ def test_optimizer_with_scheduling(tmpdir): trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - init_lr = hparams.get('learning_rate') + init_lr = 0.1 adjusted_lr = [pg['lr'] for pg in trainer.optimizers[0].param_groups] - assert len(trainer.lr_schedulers) == 1, \ - 'lr scheduler not initialized properly, it has %i elements instread of 1' % len(trainer.lr_schedulers) + assert len(trainer.lr_schedulers) == 1 + assert all(a == adjusted_lr[0] for a in adjusted_lr) + assert init_lr * 0.1 == adjusted_lr[0] - assert all(a == adjusted_lr[0] for a in adjusted_lr), \ - 'Lr not equally adjusted for all param groups' - adjusted_lr = adjusted_lr[0] - assert init_lr * 0.1 == adjusted_lr, \ - 'Lr not adjusted correctly, expected %f but got %f' % (init_lr * 0.1, adjusted_lr) +def test_multi_optimizer_with_scheduling(tmpdir): + """ Verify that learning rate scheduling is working """ + class TestModel(BoringModel): + init_lr = 5e-4 -def test_multi_optimizer_with_scheduling_stepping(tmpdir): + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - model.configure_optimizers = model.configure_optimizers__multiple_schedulers + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=self.init_lr) + optimizer2 = optim.Adam(self.parameters(), lr=self.init_lr) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, step_size=1) + return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] - # fit model + model = TestModel() + model.training_epoch_end = None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -70,24 +72,14 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir): trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" - init_lr = hparams.get('learning_rate') adjusted_lr1 = [pg['lr'] for pg in trainer.optimizers[0].param_groups] adjusted_lr2 = [pg['lr'] for pg in trainer.optimizers[1].param_groups] - assert len(trainer.lr_schedulers) == 2, 'all lr scheduler not initialized properly' - - assert all(a == adjusted_lr1[0] for a in adjusted_lr1), \ - 'lr not equally adjusted for all param groups for optimizer 1' - adjusted_lr1 = adjusted_lr1[0] - - assert all(a == adjusted_lr2[0] for a in adjusted_lr2), \ - 'lr not equally adjusted for all param groups for optimizer 2' - adjusted_lr2 = adjusted_lr2[0] - - # Called ones after end of epoch - assert init_lr * 0.1 == adjusted_lr1, 'lr for optimizer 1 not adjusted correctly' - # Called every 3 steps, meaning for 1 epoch of 11 batches, it is called 3 times - assert init_lr * 0.1 == adjusted_lr2, 'lr for optimizer 2 not adjusted correctly' + assert len(trainer.lr_schedulers) == 2 + assert all(a == adjusted_lr1[0] for a in adjusted_lr1) + assert all(a == adjusted_lr2[0] for a in adjusted_lr2) + assert model.init_lr * 0.1 == adjusted_lr1[0] + assert model.init_lr * 0.1 == adjusted_lr2[0] def test_reducelronplateau_with_no_monitor_raises(tmpdir): @@ -95,8 +87,8 @@ def test_reducelronplateau_with_no_monitor_raises(tmpdir): Test exception when a ReduceLROnPlateau is used with no monitor """ model = EvalModelTemplate() - optimizer = torch.optim.Adam(model.parameters()) - model.configure_optimizers = lambda: ([optimizer], [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)]) + optimizer = optim.Adam(model.parameters()) + model.configure_optimizers = lambda: ([optimizer], [optim.lr_scheduler.ReduceLROnPlateau(optimizer)]) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.raises( MisconfigurationException, match='`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`' @@ -109,11 +101,11 @@ def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir): Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor """ model = EvalModelTemplate() - optimizer = torch.optim.Adam(model.parameters()) + optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: { 'optimizer': optimizer, 'lr_scheduler': { - 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), }, } trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -122,53 +114,63 @@ def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir): def test_reducelronplateau_scheduling(tmpdir): - model = EvalModelTemplate() - optimizer = torch.optim.Adam(model.parameters()) - model.configure_optimizers = lambda: { - 'optimizer': optimizer, - 'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer), - 'monitor': 'val_acc', - } + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx): + self.log("foo", batch_idx) + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = optim.Adam(self.parameters()) + return { + 'optimizer': optimizer, + 'lr_scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': 'foo', + } + + model = TestModel() trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + lr_scheduler = trainer.lr_schedulers[0] assert lr_scheduler == dict( scheduler=lr_scheduler['scheduler'], - monitor='val_acc', + monitor='foo', interval='epoch', frequency=1, reduce_on_plateau=True, strict=True, + opt_idx=None, name=None, - ), 'lr scheduler was not correctly converted to dict' + ) -def test_optimizer_return_options(): - trainer = Trainer() - model = EvalModelTemplate() +def test_optimizer_return_options(tmpdir): + trainer = Trainer(default_root_dir=tmpdir) + model = BoringModel() # single optimizer - opt_a = torch.optim.Adam(model.parameters(), lr=0.002) - opt_b = torch.optim.SGD(model.parameters(), lr=0.002) - scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10) - scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10) + opt_a = optim.Adam(model.parameters(), lr=0.002) + opt_b = optim.SGD(model.parameters(), lr=0.002) + scheduler_a = optim.lr_scheduler.StepLR(opt_a, 10) + scheduler_b = optim.lr_scheduler.StepLR(opt_b, 10) # single optimizer model.configure_optimizers = lambda: opt_a - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == len(freq) == 0 + opt, lr_sched, freq = trainer.init_optimizers(model) + assert len(opt) == 1 and len(lr_sched) == len(freq) == 0 # opt tuple model.configure_optimizers = lambda: (opt_a, opt_b) - optim, lr_sched, freq = trainer.init_optimizers(model) - assert optim == [opt_a, opt_b] + opt, lr_sched, freq = trainer.init_optimizers(model) + assert opt == [opt_a, opt_b] assert len(lr_sched) == len(freq) == 0 # opt list model.configure_optimizers = lambda: [opt_a, opt_b] - optim, lr_sched, freq = trainer.init_optimizers(model) - assert optim == [opt_a, opt_b] + opt, lr_sched, freq = trainer.init_optimizers(model) + assert opt == [opt_a, opt_b] assert len(lr_sched) == len(freq) == 0 ref_lr_sched = dict( @@ -179,30 +181,31 @@ def test_optimizer_return_options(): monitor=None, strict=True, name=None, + opt_idx=None, ) # opt tuple of 2 lists model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == len(lr_sched) == 1 + opt, lr_sched, freq = trainer.init_optimizers(model) + assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 - assert optim[0] == opt_a + assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched # opt tuple of 1 list model.configure_optimizers = lambda: ([opt_a], scheduler_a) - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == len(lr_sched) == 1 + opt, lr_sched, freq = trainer.init_optimizers(model) + assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 - assert optim[0] == opt_a + assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched # opt single dictionary model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a} - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == len(lr_sched) == 1 + opt, lr_sched, freq = trainer.init_optimizers(model) + assert len(opt) == len(lr_sched) == 1 assert len(freq) == 0 - assert optim[0] == opt_a + assert opt[0] == opt_a assert lr_sched[0] == ref_lr_sched # opt multiple dictionaries with frequencies @@ -218,74 +221,130 @@ def test_optimizer_return_options(): "frequency": 5 }, ) - optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == len(lr_sched) == len(freq) == 2 - assert optim[0] == opt_a + opt, lr_sched, freq = trainer.init_optimizers(model) + assert len(opt) == len(lr_sched) == len(freq) == 2 + assert opt[0] == opt_a + ref_lr_sched["opt_idx"] = 0 assert lr_sched[0] == ref_lr_sched + ref_lr_sched["scheduler"] = scheduler_b + ref_lr_sched["opt_idx"] = 1 + assert lr_sched[1] == ref_lr_sched assert freq == [1, 5] -def test_none_optimizer_warning(): - - trainer = Trainer() - - model = EvalModelTemplate() - model.configure_optimizers = model.configure_optimizers__empty - - with pytest.warns(UserWarning, match='will run with no optimizer'): - _, __, ___ = trainer.init_optimizers(model) - - def test_none_optimizer(tmpdir): - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - model.configure_optimizers = model.configure_optimizers__empty - - # fit model + model = BoringModel() + model.configure_optimizers = lambda: None trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, ) - trainer.fit(model) - - # verify training completed + with pytest.warns(UserWarning, match='will run with no optimizer'): + trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" def test_configure_optimizer_from_dict(tmpdir): """Tests if `configure_optimizer` method could return a dictionary with `optimizer` field only.""" - class CurrentModel(EvalModelTemplate): + class TestModel(BoringModel): def configure_optimizers(self): - config = {'optimizer': torch.optim.SGD(params=self.parameters(), lr=1e-03)} + config = {'optimizer': optim.SGD(params=self.parameters(), lr=1e-03)} return config - hparams = EvalModelTemplate.get_default_hparams() - model = CurrentModel(**hparams) - - # fit model + model = TestModel() trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1, + fast_dev_run=True, ) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -def test_configure_optimizers_with_frequency(tmpdir): +@pytest.mark.parametrize( + 'schedulers, kwargs, intervals, frequencies, expected_steps, max_epochs', + [ + ( + (optim.lr_scheduler.OneCycleLR, optim.lr_scheduler.OneCycleLR), + (dict(max_lr=0.01, total_steps=3), dict(max_lr=0.01, total_steps=2)), + ('step', 'step'), + (3, 2), + (4, 3), + 1, + ), + ( + (optim.lr_scheduler.OneCycleLR, optim.lr_scheduler.OneCycleLR), + (dict(max_lr=0.01, total_steps=5), dict(max_lr=0.01, total_steps=5)), + ('step', 'step'), + (None, None), + (6, 6), + 1, + ), + ( + (optim.lr_scheduler.StepLR, optim.lr_scheduler.CosineAnnealingLR), + (dict(step_size=5), dict(T_max=2)), + ('epoch', 'epoch'), + (5, 10), + (2, 3), + 3, + ), + ], +) +def test_step_scheduling_for_multiple_optimizers_with_frequency( + tmpdir, schedulers, kwargs, intervals, frequencies, expected_steps, max_epochs +): """ - Test that multiple optimizers work when corresponding frequency is set. + Test that step LR schedulers for multiple optimizers follow + the optimizer frequencies when corresponding frequency is set. """ - model = EvalModelTemplate() - model.configure_optimizers = model.configure_optimizers__multiple_optimizers_frequency - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + class DummyModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx): + return super().training_step(batch, batch_idx) + + def training_epoch_end(self, outputs) -> None: + pass + + def configure_optimizers(self): + optimizer1 = optim.Adam(self.parameters(), lr=0.01) + optimizer2 = optim.Adam(self.parameters(), lr=0.01) + + lr_dict_1 = { + 'scheduler': schedulers[0](optimizer1, **kwargs[0]), + 'interval': intervals[0], + } + lr_dict_2 = { + 'scheduler': schedulers[1](optimizer2, **kwargs[1]), + 'interval': intervals[1], + } + + return [ + { + 'optimizer': optimizer1, + 'frequency': frequencies[0], + 'lr_scheduler': lr_dict_1 + }, + { + 'optimizer': optimizer2, + 'frequency': frequencies[1], + 'lr_scheduler': lr_dict_2 + }, + ] + + model = DummyModel() + + trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=1, limit_train_batches=5, max_epochs=max_epochs) trainer.fit(model) - assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + assert trainer.lr_schedulers[0]['opt_idx'] == 0 + assert trainer.lr_schedulers[1]['opt_idx'] == 1 + # Step count is 1 greater than the expected value because scheduler.step() is called once during initialization + assert trainer.lr_schedulers[0]['scheduler']._step_count == expected_steps[0] + assert trainer.lr_schedulers[1]['scheduler']._step_count == expected_steps[1] @pytest.mark.parametrize("fn", ("validate", "test")) @@ -297,10 +356,10 @@ def test_init_optimizers_during_evaluation(tmpdir, fn): class TestModel(BoringModel): def configure_optimizers(self): - optimizer1 = torch.optim.Adam(self.parameters(), lr=0.1) - optimizer2 = torch.optim.Adam(self.parameters(), lr=0.1) - lr_scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer1, step_size=1) - lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer2, step_size=1) + optimizer1 = optim.Adam(self.parameters(), lr=0.1) + optimizer2 = optim.Adam(self.parameters(), lr=0.1) + lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, step_size=1) + lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, step_size=1) return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2] trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=10, limit_test_batches=10) @@ -344,8 +403,8 @@ def training_step(self, batch, batch_idx, optimizer_idx): return acc def configure_optimizers(self): - a = torch.optim.RMSprop(self.layer_1.parameters(), 1e-2) - b = torch.optim.RMSprop(self.layer_2.parameters(), 1e-2) + a = optim.RMSprop(self.layer_1.parameters(), 1e-2) + b = optim.RMSprop(self.layer_2.parameters(), 1e-2) return a, b model = TestModel() @@ -366,8 +425,8 @@ def test_lr_scheduler_strict(tmpdir): Test "strict" support in lr_scheduler dict """ model = EvalModelTemplate() - optimizer = torch.optim.Adam(model.parameters()) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) + optimizer = optim.Adam(model.parameters()) + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) model.configure_optimizers = lambda: { @@ -402,7 +461,7 @@ def test_unknown_configure_optimizers_raises(tmpdir): """ Test exception with an unsupported configure_optimizers return """ - model = EvalModelTemplate() + model = BoringModel() model.configure_optimizers = lambda: 1 trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) with pytest.raises(MisconfigurationException, match="Unknown configuration for model optimizers"): @@ -414,11 +473,11 @@ def test_lr_scheduler_with_unknown_interval_raises(tmpdir): Test exception when lr_scheduler dict has unknown interval param value """ model = BoringModel() - optimizer = torch.optim.Adam(model.parameters()) + optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: { 'optimizer': optimizer, 'lr_scheduler': { - 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1), + 'scheduler': optim.lr_scheduler.StepLR(optimizer, 1), 'interval': "incorrect_unknown_value" }, } @@ -431,12 +490,12 @@ def test_lr_scheduler_with_extra_keys_warns(tmpdir): """ Test warning when lr_scheduler dict has extra keys """ - model = EvalModelTemplate() - optimizer = torch.optim.Adam(model.parameters()) + model = BoringModel() + optimizer = optim.Adam(model.parameters()) model.configure_optimizers = lambda: { 'optimizer': optimizer, 'lr_scheduler': { - 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1), + 'scheduler': optim.lr_scheduler.StepLR(optimizer, 1), 'foo': 1, 'bar': 2, }, @@ -450,9 +509,9 @@ def test_lr_scheduler_with_no_actual_scheduler_raises(tmpdir): """ Test exception when lr_scheduler dict has no scheduler """ - model = EvalModelTemplate() + model = BoringModel() model.configure_optimizers = lambda: { - 'optimizer': torch.optim.Adam(model.parameters()), + 'optimizer': optim.Adam(model.parameters()), 'lr_scheduler': {}, } trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) @@ -468,9 +527,9 @@ def test_invalid_optimizer_in_scheduler(tmpdir): class InvalidOptimizerModel(BoringModel): def configure_optimizers(self): - opt1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - opt2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - lr_scheduler = torch.optim.lr_scheduler.StepLR(opt2, step_size=1) + opt1 = optim.SGD(self.layer.parameters(), lr=0.1) + opt2 = optim.SGD(self.layer.parameters(), lr=0.1) + lr_scheduler = optim.lr_scheduler.StepLR(opt2, step_size=1) return [opt1], [lr_scheduler] model = InvalidOptimizerModel() @@ -479,6 +538,22 @@ def configure_optimizers(self): trainer.fit(model) +def test_invalid_optimizer_dict_raises(tmpdir): + """ + Test exception when lr_scheduler dict has no scheduler + """ + + class DummyModel(BoringModel): + + def configure_optimizers(self): + return [{'optimizer': optim.Adam(self.parameters())}, optim.Adam(self.parameters())] + + model = DummyModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match='Unknown configuration for model optimizers'): + trainer.fit(model) + + def test_warn_invalid_scheduler_key_in_manual_optimization(tmpdir): """ Test warning when invalid scheduler keys are provided in manual optimization. @@ -491,8 +566,8 @@ def __init__(self): self.automatic_optimization = False def configure_optimizers(self): - opt = torch.optim.SGD(self.layer.parameters(), lr=0.1) - sch = torch.optim.lr_scheduler.StepLR(opt, step_size=1) + opt = optim.SGD(self.layer.parameters(), lr=0.1) + sch = optim.lr_scheduler.StepLR(opt, step_size=1) return [opt], [{"scheduler": sch, "interval": "epoch"}] model = TestModel() @@ -505,7 +580,7 @@ class TestModel(BoringModel): def configure_optimizers(self): # Adagrad creates state tensors immediately, model is not yet on GPU. - return torch.optim.Adagrad(self.parameters()) + return optim.Adagrad(self.parameters()) def on_train_start(self, *args, **kwargs): opt = self.optimizers()