Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: passing wrong strings for scheduler interval doesn't throw an error #5923

Merged
Merged
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if 'interval' in scheduler and scheduler['interval'] not in ('step', 'epoch'):
raise MisconfigurationException(
f'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)

scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau
)
Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,24 @@ def test_unknown_configure_optimizers_raises(tmpdir):
trainer.fit(model)


def test_lr_scheduler_with_unknown_interval_raises(tmpdir):
"""
Test exception when lr_scheduler dict has unknown interval param value
"""
model = BoringModel()
optimizer = torch.optim.Adam(model.parameters())
model.configure_optimizers = lambda: {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': torch.optim.lr_scheduler.StepLR(optimizer, 1),
'interval': "incorrect_unknown_value"
},
}
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(MisconfigurationException, match=r'The "interval" key in lr scheduler dict must be'):
trainer.fit(model)


def test_lr_scheduler_with_extra_keys_warns(tmpdir):
"""
Test warning when lr_scheduler dict has extra keys
Expand Down