Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Mar 2, 2021
1 parent 02ec2d2 commit 78b2562
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
12 changes: 8 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):

def _should_skip_saving_checkpoint(self, trainer) -> bool:
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.running_sanity_check # don't save anything during sanity check
or self.save_top_k == 0 # no models are saved
or self.save_top_k == 0 # no models are saved
or self._last_global_step_saved == global_step # already saved at the last step
)

Expand Down Expand Up @@ -282,9 +282,13 @@ def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
if self.every_n_epochs == 0 or self.every_n_epochs < -1:
raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1')
raise MisconfigurationException(
f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1'
)
if self.every_n_batches == 0 or self.every_n_batches < -1:
raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1')
raise MisconfigurationException(
f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1'
)
if self.monitor is None:
# None: save last epoch, -1: save all epochs, 0: nothing is saved
if self.save_top_k not in [None, -1, 0]:
Expand Down
10 changes: 4 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,21 +499,19 @@ def test_none_monitor_top_k(tmpdir):
ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
ModelCheckpoint(dirpath=tmpdir, save_top_k=0)


def test_invalid_every_n_epoch(tmpdir):
""" Test that an exception is raised for every_n_epochs = 0 or < -1. """
with pytest.raises(
MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'
):
with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'):
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0)
with pytest.raises(
MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'
):
with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'):
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2)

# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1)
ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3)


def test_invalid_every_n_batches(tmpdir):
""" Test that an exception is raised for every_n_batches = 0 or < -1. """
with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'):
Expand Down

0 comments on commit 78b2562

Please sign in to comment.