From bdce18f27bf23401c4913a464937e4c57fe6c058 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 19:44:52 -0800 Subject: [PATCH 01/49] Update model_checkpoint.py --- .../callbacks/model_checkpoint.py | 47 +++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f457e9de7d0fa8..f37d85b236d673 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -38,6 +38,8 @@ log = logging.getLogger(__name__) warning_cache = WarningCache() +log = logging.getLogger(__name__) + class ModelCheckpoint(Callback): r""" @@ -156,6 +158,8 @@ def __init__( save_weights_only: bool = False, mode: str = "min", period: int = 1, + every_n_epochs: int = 1, + every_n_batches: int = -1, ): super().__init__() self.monitor = monitor @@ -163,7 +167,9 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.period = period + self.period = every_n_epochs + self.every_n_epochs = every_n_epochs + self.every_n_batches = every_n_batches self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -174,6 +180,13 @@ def __init__( self.save_function = None self.warned_result_obj = False + if period is not None: + rank_zero_warn( + 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_epochs` instead.', DeprecationWarning + ) + self.every_n_epochs = period + self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) self.__validate_init_configuration() @@ -185,11 +198,24 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + if self._should_skip_saving_checkpoint(trainer): + return + step = trainer.global_step + skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) + if skip_step: + return + self.save_checkpoint(trainer, pl_module) + def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - self.save_checkpoint(trainer) + if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 0: + return + epoch = trainer.current_epoch + if (epoch + 1) % self.every_n_epochs == 0: + self.save_checkpoint(trainer, pl_module) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { @@ -204,7 +230,15 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused: Optional = None): + def _should_skip_saving_checkpoint(self, trainer) -> bool: + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.running_sanity_check # don't save anything during sanity check + or self.save_top_k == 0 # no models are saved + or self._last_global_step_saved == global_step # already saved at the last step + ) + + def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -219,6 +253,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): epoch = trainer.current_epoch global_step = trainer.global_step +<<<<<<< HEAD from pytorch_lightning.trainer.states import TrainerState if ( trainer.fast_dev_run # disable checkpointing with fast_dev_run @@ -230,6 +265,8 @@ def save_checkpoint(self, trainer, unused: Optional = None): ): return +======= +>>>>>>> Update model_checkpoint.py self._add_backward_monitor_support(trainer) self._validate_monitor_key(trainer) @@ -251,6 +288,10 @@ def save_checkpoint(self, trainer, unused: Optional = None): def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self.every_n_epochs == 0 or self.every_n_epochs < -1: + raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + if self.every_n_batches == 0 or self.every_n_batches < -1: + raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): From 016a30c3599990360ecf49d489d9d3cbc761044d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:11:05 -0800 Subject: [PATCH 02/49] add tests --- CHANGELOG.md | 1 + .../callbacks/model_checkpoint.py | 42 ++++++++--- tests/checkpointing/test_model_checkpoint.py | 71 +++++++++++++++++++ 3 files changed, 103 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 384a218c81305f..3ca5d76729f843 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f37d85b236d673..911dab1d030dec 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -95,7 +95,20 @@ class ModelCheckpoint(Callback): save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). + every_n_epochs: Interval (number of epochs) between checkpoints. + every_n_batches: Interval (number of batches) between checkpoints. period: Interval (number of epochs) between checkpoints. +<<<<<<< HEAD +======= + + .. warning:: + This argument has been deprecated in v1.3 and will be removed in v1.5. + Use ``every_n_epochs`` instead. + prefix: A string to put at the beginning of checkpoint filename. + + .. warning:: + This argument has been deprecated in v1.1 and will be removed in v1.3 +>>>>>>> add tests Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -157,9 +170,9 @@ def __init__( save_top_k: Optional[int] = None, save_weights_only: bool = False, mode: str = "min", - period: int = 1, every_n_epochs: int = 1, every_n_batches: int = -1, + period: Optional[int] = None, ): super().__init__() self.monitor = monitor @@ -199,6 +212,9 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.save_function = trainer.save_checkpoint def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + """ + Save a checkpoint during the training loop if configured to do so. + """ if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step @@ -230,14 +246,6 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def _should_skip_saving_checkpoint(self, trainer) -> bool: - return ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.running_sanity_check # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved - or self._last_global_step_saved == global_step # already saved at the last step - ) - def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. @@ -285,13 +293,25 @@ def save_checkpoint(self, trainer, pl_module): # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) + def _should_skip_saving_checkpoint(self, trainer) -> bool: + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.running_sanity_check # don't save anything during sanity check + or self.save_top_k == 0 # no models are saved + or self._last_global_step_saved == global_step # already saved at the last step + ) + def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.every_n_epochs == 0 or self.every_n_epochs < -1: - raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' + ) if self.every_n_batches == 0 or self.every_n_batches < -1: - raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 1b33123d6d3f6e..baac46d3d7763e 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -498,6 +498,36 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) +def test_invalid_every_n_epoch(tmpdir): + """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + +def test_invalid_every_n_batches(tmpdir): + """ Test that an exception is raised for every_n_batches = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) + def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ @@ -563,6 +593,47 @@ def test_model_checkpoint_period(tmpdir, period): assert set(os.listdir(tmpdir)) == set(expected) +@pytest.mark.parametrize("every_n_epochs", list(range(4))) +def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=0.1, + limit_val_batches=0.1, + val_check_interval=1.0, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + +@pytest.mark.parametrize("every_n_epochs", list(range(4))) +def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs, period=None) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=0.1, + limit_val_batches=0.1, + val_check_interval=1.0, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods() From 8d1b8602089172d09bf12ac672c9d3d90b061d1b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:12:24 -0800 Subject: [PATCH 03/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 911dab1d030dec..3f6bffbebb8294 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -298,7 +298,7 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.running_sanity_check # don't save anything during sanity check or self.save_top_k == 0 # no models are saved - or self._last_global_step_saved == global_step # already saved at the last step + or self._last_global_step_saved == trainer.global_step # already saved at the last step ) def __validate_init_configuration(self): From 4a9caa5cc88d4ddfb05ccf3645f9880d2a3db8f4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:16:58 -0800 Subject: [PATCH 04/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 67 ++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index baac46d3d7763e..5bd6578dfabe38 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -633,6 +633,73 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) +def test_ckpt_every_n_steps(tmpdir): + """ Tests that the checkpoints are saved every n training steps. """ + + model = LogInTwoMethods() + + trainer = Trainer( + default_root_dir=tmpdir, + min_epochs=2, + max_epochs=2, + progress_bar_refresh_rate=0, + checkpoint_callback=ModelCheckpoint( + filename="{step}", + every_n_epochs=-1, + every_n_steps=16, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ), + logger=False, + ) + + trainer.fit(model) + self.assertCountEqual( + os.listdir(tmpdir), + [ + "step=15.ckpt", + "step=31.ckpt", + "step=47.ckpt", + "step=63.ckpt", + "step=79.ckpt", + "step=95.ckpt", + "step=111.ckpt", + "step=127.ckpt", + ], + ) + +def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): + """ Tests that checkpoints are taken every 30 steps and every epochs """ + model = LogInTwoMethods() + trainer = Trainer( + default_root_dir=tmpdir, + min_epochs=2, + max_epochs=2, + progress_bar_refresh_rate=0, + checkpoint_callback=ModelCheckpoint( + every_n_epochs=1, + every_n_steps=30, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ), + logger=False, + ) + trainer.fit(model) + self.assertCountEqual( + os.listdir(tmpdir), + [ + "epoch=0-step=29.ckpt", + "epoch=0-step=59.ckpt", + "epoch=0-step=63.ckpt", + "epoch=1-step=89.ckpt", + "epoch=1-step=119.ckpt", + "epoch=1-step=127.ckpt", + "epoch=1-step=127-v0.ckpt", + ], + ) + def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ From b4cd96ecfb954a9a3215d936d4b69eb6f8d51ae1 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:24:51 -0800 Subject: [PATCH 05/49] fix tests --- .../callbacks/model_checkpoint.py | 21 ------ tests/checkpointing/test_model_checkpoint.py | 65 +++++++++++-------- 2 files changed, 38 insertions(+), 48 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3f6bffbebb8294..8309a86c99b782 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -98,17 +98,10 @@ class ModelCheckpoint(Callback): every_n_epochs: Interval (number of epochs) between checkpoints. every_n_batches: Interval (number of batches) between checkpoints. period: Interval (number of epochs) between checkpoints. -<<<<<<< HEAD -======= .. warning:: This argument has been deprecated in v1.3 and will be removed in v1.5. Use ``every_n_epochs`` instead. - prefix: A string to put at the beginning of checkpoint filename. - - .. warning:: - This argument has been deprecated in v1.1 and will be removed in v1.3 ->>>>>>> add tests Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -261,20 +254,6 @@ def save_checkpoint(self, trainer, pl_module): epoch = trainer.current_epoch global_step = trainer.global_step -<<<<<<< HEAD - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state != TrainerState.FITTING # don't save anything during non-fit - or trainer.sanity_checking # don't save anything during sanity check - or self.period < 1 # no models are saved - or (epoch + 1) % self.period # skip epoch - or self._last_global_step_saved == global_step # already saved at the last step - ): - return - -======= ->>>>>>> Update model_checkpoint.py self._add_backward_monitor_support(trainer) self._validate_monitor_key(trainer) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 5bd6578dfabe38..758d87605eca09 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -498,6 +498,7 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) + def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ with pytest.raises( @@ -513,6 +514,7 @@ def test_invalid_every_n_epoch(tmpdir): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ with pytest.raises( @@ -597,7 +599,12 @@ def test_model_checkpoint_period(tmpdir, period): def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_epochs=every_n_epochs + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -613,11 +620,18 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) + @pytest.mark.parametrize("every_n_epochs", list(range(4))) def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs, period=None) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_epochs=every_n_epochs, + period=None + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -633,6 +647,7 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) + def test_ckpt_every_n_steps(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ @@ -655,19 +670,17 @@ def test_ckpt_every_n_steps(tmpdir): ) trainer.fit(model) - self.assertCountEqual( - os.listdir(tmpdir), - [ - "step=15.ckpt", - "step=31.ckpt", - "step=47.ckpt", - "step=63.ckpt", - "step=79.ckpt", - "step=95.ckpt", - "step=111.ckpt", - "step=127.ckpt", - ], - ) + expected = [ + "step=15.ckpt", + "step=31.ckpt", + "step=47.ckpt", + "step=63.ckpt", + "step=79.ckpt", + "step=95.ckpt", + "step=111.ckpt", + "step=127.ckpt", + ] + assert set(os.listdir(tmpdir)) == set(expected) def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): """ Tests that checkpoints are taken every 30 steps and every epochs """ @@ -687,18 +700,16 @@ def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): logger=False, ) trainer.fit(model) - self.assertCountEqual( - os.listdir(tmpdir), - [ - "epoch=0-step=29.ckpt", - "epoch=0-step=59.ckpt", - "epoch=0-step=63.ckpt", - "epoch=1-step=89.ckpt", - "epoch=1-step=119.ckpt", - "epoch=1-step=127.ckpt", - "epoch=1-step=127-v0.ckpt", - ], - ) + expected = [ + "epoch=0-step=29.ckpt", + "epoch=0-step=59.ckpt", + "epoch=0-step=63.ckpt", + "epoch=1-step=89.ckpt", + "epoch=1-step=119.ckpt", + "epoch=1-step=127.ckpt", + "epoch=1-step=127-v0.ckpt", + ] + assert set(os.listdir(tmpdir)) == set(expected) def test_model_checkpoint_topk_zero(tmpdir): From 767cc943f3f2f6d8dfd633e34db7456c23c46f91 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 21:25:23 -0800 Subject: [PATCH 06/49] every_n_batches --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8309a86c99b782..3a4dc5df28799e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -211,7 +211,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) + skip_step = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) if skip_step: return self.save_checkpoint(trainer, pl_module) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 758d87605eca09..b9c7c82219a39c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -648,7 +648,7 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): assert set(os.listdir(tmpdir)) == set(expected) -def test_ckpt_every_n_steps(tmpdir): +def test_ckpt_every_n_batches(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ model = LogInTwoMethods() @@ -661,7 +661,7 @@ def test_ckpt_every_n_steps(tmpdir): checkpoint_callback=ModelCheckpoint( filename="{step}", every_n_epochs=-1, - every_n_steps=16, + every_n_batches=16, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -682,7 +682,7 @@ def test_ckpt_every_n_steps(tmpdir): ] assert set(os.listdir(tmpdir)) == set(expected) -def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): +def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() trainer = Trainer( @@ -692,7 +692,7 @@ def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( every_n_epochs=1, - every_n_steps=30, + every_n_batches=30, dirpath=tmpdir, save_top_k=-1, save_last=False, From 5754303f7d040916f568132fd05b1bddb7012626 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 22:18:40 -0800 Subject: [PATCH 07/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b9c7c82219a39c..584d7854270b69 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -504,15 +504,15 @@ def test_invalid_every_n_epoch(tmpdir): with pytest.raises( MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' ): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, period=None) with pytest.raises( MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' ): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2, period=None) # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1, period=None) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3, period=None) def test_invalid_every_n_batches(tmpdir): From ea101c89b993d49134bb1ac8145d499a1e7c1cea Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 22:30:25 -0800 Subject: [PATCH 08/49] defaults --- .../callbacks/model_checkpoint.py | 1 + tests/checkpointing/test_model_checkpoint.py | 28 +++++-------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3a4dc5df28799e..30780480be3cb5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -192,6 +192,7 @@ def __init__( ' Please use `every_n_epochs` instead.', DeprecationWarning ) self.every_n_epochs = period + self.period = period self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 584d7854270b69..9f788d5ff91283 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -501,13 +501,9 @@ def test_none_monitor_top_k(tmpdir): def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, period=None) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2, period=None) # These should not fail @@ -517,13 +513,9 @@ def test_invalid_every_n_epoch(tmpdir): def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_batches=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) # These should not fail @@ -600,10 +592,7 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, - filename='{epoch}', - save_top_k=-1, - every_n_epochs=every_n_epochs + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -626,11 +615,7 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, - filename='{epoch}', - save_top_k=-1, - every_n_epochs=every_n_epochs, - period=None + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs, period=None ) trainer = Trainer( default_root_dir=tmpdir, @@ -682,6 +667,7 @@ def test_ckpt_every_n_batches(tmpdir): ] assert set(os.listdir(tmpdir)) == set(expected) + def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() From ddaa2e090723be1e6253fa3d37449ab7ad9c0186 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 22:38:04 -0800 Subject: [PATCH 09/49] rm tests --- .../callbacks/model_checkpoint.py | 8 ------- tests/checkpointing/test_model_checkpoint.py | 24 ------------------- 2 files changed, 32 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 30780480be3cb5..0f2ba76ad85df7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -284,14 +284,6 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_epochs == 0 or self.every_n_epochs < -1: - raise MisconfigurationException( - f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' - ) - if self.every_n_batches == 0 or self.every_n_batches < -1: - raise MisconfigurationException( - f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' - ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 9f788d5ff91283..a28333f4c1fa1d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -499,30 +499,6 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=0) -def test_invalid_every_n_epoch(tmpdir): - """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, period=None) - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2, period=None) - - # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1, period=None) - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3, period=None) - - -def test_invalid_every_n_batches(tmpdir): - """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) - - # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) - ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) - - def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ with pytest.warns(UserWarning, match=r'ModelCheckpoint.*is a redundant.*'): From a1fdb707d0e040d58818a006334af7e82a3cf461 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 23 Feb 2021 09:38:39 -0800 Subject: [PATCH 10/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0f2ba76ad85df7..3f60653b711b2b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -221,7 +221,7 @@ def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 0: + if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1: return epoch = trainer.current_epoch if (epoch + 1) % self.every_n_epochs == 0: From 55aca5edd3e309d73bf0f80ff5164ae3de3c705a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 23 Feb 2021 23:33:31 -0800 Subject: [PATCH 11/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a28333f4c1fa1d..c7f12f4fa17f64 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -669,7 +669,6 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): "epoch=1-step=89.ckpt", "epoch=1-step=119.ckpt", "epoch=1-step=127.ckpt", - "epoch=1-step=127-v0.ckpt", ] assert set(os.listdir(tmpdir)) == set(expected) From 4ed5b820f63a1b7ee32644d2dc3e7e0ae345de4b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Feb 2021 12:09:01 +0100 Subject: [PATCH 12/49] Prune deprecated metrics for 1.3 (#6161) * prune deprecated metrics for 1.3 * isort / yapf --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ca5d76729f843..cad054bcb39d8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) - Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) +- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) From c7f25a40c134bbb5918c589b778a14b50bd28c35 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 19:44:52 -0800 Subject: [PATCH 13/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3f60653b711b2b..8bd43a3565d09f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -206,13 +206,10 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.save_function = trainer.save_checkpoint def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: - """ - Save a checkpoint during the training loop if configured to do so. - """ if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_step = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) if skip_step: return self.save_checkpoint(trainer, pl_module) @@ -240,6 +237,14 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] + def _should_skip_saving_checkpoint(self, trainer) -> bool: + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.running_sanity_check # don't save anything during sanity check + or self.save_top_k == 0 # no models are saved + or self._last_global_step_saved == global_step # already saved at the last step + ) + def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. @@ -284,6 +289,10 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self.every_n_epochs == 0 or self.every_n_epochs < -1: + raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + if self.every_n_batches == 0 or self.every_n_batches < -1: + raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): From b89abbb1d6a19dfed6a0d7a43f933eaf0d9cce31 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:11:05 -0800 Subject: [PATCH 14/49] add tests --- CHANGELOG.md | 1 + tests/checkpointing/test_model_checkpoint.py | 30 ++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cad054bcb39d8e..21d1d6cf3d9ac4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index c7f12f4fa17f64..89849ee62be51e 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -498,6 +498,36 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) +def test_invalid_every_n_epoch(tmpdir): + """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + +def test_invalid_every_n_batches(tmpdir): + """ Test that an exception is raised for every_n_batches = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) + def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ From 3bc2bc733ac352f8cc3eaaf463b485becec8fdb4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 22:30:25 -0800 Subject: [PATCH 15/49] defaults --- tests/checkpointing/test_model_checkpoint.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 89849ee62be51e..a1afd2cd79c331 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -515,13 +515,9 @@ def test_invalid_every_n_epoch(tmpdir): def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_batches=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) # These should not fail From 93d59e8cb338a13e0f662eb1091ccace185e9960 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 24 Feb 2021 09:30:28 -0800 Subject: [PATCH 16/49] Update CHANGELOG.md --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21d1d6cf3d9ac4..3ca5d76729f843 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) -- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) -- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) - - Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) From acddd6a7b63f779ef6fa041ed3d332abae14590e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 24 Feb 2021 09:50:30 -0800 Subject: [PATCH 17/49] pre-commit --- pytorch_lightning/callbacks/model_checkpoint.py | 12 ++++++++---- tests/checkpointing/test_model_checkpoint.py | 10 ++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8bd43a3565d09f..54ce4ce6204bcd 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -239,9 +239,9 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): def _should_skip_saving_checkpoint(self, trainer) -> bool: return ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run + trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.running_sanity_check # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved + or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == global_step # already saved at the last step ) @@ -290,9 +290,13 @@ def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.every_n_epochs == 0 or self.every_n_epochs < -1: - raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' + ) if self.every_n_batches == 0 or self.every_n_batches < -1: - raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a1afd2cd79c331..4a2169d582cb7b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -498,21 +498,19 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) + def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) # These should not fail ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): From 8a6539dd8ffbf5354695a8ad539deed40edcf601 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 21:41:06 -0800 Subject: [PATCH 18/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 54ce4ce6204bcd..cd2f611c7002fd 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -237,14 +237,6 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def _should_skip_saving_checkpoint(self, trainer) -> bool: - return ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.running_sanity_check # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved - or self._last_global_step_saved == global_step # already saved at the last step - ) - def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. From 07e5cec722c5e8c61c97710eda639d82561b5ed8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 21:44:52 -0800 Subject: [PATCH 19/49] update defaults --- pytorch_lightning/callbacks/model_checkpoint.py | 10 +++++----- tests/checkpointing/test_model_checkpoint.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index cd2f611c7002fd..af19fd34c2631a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -164,7 +164,7 @@ def __init__( save_weights_only: bool = False, mode: str = "min", every_n_epochs: int = 1, - every_n_batches: int = -1, + every_n_batches: int = 0, period: Optional[int] = None, ): super().__init__() @@ -281,13 +281,13 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_epochs == 0 or self.every_n_epochs < -1: + if self.every_n_epochs <= -1: raise MisconfigurationException( - f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' + f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be non-negative.' ) - if self.every_n_batches == 0 or self.every_n_batches < -1: + if self.every_n_batches <= -1: raise MisconfigurationException( - f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' + f'Invalid value for every_n_batches={self.every_n_batches}. Must be non-negative.' ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4a2169d582cb7b..d07311a6126a6b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -501,25 +501,25 @@ def test_none_monitor_top_k(tmpdir): def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-1*'): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, every_n_batches=1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-1*'): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) @@ -645,7 +645,7 @@ def test_ckpt_every_n_batches(tmpdir): progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( filename="{step}", - every_n_epochs=-1, + every_n_epochs=0, every_n_batches=16, dirpath=tmpdir, save_top_k=-1, From 763476c9ab4bdaa77d182fe7d56ca8fd2b363510 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 21:45:57 -0800 Subject: [PATCH 20/49] Update test_remove_1-5.py --- tests/deprecated_api/test_remove_1-5.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 7d8c7d2adeea10..f205f124f4e862 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -78,6 +78,7 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) +<<<<<<< HEAD def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): @@ -104,3 +105,10 @@ def configure_optimizers(self): with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): trainer.fit(model) +======= +def test_v1_5_0_model_checkpoint_period(tmpdir): + with no_warning_call(DeprecationWarning): + ModelCheckpoint(dirpath=tmpdir) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + ModelCheckpoint(dirpath=tmpdir, period=1) +>>>>>>> Update test_remove_1-5.py From 74b0aeeaad7aae51a42cb235966255802bbdba14 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 22:00:37 -0800 Subject: [PATCH 21/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index af19fd34c2631a..74222c3891c3f7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -101,6 +101,7 @@ class ModelCheckpoint(Callback): .. warning:: This argument has been deprecated in v1.3 and will be removed in v1.5. + Use ``every_n_epochs`` instead. Note: From 1a1c1c6c4bdd69da904b72512503959672d8339d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 22:13:45 -0800 Subject: [PATCH 22/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 74222c3891c3f7..fb2ef5cfae1f2d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -210,8 +210,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) - if skip_step: + skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + if skip_batch: return self.save_checkpoint(trainer, pl_module) From f6625ccd13d9b7c2a1d2a4f83cab5f399b90214c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 23:52:58 -0800 Subject: [PATCH 23/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index fb2ef5cfae1f2d..9ed8b31790a248 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -219,11 +219,14 @@ def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1: - return epoch = trainer.current_epoch - if (epoch + 1) % self.every_n_epochs == 0: - self.save_checkpoint(trainer, pl_module) + skip = ( + self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 + or (epoch + 1) % self.every_n_epochs != 0 + ) + if skip: + return + self.save_checkpoint(trainer, pl_module) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { From 7572947e2abae7e636187eb8d74e093d54253823 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 23:53:51 -0800 Subject: [PATCH 24/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9ed8b31790a248..05f4988f867ef3 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -219,10 +219,9 @@ def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - epoch = trainer.current_epoch skip = ( self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 - or (epoch + 1) % self.every_n_epochs != 0 + or (trainer.current_epoch + 1) % self.every_n_epochs != 0 ) if skip: return From 67700d7024c9067f742e9d3ee9f00f9e0de22276 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 23:56:59 -0800 Subject: [PATCH 25/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 05f4988f867ef3..9b6a58e61e96c5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -211,6 +211,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data return step = trainer.global_step skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + log.warning(f"in on_train_batch_end at step {step}, every_n_batches={self.every_n_batches}, going to skip batch? {skip_batch}") if skip_batch: return self.save_checkpoint(trainer, pl_module) From 5b56212b5dbacfc5910ec8ce74a8c678845baef2 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 2 Mar 2021 00:25:53 -0800 Subject: [PATCH 26/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9b6a58e61e96c5..67765516b995ee 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -210,8 +210,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) - log.warning(f"in on_train_batch_end at step {step}, every_n_batches={self.every_n_batches}, going to skip batch? {skip_batch}") + skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) From bd3d10fe17b1dcfde60ae88488e5a4af1e176869 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 21:56:50 -0800 Subject: [PATCH 27/49] fix tests --- CHANGELOG.md | 2 ++ .../callbacks/model_checkpoint.py | 8 ++---- tests/checkpointing/test_model_checkpoint.py | 28 +++++-------------- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ca5d76729f843..4e1ddeb48ca81b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- `period` has been deprecated in favor of `every_n_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 67765516b995ee..4d374856922eb6 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -174,8 +174,8 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.period = every_n_epochs - self.every_n_epochs = every_n_epochs + self.every_n_epochs = period or every_n_epochs + self.period = self.every_n_epochs self.every_n_batches = every_n_batches self._last_global_step_saved = -1 self.current_score = None @@ -192,8 +192,6 @@ def __init__( 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' ' Please use `every_n_epochs` instead.', DeprecationWarning ) - self.every_n_epochs = period - self.period = period self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) @@ -209,7 +207,7 @@ def on_pretrain_routine_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: if self._should_skip_saving_checkpoint(trainer): return - step = trainer.global_step + step = trainer.total_batch_idx skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) if skip_batch: return diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d07311a6126a6b..d57f1b423c5aff 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -598,9 +598,8 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -621,9 +620,8 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -637,16 +635,15 @@ def test_ckpt_every_n_batches(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ model = LogInTwoMethods() - + every_n_batches = 16 trainer = Trainer( default_root_dir=tmpdir, - min_epochs=2, max_epochs=2, progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( filename="{step}", every_n_epochs=0, - every_n_batches=16, + every_n_batches=every_n_batches, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -655,16 +652,7 @@ def test_ckpt_every_n_batches(tmpdir): ) trainer.fit(model) - expected = [ - "step=15.ckpt", - "step=31.ckpt", - "step=47.ckpt", - "step=63.ckpt", - "step=79.ckpt", - "step=95.ckpt", - "step=111.ckpt", - "step=127.ckpt", - ] + expected=[f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] assert set(os.listdir(tmpdir)) == set(expected) @@ -673,9 +661,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): model = LogInTwoMethods() trainer = Trainer( default_root_dir=tmpdir, - min_epochs=2, max_epochs=2, - progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( every_n_epochs=1, every_n_batches=30, From bf2831f58d08fc0687c3b5d81ea4a84be65a2e62 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 21:57:37 -0800 Subject: [PATCH 28/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d57f1b423c5aff..e8a40f094d6c36 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -652,7 +652,7 @@ def test_ckpt_every_n_batches(tmpdir): ) trainer.fit(model) - expected=[f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] + expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] assert set(os.listdir(tmpdir)) == set(expected) From 512f030370665c49e058ee1f4385fb1cc6e144a5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 21:59:41 -0800 Subject: [PATCH 29/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4d374856922eb6..ce8672e067d0b0 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -282,11 +282,11 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_epochs <= -1: + if self.every_n_epochs < 0: raise MisconfigurationException( f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be non-negative.' ) - if self.every_n_batches <= -1: + if self.every_n_batches < 0: raise MisconfigurationException( f'Invalid value for every_n_batches={self.every_n_batches}. Must be non-negative.' ) From 23a1960658a9a722e123f72f3ffe2de5c2010fb6 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 22:18:23 -0800 Subject: [PATCH 30/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ce8672e067d0b0..164f0de3af3f80 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -208,7 +208,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_skip_saving_checkpoint(trainer): return step = trainer.total_batch_idx - skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) + skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) From 6a2d67bf07110b21539ee51ce1e2864249454182 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 23:44:15 -0800 Subject: [PATCH 31/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 164f0de3af3f80..a3120f8b9010c1 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -38,8 +38,6 @@ log = logging.getLogger(__name__) warning_cache = WarningCache() -log = logging.getLogger(__name__) - class ModelCheckpoint(Callback): r""" @@ -206,9 +204,11 @@ def on_pretrain_routine_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: if self._should_skip_saving_checkpoint(trainer): + log.critical("in train batch end, not saving checkpoint after trainer check") return step = trainer.total_batch_idx skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + log.critical(f"in train batch end, every_n_batches={self.every_n_batches}, step={step}, skip_batch? {skip_batch}") if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -282,14 +282,6 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_epochs < 0: - raise MisconfigurationException( - f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be non-negative.' - ) - if self.every_n_batches < 0: - raise MisconfigurationException( - f'Invalid value for every_n_batches={self.every_n_batches}. Must be non-negative.' - ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): From 70fb0a6ded0dda16efc00fec7417b72824df6cf1 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 23:46:22 -0800 Subject: [PATCH 32/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e8a40f094d6c36..82349954617f89 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -500,27 +500,17 @@ def test_none_monitor_top_k(tmpdir): def test_invalid_every_n_epoch(tmpdir): - """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-1*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) - - # These should not fail + """ Test different configurations for every_n_epochs. """ ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, every_n_batches=1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1, every_n_batches=1) def test_invalid_every_n_batches(tmpdir): - """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-1*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) - - # These should not fail + """ Test different configurations for every_n_batches. """ ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1, every_n_epochs=2) def test_none_monitor_save_last(tmpdir): From 4d58406139b4fa9fb07aabb11f0ba3c57ac30007 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 00:12:22 -0800 Subject: [PATCH 33/49] ckpt-callback --- .../callbacks/model_checkpoint.py | 4 ++- tests/checkpointing/test_model_checkpoint.py | 32 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a3120f8b9010c1..8dcd7fef0134ee 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -208,7 +208,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data return step = trainer.total_batch_idx skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) - log.critical(f"in train batch end, every_n_batches={self.every_n_batches}, step={step}, skip_batch? {skip_batch}") + log.critical( + f"in train batch end, every_n_batches={self.every_n_batches}, step={step}, skip_batch? {skip_batch}" + ) if skip_batch: return self.save_checkpoint(trainer, pl_module) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 82349954617f89..3afa210e0da664 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -626,18 +626,19 @@ def test_ckpt_every_n_batches(tmpdir): model = LogInTwoMethods() every_n_batches = 16 + checkpoint_callback = ModelCheckpoint( + filename="{step}", + every_n_epochs=0, + every_n_batches=every_n_batches, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, progress_bar_refresh_rate=0, - checkpoint_callback=ModelCheckpoint( - filename="{step}", - every_n_epochs=0, - every_n_batches=every_n_batches, - dirpath=tmpdir, - save_top_k=-1, - save_last=False, - ), + callbacks=[checkpoint_callback], logger=False, ) @@ -649,16 +650,17 @@ def test_ckpt_every_n_batches(tmpdir): def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() + checkpoint_callback = ModelCheckpoint( + every_n_epochs=1, + every_n_batches=30, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ), trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - checkpoint_callback=ModelCheckpoint( - every_n_epochs=1, - every_n_batches=30, - dirpath=tmpdir, - save_top_k=-1, - save_last=False, - ), + callbacks=[checkpoint_callback], logger=False, ) trainer.fit(model) From be47d800bf3f485219732beb3a649bdee6e91bd4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 00:31:16 -0800 Subject: [PATCH 34/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 3afa210e0da664..ec479c291d8d2c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -656,7 +656,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): dirpath=tmpdir, save_top_k=-1, save_last=False, - ), + ) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, From e9702d21ba3f891f65193c73628058b95e4fb0ee Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 00:31:59 -0800 Subject: [PATCH 35/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8dcd7fef0134ee..f7db4821cf00f6 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -94,7 +94,7 @@ class ModelCheckpoint(Callback): saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). every_n_epochs: Interval (number of epochs) between checkpoints. - every_n_batches: Interval (number of batches) between checkpoints. + every_n_batches: Interval (number of training batches) between checkpoints. period: Interval (number of epochs) between checkpoints. .. warning:: From ddad26ac0b7bdb5382cde843d1ceb811c56af320 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 00:49:52 -0800 Subject: [PATCH 36/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f7db4821cf00f6..514f591319fcd0 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -204,13 +204,9 @@ def on_pretrain_routine_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: if self._should_skip_saving_checkpoint(trainer): - log.critical("in train batch end, not saving checkpoint after trainer check") return step = trainer.total_batch_idx skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) - log.critical( - f"in train batch end, every_n_batches={self.every_n_batches}, step={step}, skip_batch? {skip_batch}" - ) if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -223,6 +219,9 @@ def on_validation_end(self, trainer, pl_module): self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_epochs != 0 ) + log.critical( + f"in validation end, every_n_epochs={self.every_n_epochs}, period={self.period}, step={step}, skip? {skip}" + ) if skip: return self.save_checkpoint(trainer, pl_module) From 5bb06fc5c99d0aaab549ce14acb6345f06e29bb3 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 01:06:07 -0800 Subject: [PATCH 37/49] validation-end --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 514f591319fcd0..aaed823061b325 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -220,7 +220,7 @@ def on_validation_end(self, trainer, pl_module): or (trainer.current_epoch + 1) % self.every_n_epochs != 0 ) log.critical( - f"in validation end, every_n_epochs={self.every_n_epochs}, period={self.period}, step={step}, skip? {skip}" + f"in validation end, every_n_epochs={self.every_n_epochs}, period={self.period}, epoch={trainer.current_epoch}, skip? {skip}" ) if skip: return diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ec479c291d8d2c..68b048a7d9f30a 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -565,9 +565,8 @@ def test_model_checkpoint_period(tmpdir, period): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) From 73fe8fac4ab2fc18f0f38377fed2fcc63951b643 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 01:24:16 -0800 Subject: [PATCH 38/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index aaed823061b325..e1570db2d58b6c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -172,7 +172,7 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.every_n_epochs = period or every_n_epochs + self.every_n_epochs = period if period is not None else every_n_epochs self.period = self.every_n_epochs self.every_n_batches = every_n_batches self._last_global_step_saved = -1 @@ -219,9 +219,6 @@ def on_validation_end(self, trainer, pl_module): self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_epochs != 0 ) - log.critical( - f"in validation end, every_n_epochs={self.every_n_epochs}, period={self.period}, epoch={trainer.current_epoch}, skip? {skip}" - ) if skip: return self.save_checkpoint(trainer, pl_module) From 3ce322a0c94df089990a67e275360db5cb4df595 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 09:59:30 -0800 Subject: [PATCH 39/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 42 +++++++------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 68b048a7d9f30a..395105c22adb1d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -499,20 +499,6 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=0) -def test_invalid_every_n_epoch(tmpdir): - """ Test different configurations for every_n_epochs. """ - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, every_n_batches=1) - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1, every_n_batches=1) - - -def test_invalid_every_n_batches(tmpdir): - """ Test different configurations for every_n_batches. """ - ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) - ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1, every_n_epochs=2) - - def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ with pytest.warns(UserWarning, match=r'ModelCheckpoint.*is a redundant.*'): @@ -599,11 +585,12 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): @pytest.mark.parametrize("every_n_epochs", list(range(4))) -def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): +def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs): + """ Tests that if period is set, it takes precedence over every_n_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs, period=None + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=(2 * every_n_epochs), period=every_n_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -646,32 +633,33 @@ def test_ckpt_every_n_batches(tmpdir): assert set(os.listdir(tmpdir)) == set(expected) -def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): +@pytest.mark.parametrize("every_n_epochs", [1, 3]) +def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() checkpoint_callback = ModelCheckpoint( - every_n_epochs=1, + every_n_epochs=every_n_epochs, every_n_batches=30, dirpath=tmpdir, save_top_k=-1, save_last=False, + filename="{step}", ) + max_epochs = 3 + epoch_step_length = 64 trainer = Trainer( default_root_dir=tmpdir, - max_epochs=2, + max_epochs=max_epochs, callbacks=[checkpoint_callback], logger=False, ) trainer.fit(model) - expected = [ - "epoch=0-step=29.ckpt", - "epoch=0-step=59.ckpt", - "epoch=0-step=63.ckpt", - "epoch=1-step=89.ckpt", - "epoch=1-step=119.ckpt", - "epoch=1-step=127.ckpt", + expected_steps_for_ckpt = [ + i for i in range(epoch_step_length * max_epochs) + if (i % every_n_batches) == 0 or (i * epoch_step_length % every_n_epochs == 0) ] - assert set(os.listdir(tmpdir)) == set(expected) + expected_ckpt_files = [f"step={step}.ckpt" for step in step] + assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) def test_model_checkpoint_topk_zero(tmpdir): From 55ce3f4b4cdf3b4f123908cb566e58a9ee73fcf0 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 10:02:22 -0800 Subject: [PATCH 40/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 395105c22adb1d..77a5bf6c6c81e3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -637,9 +637,10 @@ def test_ckpt_every_n_batches(tmpdir): def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() + every_n_batches = 30 checkpoint_callback = ModelCheckpoint( every_n_epochs=every_n_epochs, - every_n_batches=30, + every_n_batches=every_n_batches, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -658,7 +659,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): i for i in range(epoch_step_length * max_epochs) if (i % every_n_batches) == 0 or (i * epoch_step_length % every_n_epochs == 0) ] - expected_ckpt_files = [f"step={step}.ckpt" for step in step] + expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) From 16986047f55a2b096da88f5f6626356c5525ed8c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 10:24:51 -0800 Subject: [PATCH 41/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 77a5bf6c6c81e3..93a8bf54aee723 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -657,7 +657,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): trainer.fit(model) expected_steps_for_ckpt = [ i for i in range(epoch_step_length * max_epochs) - if (i % every_n_batches) == 0 or (i * epoch_step_length % every_n_epochs == 0) + if ((i+1) % every_n_batches) == 0 or (i+1) % (every_n_epochs * epoch_step_length) == 0) ] expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) From 3e46c883cbb80940f1557d56f834c3577e13e39d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 10:27:16 -0800 Subject: [PATCH 42/49] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 93a8bf54aee723..f6a60c9bec0543 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -657,7 +657,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): trainer.fit(model) expected_steps_for_ckpt = [ i for i in range(epoch_step_length * max_epochs) - if ((i+1) % every_n_batches) == 0 or (i+1) % (every_n_epochs * epoch_step_length) == 0) + if ((i + 1) % every_n_batches) == 0 or (i + 1) % (every_n_epochs * epoch_step_length) == 0 ] expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) From 71cf574fe0a1501aa8b7c9e5ac8da5f67c785a86 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:00:58 -0800 Subject: [PATCH 43/49] clarify-names - Make names explicit as to which hooks they apply to - Use step instead of batch for consistency with global step --- CHANGELOG.md | 2 +- .../callbacks/model_checkpoint.py | 47 ++++++++----- tests/checkpointing/test_model_checkpoint.py | 66 +++++++++++++------ tests/deprecated_api/test_remove_1-5.py | 5 +- 4 files changed, 78 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e1ddeb48ca81b..28ec368a4d831d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,7 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated -- `period` has been deprecated in favor of `every_n_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) +- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e1570db2d58b6c..5fdec6afad56a6 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -93,14 +93,18 @@ class ModelCheckpoint(Callback): save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). - every_n_epochs: Interval (number of epochs) between checkpoints. - every_n_batches: Interval (number of training batches) between checkpoints. + every_n_train_steps: Number of training steps between checkpoints. + To disable, set ``every_n_train_steps = 0``. This value must be non-negative. + every_n_val_epochs: Number of validation epochs between checkpoints. + To disable, set ``every_n_val_epochs = 0``. This value must be non-negative. + This is not mutually exclusive with ``every_n_val_epochs``. If both are set, pay extreme caution if also setting ``monitor`` + as the ``monitor`` value must be available in both training and validation. This can have unintended consequences with tracking the top K models. period: Interval (number of epochs) between checkpoints. .. warning:: This argument has been deprecated in v1.3 and will be removed in v1.5. - Use ``every_n_epochs`` instead. + Use ``every_n_val_epochs`` instead. Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -162,8 +166,8 @@ def __init__( save_top_k: Optional[int] = None, save_weights_only: bool = False, mode: str = "min", - every_n_epochs: int = 1, - every_n_batches: int = 0, + every_n_train_steps: int = 0, + every_n_val_epochs: int = 1, period: Optional[int] = None, ): super().__init__() @@ -172,9 +176,9 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.every_n_epochs = period if period is not None else every_n_epochs - self.period = self.every_n_epochs - self.every_n_batches = every_n_batches + self.every_n_val_epochs = period if period is not None else every_n_val_epochs + self.period = self.every_n_val_epochs + self.every_n_train_steps = every_n_train_steps self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -188,7 +192,7 @@ def __init__( if period is not None: rank_zero_warn( 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_epochs` instead.', DeprecationWarning + ' Please use `every_n_val_epochs` instead.', DeprecationWarning ) self.__init_monitor_mode(monitor, mode) @@ -202,11 +206,12 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None: + """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ if self._should_skip_saving_checkpoint(trainer): return - step = trainer.total_batch_idx - skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + step = trainer.global_step + skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -216,8 +221,8 @@ def on_validation_end(self, trainer, pl_module): checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 - or (trainer.current_epoch + 1) % self.every_n_epochs != 0 + self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 ) if skip: return @@ -236,7 +241,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, pl_module): + def save_checkpoint(self, trainer, unused): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -280,6 +285,14 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self.every_n_train_steps < 0: + raise MisconfigurationException( + f'Invalid value for every_n_train_batches={self.every_n_train_steps}. Must be >= 0' + ) + if self.every_n_val_epochs < 0: + raise MisconfigurationException( + f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): @@ -583,9 +596,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A self._save_model(trainer, filepath) if ( - self.save_top_k is None - and self.best_model_path - and self.best_model_path != filepath + self.save_top_k is None and self.best_model_path and self.best_model_path != filepath and trainer.is_global_zero ): self._del_model(self.best_model_path) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f6a60c9bec0543..fc2bb8235102a9 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -508,6 +508,26 @@ def test_none_monitor_save_last(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_last=False) +def test_invalid_every_n_val_epochs(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() @@ -562,12 +582,12 @@ def test_model_checkpoint_period(tmpdir, period): assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_epochs", list(range(4))) -def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -580,17 +600,22 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_epochs", list(range(4))) -def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs): - """ Tests that if period is set, it takes precedence over every_n_epochs for backwards compatibility. """ +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): + """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=(2 * every_n_epochs), period=every_n_epochs + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_val_epochs=(2 * every_n_val_epochs), + period=every_n_val_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -603,19 +628,20 @@ def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) -def test_ckpt_every_n_batches(tmpdir): +def test_ckpt_every_n_train_steps(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ model = LogInTwoMethods() - every_n_batches = 16 + every_n_train_steps = 16 checkpoint_callback = ModelCheckpoint( filename="{step}", - every_n_epochs=0, - every_n_batches=every_n_batches, + every_n_val_epochs=0, + every_n_train_steps=every_n_train_steps, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -629,18 +655,18 @@ def test_ckpt_every_n_batches(tmpdir): ) trainer.fit(model) - expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] + expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_train_steps)] assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_epochs", [1, 3]) -def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): +@pytest.mark.parametrize("every_n_val_epochs", [1, 3]) +def test_ckpt_every_n_train_steps_and_every_n_val_epochs(tmpdir, every_n_val_epochs): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() - every_n_batches = 30 + every_n_train_steps = 30 checkpoint_callback = ModelCheckpoint( - every_n_epochs=every_n_epochs, - every_n_batches=every_n_batches, + every_n_val_epochs=every_n_val_epochs, + every_n_train_steps=every_n_train_steps, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -657,7 +683,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): trainer.fit(model) expected_steps_for_ckpt = [ i for i in range(epoch_step_length * max_epochs) - if ((i + 1) % every_n_batches) == 0 or (i + 1) % (every_n_epochs * epoch_step_length) == 0 + if ((i + 1) % every_n_train_steps) == 0 or (i + 1) % (every_n_val_epochs * epoch_step_length) == 0 ] expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f205f124f4e862..e65ebbab254de1 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -78,7 +78,6 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) -<<<<<<< HEAD def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): @@ -105,10 +104,10 @@ def configure_optimizers(self): with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): trainer.fit(model) -======= + + def test_v1_5_0_model_checkpoint_period(tmpdir): with no_warning_call(DeprecationWarning): ModelCheckpoint(dirpath=tmpdir) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): ModelCheckpoint(dirpath=tmpdir, period=1) ->>>>>>> Update test_remove_1-5.py From 2cd6c3d0828f1b6075f14b9a6a09c69b406ae3cb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:03:58 -0800 Subject: [PATCH 44/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5fdec6afad56a6..d3467bad340f53 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -241,7 +241,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused): + def save_checkpoint(self, trainer, unused: Optional = None): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -277,7 +277,8 @@ def save_checkpoint(self, trainer, unused): def _should_skip_saving_checkpoint(self, trainer) -> bool: return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.running_sanity_check # don't save anything during sanity check + or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == trainer.global_step # already saved at the last step ) From ca07b581c1524983de589d5f380fe9a79f9661c2 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:12:24 -0800 Subject: [PATCH 45/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d3467bad340f53..aa74428841a252 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -275,6 +275,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): self._save_last_checkpoint(trainer, monitor_candidates) def _should_skip_saving_checkpoint(self, trainer) -> bool: + from pytorch_lightning.trainer.states import TrainerState return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.state != TrainerState.FITTING # don't save anything during non-fit From 90d57fd98a4b9964a8d1f898c3ece1c91191a016 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:15:26 -0800 Subject: [PATCH 46/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index aa74428841a252..e828cd437a3d0a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -253,7 +253,6 @@ def save_checkpoint(self, trainer, unused: Optional = None): " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning ) - epoch = trainer.current_epoch global_step = trainer.global_step self._add_backward_monitor_support(trainer) From c8948cc5dbc2d64e96007db1d1fb362aef01bd1b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:31:47 -0800 Subject: [PATCH 47/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e828cd437a3d0a..ff93d3ca393655 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -97,8 +97,10 @@ class ModelCheckpoint(Callback): To disable, set ``every_n_train_steps = 0``. This value must be non-negative. every_n_val_epochs: Number of validation epochs between checkpoints. To disable, set ``every_n_val_epochs = 0``. This value must be non-negative. - This is not mutually exclusive with ``every_n_val_epochs``. If both are set, pay extreme caution if also setting ``monitor`` - as the ``monitor`` value must be available in both training and validation. This can have unintended consequences with tracking the top K models. + This is not mutually exclusive with ``every_n_val_epochs``. + If both are set, pay extreme caution if also setting ``monitor`` + as the ``monitor`` value must be available in both training and validation. + This can have unintended consequences with tracking the top k models. period: Interval (number of epochs) between checkpoints. .. warning:: @@ -277,7 +279,7 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerState return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.state != TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == trainer.global_step # already saved at the last step From fe1600e284a83ee4ef37f5339f6a475c5c0760f7 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:54:08 -0800 Subject: [PATCH 48/49] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ff93d3ca393655..fbef267f933829 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -281,7 +281,6 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.state != TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == trainer.global_step # already saved at the last step ) From 81385f2dc8d02ec5ef2c15b0aa96e5bca9bc8326 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 9 Mar 2021 14:26:34 +0100 Subject: [PATCH 49/49] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 28ec368a4d831d..49db8ee4360c70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + - Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))