diff --git a/CHANGELOG.md b/CHANGELOG.md index f6ca9dffda46c9..643449b480b036 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) ### Changed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 90ac0f12f7a9b2..89da66dedd2d5f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -165,13 +165,9 @@ def __init__( every_n_epochs: int = 1, every_n_batches: int = -1, mode: str = "min", -<<<<<<< HEAD - period: Optional[int] = None, -======= - period: int = 1, every_n_epochs: int = 1, every_n_batches: int = -1, ->>>>>>> Update model_checkpoint.py + period: Optional[int] = None, ): super().__init__() self.monitor = monitor diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 18fabf771556b4..0af2464469da1d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -499,6 +499,36 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) +def test_invalid_every_n_epoch(tmpdir): + """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + +def test_invalid_every_n_batches(tmpdir): + """ Test that an exception is raised for every_n_batches = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) + def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """