diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 6772dcc645e3b..6793a370fdc35 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -117,6 +117,12 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None): raise MisconfigurationException( 'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler' ) + if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'): + raise MisconfigurationException( + f'The "interval" key in lr scheduler dict must be "step" or "epoch"' + f' but is "{scheduler["interval"]}"' + ) + scheduler['reduce_on_plateau'] = isinstance( scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau ) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index c9a9250995dd0..7172b2dca76da 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -459,6 +459,24 @@ def test_unknown_configure_optimizers_raises(tmpdir): trainer.fit(model) +def test_lr_scheduler_with_unknown_interval_raises(tmpdir): + """ + Test exception when lr_scheduler dict has unknown interval param value + """ + model = BoringModel() + optimizer = torch.optim.Adam(model.parameters()) + model.configure_optimizers = lambda: { + 'optimizer': optimizer, + 'lr_scheduler': { + 'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1), + 'interval': "incorrect_unknown_value" + }, + } + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be'): + trainer.fit(model) + + def test_lr_scheduler_with_extra_keys_warns(tmpdir): """ Test warning when lr_scheduler dict has extra keys