diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 26d48c8b03e6f0..7190fccfa88dec 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -238,9 +238,9 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): def _should_skip_saving_checkpoint(self, trainer) -> bool: return ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run + trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.running_sanity_check # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved + or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == global_step # already saved at the last step ) @@ -282,9 +282,13 @@ def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.every_n_epochs == 0 or self.every_n_epochs < -1: - raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' + ) if self.every_n_batches == 0 or self.every_n_batches < -1: - raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in [None, -1, 0]: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index aacb5b73c64bce..bb7f4df8e3d707 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -499,21 +499,19 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) + def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) # These should not fail ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'):