From 2549ca40e65ffb9dddbe5df923e35110503f19d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 21 Oct 2020 21:12:48 +0200 Subject: [PATCH] Clean up optimizer code (#3587) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update optimizer code * Update CHANGELOG * Fix tuple of one list case * Update docs * Fix pep issue * Minor typo [skip-ci] * Use minimal match Co-authored-by: Adrian Wälchli * Apply suggestions from code review Co-authored-by: Rohit Gupta Co-authored-by: Adrian Wälchli Co-authored-by: Rohit Gupta --- CHANGELOG.md | 1 + docs/source/optimizers.rst | 30 ++- .../trainer/connectors/optimizer_connector.py | 54 ++---- pytorch_lightning/trainer/optimizers.py | 133 +++++++------- tests/base/model_optimizers.py | 5 - tests/trainer/test_optimizers.py | 172 ++++++++++++------ 6 files changed, 223 insertions(+), 172 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae339289c3b93..83605440f90b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587)) ### Deprecated diff --git a/docs/source/optimizers.rst b/docs/source/optimizers.rst index ce4e057542c08..2b8025959c9ca 100644 --- a/docs/source/optimizers.rst +++ b/docs/source/optimizers.rst @@ -101,26 +101,46 @@ Every optimizer you use can be paired with any `LearningRateScheduler 1 optimizers from :meth:`pytorch_lightning.c # Two optimizers, one scheduler for adam only def configure_optimizers(self): - return [Adam(...), SGD(...)], [ReduceLROnPlateau()] + return [Adam(...), SGD(...)], {'scheduler': ReduceLROnPlateau(), 'monitor': 'metric_to_track'} Lightning will call each optimizer sequentially: diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 4938e6673b06a..3987b1b64ac25 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -16,7 +16,6 @@ class OptimizerConnector: - def __init__(self, trainer): self.trainer = trainer @@ -41,21 +40,15 @@ def update_learning_rates(self, interval: str, monitor_metrics=None): # Take step if call to update_learning_rates matches the interval key and # the current step modulo the schedulers frequency is zero if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0: - # If instance of ReduceLROnPlateau, we need to pass validation loss + # If instance of ReduceLROnPlateau, we need a monitor + monitor_key, monitor_val = None, None if lr_scheduler['reduce_on_plateau']: - try: - monitor_key = lr_scheduler['monitor'] - except KeyError as e: - m = "ReduceLROnPlateau requires returning a dict from configure_optimizers with the keyword " \ - "monitor=. For example:" \ - "return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'your_loss'}" - raise MisconfigurationException(m) - - if monitor_metrics is not None: - monitor_val = monitor_metrics.get(monitor_key) - else: - monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key) - + monitor_key = lr_scheduler['monitor'] + monitor_val = ( + monitor_metrics.get(monitor_key) + if monitor_metrics is not None + else self.trainer.logger_connector.callback_metrics.get(monitor_key) + ) if monitor_val is None: if lr_scheduler.get('strict', True): avail_metrics = self.trainer.logger_connector.callback_metrics.keys() @@ -71,30 +64,15 @@ def update_learning_rates(self, interval: str, monitor_metrics=None): RuntimeWarning, ) continue - # update LR - old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + # update LR + old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) - new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] - - if self.trainer.dev_debugger.enabled: - self.trainer.dev_debugger.track_lr_schedulers_update( - self.trainer.batch_idx, - interval, - scheduler_idx, - old_lr, - new_lr, - monitor_key, - ) else: - # update LR - old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] lr_scheduler['scheduler'].step() - new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] - if self.trainer.dev_debugger.enabled: - self.trainer.dev_debugger.track_lr_schedulers_update( - self.trainer.batch_idx, - interval, - scheduler_idx, - old_lr, new_lr - ) + if self.trainer.dev_debugger.enabled: + self.trainer.dev_debugger.track_lr_schedulers_update( + self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key + ) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 56404ab391a70..a6b63002dcc40 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -21,111 +21,107 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException class TrainerOptimizersMixin(ABC): - - def init_optimizers( - self, - model: LightningModule - ) -> Tuple[List, List, List]: + def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: optim_conf = model.configure_optimizers() - if optim_conf is None: - rank_zero_warn('`LightningModule.configure_optimizers` returned `None`, ' - 'this fit will run with no optimizer', UserWarning) + rank_zero_warn( + '`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer', + UserWarning, + ) optim_conf = _MockOptimizer() + optimizers, lr_schedulers, optimizer_frequencies = [], [], [] + monitor = None + # single output, single optimizer if isinstance(optim_conf, Optimizer): - return [optim_conf], [], [] - + optimizers = [optim_conf] # two lists, optimizer + lr schedulers - elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ - and isinstance(optim_conf[0], list): - optimizers, lr_schedulers = optim_conf - lr_schedulers = self.configure_schedulers(lr_schedulers) - return optimizers, lr_schedulers, [] - + elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list): + opt, sch = optim_conf + optimizers = opt + lr_schedulers = sch if isinstance(sch, list) else [sch] # single dictionary elif isinstance(optim_conf, dict): - optimizer = optim_conf["optimizer"] + optimizers = [optim_conf["optimizer"]] monitor = optim_conf.get('monitor', None) - lr_scheduler = optim_conf.get("lr_scheduler", []) - if lr_scheduler: - lr_schedulers = self.configure_schedulers([lr_scheduler], monitor) - else: - lr_schedulers = [] - return [optimizer], lr_schedulers, [] - + lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else [] # multiple dictionaries - elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): + elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf): optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] - # take only lr wif exists and ot they are defined - not None - lr_schedulers = [ - opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler") - ] - # take only freq wif exists and ot they are defined - not None + lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if "lr_scheduler" in opt_dict] optimizer_frequencies = [ - opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") is not None + opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None ] - - # clean scheduler list - if lr_schedulers: - lr_schedulers = self.configure_schedulers(lr_schedulers) # assert that if frequencies are present, they are given for all optimizers if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): raise ValueError("A frequency must be given to each optimizer.") - return optimizers, lr_schedulers, optimizer_frequencies - # single list or tuple, multiple optimizer elif isinstance(optim_conf, (list, tuple)): - return list(optim_conf), [], [] - + optimizers = list(optim_conf) # unknown configuration else: - raise ValueError( + raise MisconfigurationException( 'Unknown configuration for model optimizers.' - ' Output from `model.configure_optimizers()` should either be:' - ' * single output, single `torch.optim.Optimizer`' - ' * single output, list of `torch.optim.Optimizer`' - ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)' - ' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)' - ' * two outputs, first being a list of `torch.optim.Optimizer` second being' - ' a list of `torch.optim.lr_scheduler`' - ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)') + ' Output from `model.configure_optimizers()` should either be:\n' + ' * `torch.optim.Optimizer`\n' + ' * [`torch.optim.Optimizer`]\n' + ' * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n' + ' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n' + ' * A list of the previously described dict format, with an optional "frequency" key (int)' + ) + lr_schedulers = self.configure_schedulers(lr_schedulers, monitor=monitor) + + return optimizers, lr_schedulers, optimizer_frequencies def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): # Convert each scheduler into dict structure with relevant information lr_schedulers = [] default_config = { - 'interval': 'epoch', # default every epoch - 'frequency': 1, # default every epoch/batch - 'reduce_on_plateau': False - } # most often not ReduceLROnPlateau scheduler - - if monitor is not None: - default_config['monitor'] = monitor - + 'scheduler': None, + 'interval': 'epoch', # after epoch is over + 'frequency': 1, # every epoch/batch + 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler + 'monitor': monitor, # value to monitor for ReduceLROnPlateau + 'strict': True, # enforce that the monitor exists for ReduceLROnPlateau + } for scheduler in schedulers: if isinstance(scheduler, dict): + # check provided keys + extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()] + if extra_keys: + rank_zero_warn(f'Found unsupported keys in the lr scheduler dict: {extra_keys}', RuntimeWarning) if 'scheduler' not in scheduler: - raise ValueError('Lr scheduler should have key `scheduler`', - ' with item being a lr scheduler') + raise MisconfigurationException( + 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' + ) scheduler['reduce_on_plateau'] = isinstance( - scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau) - + scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau + ) + if scheduler['reduce_on_plateau'] and scheduler.get('monitor', None) is None: + raise MisconfigurationException( + 'The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used.' + ' For example: {"optimizer": optimizer, "lr_scheduler":' + ' {"scheduler": scheduler, "monitor": "your_loss"}}' + ) lr_schedulers.append({**default_config, **scheduler}) - elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): - lr_schedulers.append({**default_config, 'scheduler': scheduler, - 'reduce_on_plateau': True}) - + if monitor is None: + raise MisconfigurationException( + '`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.' + ' For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' + ) + lr_schedulers.append( + {**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor} + ) elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): lr_schedulers.append({**default_config, 'scheduler': scheduler}) else: - raise ValueError(f'Input {scheduler} to lr schedulers ' - 'is a invalid input.') + raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid') return lr_schedulers def reinit_scheduler_properties(self, optimizers: list, schedulers: list): @@ -138,10 +134,7 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): if scheduler.optimizer == optimizer: # Find the mro belonging to the base lr scheduler class for i, mro in enumerate(scheduler.__class__.__mro__): - if ( - mro == optim.lr_scheduler._LRScheduler - or mro == optim.lr_scheduler.ReduceLROnPlateau - ): + if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau): idx = i state = scheduler.state_dict() else: diff --git a/tests/base/model_optimizers.py b/tests/base/model_optimizers.py index 194c6ce6baa70..e4b8d489f872d 100644 --- a/tests/base/model_optimizers.py +++ b/tests/base/model_optimizers.py @@ -81,11 +81,6 @@ def configure_optimizers__mixed_scheduling(self): return [optimizer1, optimizer2], \ [{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2] - def configure_optimizers__reduce_lr_on_plateau(self): - optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) - lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) - return [optimizer], [lr_scheduler] - def configure_optimizers__param_groups(self): param_groups = [ {'params': list(self.parameters())[:2], 'lr': self.learning_rate * 0.1}, diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 78dc3e9122c69..c28e626f2eec0 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -15,8 +15,8 @@ import torch from pytorch_lightning import Trainer, Callback -from tests.base import EvalModelTemplate from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate from tests.base.boring_model import BoringModel @@ -126,55 +126,60 @@ def test_multi_optimizer_with_scheduling_stepping(tmpdir): 'lr for optimizer 2 not adjusted correctly' -def test_reduce_lr_on_plateau_scheduling_missing_monitor(tmpdir): - - hparams = EvalModelTemplate.get_default_hparams() - model = EvalModelTemplate(**hparams) - model.configure_optimizers = model.configure_optimizers__reduce_lr_on_plateau - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - ) - - m = '.*ReduceLROnPlateau requires returning a dict from configure_optimizers.*' - with pytest.raises(MisconfigurationException, match=m): +def test_reducelronplateau_with_no_monitor_raises(tmpdir): + """ + Test exception when a ReduceLROnPlateau is used with no monitor + """ + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: ([optimizer], [torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)]) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises( + MisconfigurationException, match='`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`' + ): trainer.fit(model) -def test_reduce_lr_on_plateau_scheduling(tmpdir): - hparams = EvalModelTemplate.get_default_hparams() - - class TestModel(EvalModelTemplate): - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) - lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) - return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'monitor': 'early_stop_on'} +def test_reducelronplateau_with_no_monitor_in_lr_scheduler_dict_raises(tmpdir): + """ + Test exception when lr_scheduler dict has a ReduceLROnPlateau with no monitor + """ + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer), + }, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match='must include a monitor when a `ReduceLROnPlateau`'): + trainer.fit(model) - model = TestModel(**hparams) - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=0.1, - limit_train_batches=0.2, - ) +def test_reducelronplateau_scheduling(tmpdir): + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': 'early_stop_on', + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) results = trainer.fit(model) assert results == 1 - - assert trainer.lr_schedulers[0] == \ - dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='early_stop_on', - interval='epoch', frequency=1, reduce_on_plateau=True), \ - 'lr schduler was not correctly converted to dict' + lr_scheduler = trainer.lr_schedulers[0] + assert lr_scheduler == dict( + scheduler=lr_scheduler['scheduler'], + monitor='early_stop_on', + interval='epoch', + frequency=1, + reduce_on_plateau=True, + strict=True, + ), 'lr scheduler was not correctly converted to dict' def test_optimizer_return_options(): - trainer = Trainer() model = EvalModelTemplate() @@ -187,35 +192,49 @@ def test_optimizer_return_options(): # single optimizer model.configure_optimizers = lambda: opt_a optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0 + assert len(optim) == 1 and len(lr_sched) == len(freq) == 0 # opt tuple model.configure_optimizers = lambda: (opt_a, opt_b) optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b - assert len(lr_sched) == 0 and len(freq) == 0 + assert optim == [opt_a, opt_b] + assert len(lr_sched) == len(freq) == 0 # opt list model.configure_optimizers = lambda: [opt_a, opt_b] optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b - assert len(lr_sched) == 0 and len(freq) == 0 + assert optim == [opt_a, opt_b] + assert len(lr_sched) == len(freq) == 0 # opt tuple of 2 lists model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 + assert len(optim) == len(lr_sched) == 1 + assert len(freq) == 0 + assert optim[0] == opt_a + assert lr_sched[0] == dict( + scheduler=scheduler_a, interval='epoch', frequency=1, reduce_on_plateau=False, monitor=None, strict=True + ) + + # opt tuple of 1 list + model.configure_optimizers = lambda: ([opt_a], scheduler_a) + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == len(lr_sched) == 1 + assert len(freq) == 0 assert optim[0] == opt_a - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False) + assert lr_sched[0] == dict( + scheduler=scheduler_a, interval='epoch', frequency=1, reduce_on_plateau=False, monitor=None, strict=True + ) # opt single dictionary model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a} optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 + assert len(optim) == len(lr_sched) == 1 + assert len(freq) == 0 assert optim[0] == opt_a - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False) + assert lr_sched[0] == dict( + scheduler=scheduler_a, interval='epoch', frequency=1, reduce_on_plateau=False, monitor=None, strict=True + ) # opt multiple dictionaries with frequencies model.configure_optimizers = lambda: ( @@ -223,10 +242,11 @@ def test_optimizer_return_options(): {"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5}, ) optim, lr_sched, freq = trainer.init_optimizers(model) - assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2 + assert len(optim) == len(lr_sched) == len(freq) == 2 assert optim[0] == opt_a - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False) + assert lr_sched[0] == dict( + scheduler=scheduler_a, interval='epoch', frequency=1, reduce_on_plateau=False, monitor=None, strict=True + ) assert freq == [1, 5] @@ -393,3 +413,47 @@ def test_lr_scheduler_strict(tmpdir): RuntimeWarning, match=r'ReduceLROnPlateau conditioned on metric .* which is not available but strict' ): assert trainer.fit(model) + + +def test_unknown_configure_optimizers_raises(tmpdir): + """ + Test exception with an unsupported configure_optimizers return + """ + model = EvalModelTemplate() + model.configure_optimizers = lambda: 1 + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match="Unknown configuration for model optimizers"): + trainer.fit(model) + + +def test_lr_scheduler_with_extra_keys_warns(tmpdir): + """ + Test warning when lr_scheduler dict has extra keys + """ + model = EvalModelTemplate() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1), + 'foo': 1, + 'bar': 2, + }, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.warns(RuntimeWarning, match=r'Found unsupported keys in the lr scheduler dict: \[.+\]'): + trainer.fit(model) + + +def test_lr_scheduler_with_no_actual_scheduler_raises(tmpdir): + """ + Test exception when lr_scheduler dict has no scheduler + """ + model = EvalModelTemplate() + model.configure_optimizers = lambda: { + 'optimizer': torch.optim.Adam(model.parameters()), + 'lr_scheduler': {}, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match='The lr scheduler dict must have the key "scheduler"'): + trainer.fit(model)