From 727591c4132f68d37bceee1ae1ec1da0a987e169 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 19:44:52 -0800 Subject: [PATCH 01/54] Update model_checkpoint.py --- .../callbacks/model_checkpoint.py | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f05a10a41996b..3be2fc4d5ca60 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -38,6 +38,8 @@ log = logging.getLogger(__name__) warning_cache = WarningCache() +log = logging.getLogger(__name__) + class ModelCheckpoint(Callback): r""" @@ -167,6 +169,8 @@ def __init__( mode: str = "min", period: int = 1, auto_insert_metric_name: bool = True + every_n_epochs: int = 1, + every_n_batches: int = -1, ): super().__init__() self.monitor = monitor @@ -174,8 +178,10 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.period = period self.auto_insert_metric_name = auto_insert_metric_name + self.period = every_n_epochs + self.every_n_epochs = every_n_epochs + self.every_n_batches = every_n_batches self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -186,6 +192,13 @@ def __init__( self.save_function = None self.warned_result_obj = False + if period is not None: + rank_zero_warn( + 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_epochs` instead.', DeprecationWarning + ) + self.every_n_epochs = period + self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) self.__validate_init_configuration() @@ -197,11 +210,24 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + if self._should_skip_saving_checkpoint(trainer): + return + step = trainer.global_step + skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) + if skip_step: + return + self.save_checkpoint(trainer, pl_module) + def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - self.save_checkpoint(trainer) + if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 0: + return + epoch = trainer.current_epoch + if (epoch + 1) % self.every_n_epochs == 0: + self.save_checkpoint(trainer, pl_module) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { @@ -216,7 +242,15 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused: Optional = None): + def _should_skip_saving_checkpoint(self, trainer) -> bool: + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.running_sanity_check # don't save anything during sanity check + or self.save_top_k == 0 # no models are saved + or self._last_global_step_saved == global_step # already saved at the last step + ) + + def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -231,17 +265,6 @@ def save_checkpoint(self, trainer, unused: Optional = None): epoch = trainer.current_epoch global_step = trainer.global_step - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state != TrainerState.FITTING # don't save anything during non-fit - or trainer.sanity_checking # don't save anything during sanity check - or self.period < 1 # no models are saved - or (epoch + 1) % self.period # skip epoch - or self._last_global_step_saved == global_step # already saved at the last step - ): - return - self._add_backward_monitor_support(trainer) self._validate_monitor_key(trainer) @@ -263,6 +286,10 @@ def save_checkpoint(self, trainer, unused: Optional = None): def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self.every_n_epochs == 0 or self.every_n_epochs < -1: + raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + if self.every_n_batches == 0 or self.every_n_batches < -1: + raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): From eeeffd854d880a7b55f3a032d3f20c72a8971ba6 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:11:05 -0800 Subject: [PATCH 02/54] add tests --- CHANGELOG.md | 1 + .../callbacks/model_checkpoint.py | 42 ++++++++--- tests/checkpointing/test_model_checkpoint.py | 71 +++++++++++++++++++ 3 files changed, 103 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f721b263668f..047de38707c76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3be2fc4d5ca60..f92cd42d1ee09 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -95,7 +95,20 @@ class ModelCheckpoint(Callback): save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). + every_n_epochs: Interval (number of epochs) between checkpoints. + every_n_batches: Interval (number of batches) between checkpoints. period: Interval (number of epochs) between checkpoints. +<<<<<<< HEAD +======= + + .. warning:: + This argument has been deprecated in v1.3 and will be removed in v1.5. + Use ``every_n_epochs`` instead. + prefix: A string to put at the beginning of checkpoint filename. + + .. warning:: + This argument has been deprecated in v1.1 and will be removed in v1.3 +>>>>>>> add tests Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -167,10 +180,10 @@ def __init__( save_top_k: Optional[int] = None, save_weights_only: bool = False, mode: str = "min", - period: int = 1, auto_insert_metric_name: bool = True every_n_epochs: int = 1, every_n_batches: int = -1, + period: Optional[int] = None, ): super().__init__() self.monitor = monitor @@ -211,6 +224,9 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.save_function = trainer.save_checkpoint def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + """ + Save a checkpoint during the training loop if configured to do so. + """ if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step @@ -242,14 +258,6 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def _should_skip_saving_checkpoint(self, trainer) -> bool: - return ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.running_sanity_check # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved - or self._last_global_step_saved == global_step # already saved at the last step - ) - def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. @@ -283,13 +291,25 @@ def save_checkpoint(self, trainer, pl_module): # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) + def _should_skip_saving_checkpoint(self, trainer) -> bool: + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.running_sanity_check # don't save anything during sanity check + or self.save_top_k == 0 # no models are saved + or self._last_global_step_saved == global_step # already saved at the last step + ) + def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.every_n_epochs == 0 or self.every_n_epochs < -1: - raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' + ) if self.every_n_batches == 0 or self.every_n_batches < -1: - raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4a8088070f041..34a184f34988c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -514,6 +514,36 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) +def test_invalid_every_n_epoch(tmpdir): + """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + +def test_invalid_every_n_batches(tmpdir): + """ Test that an exception is raised for every_n_batches = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) + def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ @@ -579,6 +609,47 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) +@pytest.mark.parametrize("every_n_epochs", list(range(4))) +def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=0.1, + limit_val_batches=0.1, + val_check_interval=1.0, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + +@pytest.mark.parametrize("every_n_epochs", list(range(4))) +def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs, period=None) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=0.1, + limit_val_batches=0.1, + val_check_interval=1.0, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods() From be36e8623d2ee2114917866056c2a0b4e3a3b9c7 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:12:24 -0800 Subject: [PATCH 03/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f92cd42d1ee09..6739cce3b3dff 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -296,7 +296,7 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.running_sanity_check # don't save anything during sanity check or self.save_top_k == 0 # no models are saved - or self._last_global_step_saved == global_step # already saved at the last step + or self._last_global_step_saved == trainer.global_step # already saved at the last step ) def __validate_init_configuration(self): From 218737f912463d1f61597681c36869566a481736 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:16:58 -0800 Subject: [PATCH 04/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 67 ++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 34a184f34988c..ee1456c990865 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -649,6 +649,73 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) +def test_ckpt_every_n_steps(tmpdir): + """ Tests that the checkpoints are saved every n training steps. """ + + model = LogInTwoMethods() + + trainer = Trainer( + default_root_dir=tmpdir, + min_epochs=2, + max_epochs=2, + progress_bar_refresh_rate=0, + checkpoint_callback=ModelCheckpoint( + filename="{step}", + every_n_epochs=-1, + every_n_steps=16, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ), + logger=False, + ) + + trainer.fit(model) + self.assertCountEqual( + os.listdir(tmpdir), + [ + "step=15.ckpt", + "step=31.ckpt", + "step=47.ckpt", + "step=63.ckpt", + "step=79.ckpt", + "step=95.ckpt", + "step=111.ckpt", + "step=127.ckpt", + ], + ) + +def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): + """ Tests that checkpoints are taken every 30 steps and every epochs """ + model = LogInTwoMethods() + trainer = Trainer( + default_root_dir=tmpdir, + min_epochs=2, + max_epochs=2, + progress_bar_refresh_rate=0, + checkpoint_callback=ModelCheckpoint( + every_n_epochs=1, + every_n_steps=30, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ), + logger=False, + ) + trainer.fit(model) + self.assertCountEqual( + os.listdir(tmpdir), + [ + "epoch=0-step=29.ckpt", + "epoch=0-step=59.ckpt", + "epoch=0-step=63.ckpt", + "epoch=1-step=89.ckpt", + "epoch=1-step=119.ckpt", + "epoch=1-step=127.ckpt", + "epoch=1-step=127-v0.ckpt", + ], + ) + def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ From f89ea0354f2c1d5443b7390b56da3cad157a364e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:24:51 -0800 Subject: [PATCH 05/54] fix tests --- .../callbacks/model_checkpoint.py | 7 -- tests/checkpointing/test_model_checkpoint.py | 65 +++++++++++-------- 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6739cce3b3dff..77d74bdd57171 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -98,17 +98,10 @@ class ModelCheckpoint(Callback): every_n_epochs: Interval (number of epochs) between checkpoints. every_n_batches: Interval (number of batches) between checkpoints. period: Interval (number of epochs) between checkpoints. -<<<<<<< HEAD -======= .. warning:: This argument has been deprecated in v1.3 and will be removed in v1.5. Use ``every_n_epochs`` instead. - prefix: A string to put at the beginning of checkpoint filename. - - .. warning:: - This argument has been deprecated in v1.1 and will be removed in v1.3 ->>>>>>> add tests Note: For extra customization, ModelCheckpoint includes the following attributes: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ee1456c990865..cfa75bc09e323 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -514,6 +514,7 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) + def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ with pytest.raises( @@ -529,6 +530,7 @@ def test_invalid_every_n_epoch(tmpdir): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ with pytest.raises( @@ -613,7 +615,12 @@ def test_model_checkpoint_period(tmpdir, period: int): def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_epochs=every_n_epochs + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -629,11 +636,18 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) + @pytest.mark.parametrize("every_n_epochs", list(range(4))) def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs, period=None) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_epochs=every_n_epochs, + period=None + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[checkpoint_callback], @@ -649,6 +663,7 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) + def test_ckpt_every_n_steps(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ @@ -671,19 +686,17 @@ def test_ckpt_every_n_steps(tmpdir): ) trainer.fit(model) - self.assertCountEqual( - os.listdir(tmpdir), - [ - "step=15.ckpt", - "step=31.ckpt", - "step=47.ckpt", - "step=63.ckpt", - "step=79.ckpt", - "step=95.ckpt", - "step=111.ckpt", - "step=127.ckpt", - ], - ) + expected = [ + "step=15.ckpt", + "step=31.ckpt", + "step=47.ckpt", + "step=63.ckpt", + "step=79.ckpt", + "step=95.ckpt", + "step=111.ckpt", + "step=127.ckpt", + ] + assert set(os.listdir(tmpdir)) == set(expected) def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): """ Tests that checkpoints are taken every 30 steps and every epochs """ @@ -703,18 +716,16 @@ def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): logger=False, ) trainer.fit(model) - self.assertCountEqual( - os.listdir(tmpdir), - [ - "epoch=0-step=29.ckpt", - "epoch=0-step=59.ckpt", - "epoch=0-step=63.ckpt", - "epoch=1-step=89.ckpt", - "epoch=1-step=119.ckpt", - "epoch=1-step=127.ckpt", - "epoch=1-step=127-v0.ckpt", - ], - ) + expected = [ + "epoch=0-step=29.ckpt", + "epoch=0-step=59.ckpt", + "epoch=0-step=63.ckpt", + "epoch=1-step=89.ckpt", + "epoch=1-step=119.ckpt", + "epoch=1-step=127.ckpt", + "epoch=1-step=127-v0.ckpt", + ] + assert set(os.listdir(tmpdir)) == set(expected) def test_model_checkpoint_topk_zero(tmpdir): From f857ffa891f90174b7c10eb5cd99d5bce0797ad9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 21:25:23 -0800 Subject: [PATCH 06/54] every_n_batches --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 77d74bdd57171..6e5ca2ab0d2e7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -223,7 +223,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) + skip_step = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) if skip_step: return self.save_checkpoint(trainer, pl_module) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index cfa75bc09e323..4c677dec5df9b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -664,7 +664,7 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): assert set(os.listdir(tmpdir)) == set(expected) -def test_ckpt_every_n_steps(tmpdir): +def test_ckpt_every_n_batches(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ model = LogInTwoMethods() @@ -677,7 +677,7 @@ def test_ckpt_every_n_steps(tmpdir): checkpoint_callback=ModelCheckpoint( filename="{step}", every_n_epochs=-1, - every_n_steps=16, + every_n_batches=16, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -698,7 +698,7 @@ def test_ckpt_every_n_steps(tmpdir): ] assert set(os.listdir(tmpdir)) == set(expected) -def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): +def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() trainer = Trainer( @@ -708,7 +708,7 @@ def test_ckpt_every_n_steps_and_every_n_epochs(tmpdir): progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( every_n_epochs=1, - every_n_steps=30, + every_n_batches=30, dirpath=tmpdir, save_top_k=-1, save_last=False, From 1763ea44cd5dfe35e327f0ee66a03dfd2b74dcd8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 22:18:40 -0800 Subject: [PATCH 07/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4c677dec5df9b..041dd13eb568e 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -520,15 +520,15 @@ def test_invalid_every_n_epoch(tmpdir): with pytest.raises( MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' ): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, period=None) with pytest.raises( MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' ): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2, period=None) # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1, period=None) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3, period=None) def test_invalid_every_n_batches(tmpdir): From e86305c9d7788f76a892271f464437d8c48a9c1b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 22:30:25 -0800 Subject: [PATCH 08/54] defaults --- .../callbacks/model_checkpoint.py | 1 + tests/checkpointing/test_model_checkpoint.py | 28 +++++-------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6e5ca2ab0d2e7..38a62e846ba86 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -204,6 +204,7 @@ def __init__( ' Please use `every_n_epochs` instead.', DeprecationWarning ) self.every_n_epochs = period + self.period = period self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 041dd13eb568e..31f251fb7896b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -517,13 +517,9 @@ def test_none_monitor_top_k(tmpdir): def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, period=None) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2, period=None) # These should not fail @@ -533,13 +529,9 @@ def test_invalid_every_n_epoch(tmpdir): def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_batches=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) # These should not fail @@ -616,10 +608,7 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, - filename='{epoch}', - save_top_k=-1, - every_n_epochs=every_n_epochs + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -642,11 +631,7 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, - filename='{epoch}', - save_top_k=-1, - every_n_epochs=every_n_epochs, - period=None + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs, period=None ) trainer = Trainer( default_root_dir=tmpdir, @@ -698,6 +683,7 @@ def test_ckpt_every_n_batches(tmpdir): ] assert set(os.listdir(tmpdir)) == set(expected) + def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() From 45f16dd069e6b1c406d313cf7e031f2f8296a61a Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 22:38:04 -0800 Subject: [PATCH 09/54] rm tests --- .../callbacks/model_checkpoint.py | 8 ------- tests/checkpointing/test_model_checkpoint.py | 24 ------------------- 2 files changed, 32 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 38a62e846ba86..789cf83f50db9 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -296,14 +296,6 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_epochs == 0 or self.every_n_epochs < -1: - raise MisconfigurationException( - f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' - ) - if self.every_n_batches == 0 or self.every_n_batches < -1: - raise MisconfigurationException( - f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' - ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 31f251fb7896b..1dcf93bf93a1d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -515,30 +515,6 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=0) -def test_invalid_every_n_epoch(tmpdir): - """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, period=None) - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2, period=None) - - # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1, period=None) - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3, period=None) - - -def test_invalid_every_n_batches(tmpdir): - """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) - - # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) - ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) - - def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ with pytest.warns(UserWarning, match=r'ModelCheckpoint.*is a redundant.*'): From fd90771ade06309a7e365689ae1c3d4f48a24192 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 23 Feb 2021 09:38:39 -0800 Subject: [PATCH 10/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 789cf83f50db9..6ba78e3c4ee63 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -233,7 +233,7 @@ def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 0: + if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1: return epoch = trainer.current_epoch if (epoch + 1) % self.every_n_epochs == 0: From be2ae2e0855de4b3cebcb97fa04e6cc890c772cf Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 23 Feb 2021 23:33:31 -0800 Subject: [PATCH 11/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 1dcf93bf93a1d..06d7238cfa779 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -685,7 +685,6 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): "epoch=1-step=89.ckpt", "epoch=1-step=119.ckpt", "epoch=1-step=127.ckpt", - "epoch=1-step=127-v0.ckpt", ] assert set(os.listdir(tmpdir)) == set(expected) From 572fc9df78f06cef163e29b6c4dd65bd908902ae Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Feb 2021 12:09:01 +0100 Subject: [PATCH 12/54] Prune deprecated metrics for 1.3 (#6161) * prune deprecated metrics for 1.3 * isort / yapf --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 047de38707c76..777a9e057e5fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) - Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) +- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) + - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) From 59d0ce80c2d8d1aa0201a1559ecf73caa58335ab Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 19:44:52 -0800 Subject: [PATCH 13/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6ba78e3c4ee63..0716776d37120 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -218,13 +218,10 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.save_function = trainer.save_checkpoint def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: - """ - Save a checkpoint during the training loop if configured to do so. - """ if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_step = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) if skip_step: return self.save_checkpoint(trainer, pl_module) @@ -252,6 +249,14 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] + def _should_skip_saving_checkpoint(self, trainer) -> bool: + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.running_sanity_check # don't save anything during sanity check + or self.save_top_k == 0 # no models are saved + or self._last_global_step_saved == global_step # already saved at the last step + ) + def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. @@ -296,6 +301,10 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self.every_n_epochs == 0 or self.every_n_epochs < -1: + raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + if self.every_n_batches == 0 or self.every_n_batches < -1: + raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): From c26dd03e6fd0904ae18311554ac743ec01929f6e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 20:11:05 -0800 Subject: [PATCH 14/54] add tests --- CHANGELOG.md | 1 + tests/checkpointing/test_model_checkpoint.py | 30 ++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 777a9e057e5fc..15c39ebb0446f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 06d7238cfa779..bee12b655e9ad 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -514,6 +514,36 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) +def test_invalid_every_n_epoch(tmpdir): + """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + +def test_invalid_every_n_batches(tmpdir): + """ Test that an exception is raised for every_n_batches = 0 or < -1. """ + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=0*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) + with pytest.raises( + MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' + ): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) + + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) + def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ From f7c510051b8361af305da44cd4d257b6d627cc06 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 22 Feb 2021 22:30:25 -0800 Subject: [PATCH 15/54] defaults --- tests/checkpointing/test_model_checkpoint.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index bee12b655e9ad..0cfdf53dacfc5 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -531,13 +531,9 @@ def test_invalid_every_n_epoch(tmpdir): def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_batches=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_batches=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) # These should not fail From 81a1434e278f480cade04ba3b6dafe5e7f1a0982 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 24 Feb 2021 09:30:28 -0800 Subject: [PATCH 16/54] Update CHANGELOG.md --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15c39ebb0446f..047de38707c76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) -- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) -- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) - - Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) From 6423e1d20a067a990782057b11262f1bdbbbc1d8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 24 Feb 2021 09:50:30 -0800 Subject: [PATCH 17/54] pre-commit --- pytorch_lightning/callbacks/model_checkpoint.py | 12 ++++++++---- tests/checkpointing/test_model_checkpoint.py | 10 ++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0716776d37120..1949f565d6ca1 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -251,9 +251,9 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): def _should_skip_saving_checkpoint(self, trainer) -> bool: return ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run + trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.running_sanity_check # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved + or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == global_step # already saved at the last step ) @@ -302,9 +302,13 @@ def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.every_n_epochs == 0 or self.every_n_epochs < -1: - raise MisconfigurationException(f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' + ) if self.every_n_batches == 0 or self.every_n_batches < -1: - raise MisconfigurationException(f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1') + raise MisconfigurationException( + f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0cfdf53dacfc5..cd24cead79f88 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -514,21 +514,19 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=-1) ModelCheckpoint(dirpath=tmpdir, save_top_k=0) + def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=0*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) - with pytest.raises( - MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*' - ): + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) # These should not fail ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): From a7a469b44748f5e7cbf05f0112c33a86164a19a9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 21:41:06 -0800 Subject: [PATCH 18/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1949f565d6ca1..10e3c4059bb49 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -249,14 +249,6 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def _should_skip_saving_checkpoint(self, trainer) -> bool: - return ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.running_sanity_check # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved - or self._last_global_step_saved == global_step # already saved at the last step - ) - def save_checkpoint(self, trainer, pl_module): """ Performs the main logic around saving a checkpoint. From 3cfc44a7d67fe38cf793a8e6fff1f700fbc9361b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 21:44:52 -0800 Subject: [PATCH 19/54] update defaults --- pytorch_lightning/callbacks/model_checkpoint.py | 10 +++++----- tests/checkpointing/test_model_checkpoint.py | 14 +++++++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 10e3c4059bb49..9ddfa56b9551d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -175,7 +175,7 @@ def __init__( mode: str = "min", auto_insert_metric_name: bool = True every_n_epochs: int = 1, - every_n_batches: int = -1, + every_n_batches: int = 0, period: Optional[int] = None, ): super().__init__() @@ -293,13 +293,13 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_epochs == 0 or self.every_n_epochs < -1: + if self.every_n_epochs <= -1: raise MisconfigurationException( - f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be positive or -1' + f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be non-negative.' ) - if self.every_n_batches == 0 or self.every_n_batches < -1: + if self.every_n_batches <= -1: raise MisconfigurationException( - f'Invalid value for every_n_batches={self.every_n_batches}. Must be positive or -1' + f'Invalid value for every_n_batches={self.every_n_batches}. Must be non-negative.' ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index cd24cead79f88..6cd1c800d5853 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -517,25 +517,25 @@ def test_none_monitor_top_k(tmpdir): def test_invalid_every_n_epoch(tmpdir): """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=0*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0) + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-1*'): + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, every_n_batches=1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) def test_invalid_every_n_batches(tmpdir): """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=0*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) + with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-1*'): + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) # These should not fail - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) @@ -661,7 +661,7 @@ def test_ckpt_every_n_batches(tmpdir): progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( filename="{step}", - every_n_epochs=-1, + every_n_epochs=0, every_n_batches=16, dirpath=tmpdir, save_top_k=-1, From 70dc43869e5997aa9458873739ffd62d1a462843 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 21:45:57 -0800 Subject: [PATCH 20/54] Update test_remove_1-5.py --- tests/deprecated_api/test_remove_1-5.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 7d8c7d2adeea1..f205f124f4e86 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -78,6 +78,7 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) +<<<<<<< HEAD def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): @@ -104,3 +105,10 @@ def configure_optimizers(self): with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): trainer.fit(model) +======= +def test_v1_5_0_model_checkpoint_period(tmpdir): + with no_warning_call(DeprecationWarning): + ModelCheckpoint(dirpath=tmpdir) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + ModelCheckpoint(dirpath=tmpdir, period=1) +>>>>>>> Update test_remove_1-5.py From ddaa78394f42cd19387bb67726e4c427b1d328b9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 22:00:37 -0800 Subject: [PATCH 21/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9ddfa56b9551d..6a433e0cd232b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -101,6 +101,7 @@ class ModelCheckpoint(Callback): .. warning:: This argument has been deprecated in v1.3 and will be removed in v1.5. + Use ``every_n_epochs`` instead. Note: From 2384874905f9b3376b4c9181c01b9471d2af08f1 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 22:13:45 -0800 Subject: [PATCH 22/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6a433e0cd232b..6916c22f61d0e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -222,8 +222,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) - if skip_step: + skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + if skip_batch: return self.save_checkpoint(trainer, pl_module) From 1f2d0f20fe1d852ae06845793f56f98ade377f8d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 23:52:58 -0800 Subject: [PATCH 23/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6916c22f61d0e..74a8d63ea2e77 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -231,11 +231,14 @@ def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1: - return epoch = trainer.current_epoch - if (epoch + 1) % self.every_n_epochs == 0: - self.save_checkpoint(trainer, pl_module) + skip = ( + self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 + or (epoch + 1) % self.every_n_epochs != 0 + ) + if skip: + return + self.save_checkpoint(trainer, pl_module) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { From a7cec2bbb0f45a09a71ff4a303987d248b50df20 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 23:53:51 -0800 Subject: [PATCH 24/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 74a8d63ea2e77..579061d990fa9 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -231,10 +231,9 @@ def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - epoch = trainer.current_epoch skip = ( self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 - or (epoch + 1) % self.every_n_epochs != 0 + or (trainer.current_epoch + 1) % self.every_n_epochs != 0 ) if skip: return From 6e06b8a8124fd03dbf5f6c1c48fb347d608610a6 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 1 Mar 2021 23:56:59 -0800 Subject: [PATCH 25/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 579061d990fa9..4ab44c2fd8d3f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -223,6 +223,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data return step = trainer.global_step skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + log.warning(f"in on_train_batch_end at step {step}, every_n_batches={self.every_n_batches}, going to skip batch? {skip_batch}") if skip_batch: return self.save_checkpoint(trainer, pl_module) From 28e36833b9bbca005c1f84fb558585ad4c8b6430 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 2 Mar 2021 00:25:53 -0800 Subject: [PATCH 26/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4ab44c2fd8d3f..c90b680b9a72a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -222,8 +222,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) - log.warning(f"in on_train_batch_end at step {step}, every_n_batches={self.every_n_batches}, going to skip batch? {skip_batch}") + skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) From 9239325de51627bf3a1c94f0322516c43a0c0d2b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 21:56:50 -0800 Subject: [PATCH 27/54] fix tests --- CHANGELOG.md | 2 ++ .../callbacks/model_checkpoint.py | 8 ++---- tests/checkpointing/test_model_checkpoint.py | 28 +++++-------------- 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 047de38707c76..6c40046eea537 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- `period` has been deprecated in favor of `every_n_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index c90b680b9a72a..cc5ccd2415086 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -186,8 +186,8 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name - self.period = every_n_epochs - self.every_n_epochs = every_n_epochs + self.every_n_epochs = period or every_n_epochs + self.period = self.every_n_epochs self.every_n_batches = every_n_batches self._last_global_step_saved = -1 self.current_score = None @@ -204,8 +204,6 @@ def __init__( 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' ' Please use `every_n_epochs` instead.', DeprecationWarning ) - self.every_n_epochs = period - self.period = period self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) @@ -221,7 +219,7 @@ def on_pretrain_routine_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: if self._should_skip_saving_checkpoint(trainer): return - step = trainer.global_step + step = trainer.total_batch_idx skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) if skip_batch: return diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 6cd1c800d5853..ba00ba3212605 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -614,9 +614,8 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -637,9 +636,8 @@ def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -653,16 +651,15 @@ def test_ckpt_every_n_batches(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ model = LogInTwoMethods() - + every_n_batches = 16 trainer = Trainer( default_root_dir=tmpdir, - min_epochs=2, max_epochs=2, progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( filename="{step}", every_n_epochs=0, - every_n_batches=16, + every_n_batches=every_n_batches, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -671,16 +668,7 @@ def test_ckpt_every_n_batches(tmpdir): ) trainer.fit(model) - expected = [ - "step=15.ckpt", - "step=31.ckpt", - "step=47.ckpt", - "step=63.ckpt", - "step=79.ckpt", - "step=95.ckpt", - "step=111.ckpt", - "step=127.ckpt", - ] + expected=[f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] assert set(os.listdir(tmpdir)) == set(expected) @@ -689,9 +677,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): model = LogInTwoMethods() trainer = Trainer( default_root_dir=tmpdir, - min_epochs=2, max_epochs=2, - progress_bar_refresh_rate=0, checkpoint_callback=ModelCheckpoint( every_n_epochs=1, every_n_batches=30, From b9152a1810522487b2507cf93fea74b8fbf3c4e8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 21:57:37 -0800 Subject: [PATCH 28/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ba00ba3212605..05eb9003dd11f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -668,7 +668,7 @@ def test_ckpt_every_n_batches(tmpdir): ) trainer.fit(model) - expected=[f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] + expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] assert set(os.listdir(tmpdir)) == set(expected) From dbbb446f2f4bd5996f4683d6100c2a8a3866af45 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 21:59:41 -0800 Subject: [PATCH 29/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index cc5ccd2415086..48605c77f0069 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -294,11 +294,11 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_epochs <= -1: + if self.every_n_epochs < 0: raise MisconfigurationException( f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be non-negative.' ) - if self.every_n_batches <= -1: + if self.every_n_batches < 0: raise MisconfigurationException( f'Invalid value for every_n_batches={self.every_n_batches}. Must be non-negative.' ) From 22e917b0236a8820641022800b82470adaf3c6b8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 22:18:23 -0800 Subject: [PATCH 30/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 48605c77f0069..e34cb34ef326a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -220,7 +220,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data if self._should_skip_saving_checkpoint(trainer): return step = trainer.total_batch_idx - skip_batch = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0) + skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) From bd45c53d65f4822fa4f27416585e30a7163e6eaf Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 23:44:15 -0800 Subject: [PATCH 31/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e34cb34ef326a..717a1c7ff4ec9 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -38,8 +38,6 @@ log = logging.getLogger(__name__) warning_cache = WarningCache() -log = logging.getLogger(__name__) - class ModelCheckpoint(Callback): r""" @@ -218,9 +216,11 @@ def on_pretrain_routine_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: if self._should_skip_saving_checkpoint(trainer): + log.critical("in train batch end, not saving checkpoint after trainer check") return step = trainer.total_batch_idx skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + log.critical(f"in train batch end, every_n_batches={self.every_n_batches}, step={step}, skip_batch? {skip_batch}") if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -294,14 +294,6 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_epochs < 0: - raise MisconfigurationException( - f'Invalid value for every_n_epochs={self.every_n_epochs}. Must be non-negative.' - ) - if self.every_n_batches < 0: - raise MisconfigurationException( - f'Invalid value for every_n_batches={self.every_n_batches}. Must be non-negative.' - ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): From 744a0783e11069da5982c73081a9dd04831b7ae2 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 4 Mar 2021 23:46:22 -0800 Subject: [PATCH 32/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 05eb9003dd11f..30af7ff29a2b8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -516,27 +516,17 @@ def test_none_monitor_top_k(tmpdir): def test_invalid_every_n_epoch(tmpdir): - """ Test that an exception is raised for every_n_epochs = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-1*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1) - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_epochs=-2*'): - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-2) - - # These should not fail + """ Test different configurations for every_n_epochs. """ ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, every_n_batches=1) ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1, every_n_batches=1) def test_invalid_every_n_batches(tmpdir): - """ Test that an exception is raised for every_n_batches = 0 or < -1. """ - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-1*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1) - with pytest.raises(MisconfigurationException, match=r'Invalid value for every_n_batches=-2*'): - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-2) - - # These should not fail + """ Test different configurations for every_n_batches. """ ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) + ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1, every_n_epochs=2) def test_none_monitor_save_last(tmpdir): From 7b7ca5db0b0bf21acaa824dfc21c301b5038622d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 00:12:22 -0800 Subject: [PATCH 33/54] ckpt-callback --- .../callbacks/model_checkpoint.py | 4 ++- tests/checkpointing/test_model_checkpoint.py | 32 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 717a1c7ff4ec9..30f2bef20d6fc 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -220,7 +220,9 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data return step = trainer.total_batch_idx skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) - log.critical(f"in train batch end, every_n_batches={self.every_n_batches}, step={step}, skip_batch? {skip_batch}") + log.critical( + f"in train batch end, every_n_batches={self.every_n_batches}, step={step}, skip_batch? {skip_batch}" + ) if skip_batch: return self.save_checkpoint(trainer, pl_module) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 30af7ff29a2b8..f6c851430f5fe 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -642,18 +642,19 @@ def test_ckpt_every_n_batches(tmpdir): model = LogInTwoMethods() every_n_batches = 16 + checkpoint_callback = ModelCheckpoint( + filename="{step}", + every_n_epochs=0, + every_n_batches=every_n_batches, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, progress_bar_refresh_rate=0, - checkpoint_callback=ModelCheckpoint( - filename="{step}", - every_n_epochs=0, - every_n_batches=every_n_batches, - dirpath=tmpdir, - save_top_k=-1, - save_last=False, - ), + callbacks=[checkpoint_callback], logger=False, ) @@ -665,16 +666,17 @@ def test_ckpt_every_n_batches(tmpdir): def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() + checkpoint_callback = ModelCheckpoint( + every_n_epochs=1, + every_n_batches=30, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ), trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, - checkpoint_callback=ModelCheckpoint( - every_n_epochs=1, - every_n_batches=30, - dirpath=tmpdir, - save_top_k=-1, - save_last=False, - ), + callbacks=[checkpoint_callback], logger=False, ) trainer.fit(model) From b32e500fdb7b565243105c2e5f9ca58b2718a955 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 00:31:16 -0800 Subject: [PATCH 34/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f6c851430f5fe..564af3c41f926 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -672,7 +672,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): dirpath=tmpdir, save_top_k=-1, save_last=False, - ), + ) trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, From 7ff63acff2c6b11645853ad212f7310eaa37278f Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 00:31:59 -0800 Subject: [PATCH 35/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 30f2bef20d6fc..aad1e88b99672 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -94,7 +94,7 @@ class ModelCheckpoint(Callback): saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). every_n_epochs: Interval (number of epochs) between checkpoints. - every_n_batches: Interval (number of batches) between checkpoints. + every_n_batches: Interval (number of training batches) between checkpoints. period: Interval (number of epochs) between checkpoints. .. warning:: From 77b70fdc2331afa4cbed3b5470742b51fec9c9a0 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 00:49:52 -0800 Subject: [PATCH 36/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index aad1e88b99672..412fd9f7e1355 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -216,13 +216,9 @@ def on_pretrain_routine_start(self, trainer, pl_module): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: if self._should_skip_saving_checkpoint(trainer): - log.critical("in train batch end, not saving checkpoint after trainer check") return step = trainer.total_batch_idx skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) - log.critical( - f"in train batch end, every_n_batches={self.every_n_batches}, step={step}, skip_batch? {skip_batch}" - ) if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -235,6 +231,9 @@ def on_validation_end(self, trainer, pl_module): self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_epochs != 0 ) + log.critical( + f"in validation end, every_n_epochs={self.every_n_epochs}, period={self.period}, step={step}, skip? {skip}" + ) if skip: return self.save_checkpoint(trainer, pl_module) From 791f876705e2354b544feb82d31aef0f69f493e0 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 01:06:07 -0800 Subject: [PATCH 37/54] validation-end --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 412fd9f7e1355..2b590217d8e07 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -232,7 +232,7 @@ def on_validation_end(self, trainer, pl_module): or (trainer.current_epoch + 1) % self.every_n_epochs != 0 ) log.critical( - f"in validation end, every_n_epochs={self.every_n_epochs}, period={self.period}, step={step}, skip? {skip}" + f"in validation end, every_n_epochs={self.every_n_epochs}, period={self.period}, epoch={trainer.current_epoch}, skip? {skip}" ) if skip: return diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 564af3c41f926..b1f94cc0d1858 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -581,9 +581,8 @@ def test_model_checkpoint_period(tmpdir, period: int): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) From 2e802c408aa9b969ebc4778113eea37fa7b9b4ea Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 01:24:16 -0800 Subject: [PATCH 38/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2b590217d8e07..44b84021c9794 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -184,7 +184,7 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name - self.every_n_epochs = period or every_n_epochs + self.every_n_epochs = period if period is not None else every_n_epochs self.period = self.every_n_epochs self.every_n_batches = every_n_batches self._last_global_step_saved = -1 @@ -231,9 +231,6 @@ def on_validation_end(self, trainer, pl_module): self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_epochs != 0 ) - log.critical( - f"in validation end, every_n_epochs={self.every_n_epochs}, period={self.period}, epoch={trainer.current_epoch}, skip? {skip}" - ) if skip: return self.save_checkpoint(trainer, pl_module) From fd9f6617362cbe4e3b5216ddee69d53811008b31 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 09:59:30 -0800 Subject: [PATCH 39/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 42 +++++++------------- 1 file changed, 15 insertions(+), 27 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b1f94cc0d1858..9bcf7fac11f94 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -515,20 +515,6 @@ def test_none_monitor_top_k(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_top_k=0) -def test_invalid_every_n_epoch(tmpdir): - """ Test different configurations for every_n_epochs. """ - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=0, every_n_batches=1) - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=3) - ModelCheckpoint(dirpath=tmpdir, every_n_epochs=-1, every_n_batches=1) - - -def test_invalid_every_n_batches(tmpdir): - """ Test different configurations for every_n_batches. """ - ModelCheckpoint(dirpath=tmpdir, every_n_batches=0) - ModelCheckpoint(dirpath=tmpdir, every_n_batches=3) - ModelCheckpoint(dirpath=tmpdir, every_n_batches=-1, every_n_epochs=2) - - def test_none_monitor_save_last(tmpdir): """ Test that a warning appears for save_last=True with monitor=None. """ with pytest.warns(UserWarning, match=r'ModelCheckpoint.*is a redundant.*'): @@ -615,11 +601,12 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): @pytest.mark.parametrize("every_n_epochs", list(range(4))) -def test_model_checkpoint_every_n_epochs_and_no_period(tmpdir, every_n_epochs): +def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs): + """ Tests that if period is set, it takes precedence over every_n_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs, period=None + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=(2 * every_n_epochs), period=every_n_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -662,32 +649,33 @@ def test_ckpt_every_n_batches(tmpdir): assert set(os.listdir(tmpdir)) == set(expected) -def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir): +@pytest.mark.parametrize("every_n_epochs", [1, 3]) +def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() checkpoint_callback = ModelCheckpoint( - every_n_epochs=1, + every_n_epochs=every_n_epochs, every_n_batches=30, dirpath=tmpdir, save_top_k=-1, save_last=False, + filename="{step}", ) + max_epochs = 3 + epoch_step_length = 64 trainer = Trainer( default_root_dir=tmpdir, - max_epochs=2, + max_epochs=max_epochs, callbacks=[checkpoint_callback], logger=False, ) trainer.fit(model) - expected = [ - "epoch=0-step=29.ckpt", - "epoch=0-step=59.ckpt", - "epoch=0-step=63.ckpt", - "epoch=1-step=89.ckpt", - "epoch=1-step=119.ckpt", - "epoch=1-step=127.ckpt", + expected_steps_for_ckpt = [ + i for i in range(epoch_step_length * max_epochs) + if (i % every_n_batches) == 0 or (i * epoch_step_length % every_n_epochs == 0) ] - assert set(os.listdir(tmpdir)) == set(expected) + expected_ckpt_files = [f"step={step}.ckpt" for step in step] + assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) def test_model_checkpoint_topk_zero(tmpdir): From f56578629060ae2b167414d0ebea97c3d95a2ad4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 10:02:22 -0800 Subject: [PATCH 40/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 9bcf7fac11f94..f7ebe534ff8bf 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -653,9 +653,10 @@ def test_ckpt_every_n_batches(tmpdir): def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() + every_n_batches = 30 checkpoint_callback = ModelCheckpoint( every_n_epochs=every_n_epochs, - every_n_batches=30, + every_n_batches=every_n_batches, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -674,7 +675,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): i for i in range(epoch_step_length * max_epochs) if (i % every_n_batches) == 0 or (i * epoch_step_length % every_n_epochs == 0) ] - expected_ckpt_files = [f"step={step}.ckpt" for step in step] + expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) From bceba8b316dc82ec46814a34243d17ebd5ae138d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 10:24:51 -0800 Subject: [PATCH 41/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f7ebe534ff8bf..12e3356af7228 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -673,7 +673,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): trainer.fit(model) expected_steps_for_ckpt = [ i for i in range(epoch_step_length * max_epochs) - if (i % every_n_batches) == 0 or (i * epoch_step_length % every_n_epochs == 0) + if ((i+1) % every_n_batches) == 0 or (i+1) % (every_n_epochs * epoch_step_length) == 0) ] expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) From 1e9244b007e7fcf8581437d6f66d0d11257838ea Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 5 Mar 2021 10:27:16 -0800 Subject: [PATCH 42/54] Update test_model_checkpoint.py --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 12e3356af7228..56a5a442cb190 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -673,7 +673,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): trainer.fit(model) expected_steps_for_ckpt = [ i for i in range(epoch_step_length * max_epochs) - if ((i+1) % every_n_batches) == 0 or (i+1) % (every_n_epochs * epoch_step_length) == 0) + if ((i + 1) % every_n_batches) == 0 or (i + 1) % (every_n_epochs * epoch_step_length) == 0 ] expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) From 47868d1e40aaafde184c5fc976db188238524062 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:00:58 -0800 Subject: [PATCH 43/54] clarify-names - Make names explicit as to which hooks they apply to - Use step instead of batch for consistency with global step --- CHANGELOG.md | 2 +- .../callbacks/model_checkpoint.py | 47 ++++++++----- tests/checkpointing/test_model_checkpoint.py | 66 +++++++++++++------ tests/deprecated_api/test_remove_1-5.py | 5 +- 4 files changed, 78 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c40046eea537..8dd87c5ca9e16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,7 +56,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated -- `period` has been deprecated in favor of `every_n_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) +- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 44b84021c9794..732a84944b03c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -93,14 +93,18 @@ class ModelCheckpoint(Callback): save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). - every_n_epochs: Interval (number of epochs) between checkpoints. - every_n_batches: Interval (number of training batches) between checkpoints. + every_n_train_steps: Number of training steps between checkpoints. + To disable, set ``every_n_train_steps = 0``. This value must be non-negative. + every_n_val_epochs: Number of validation epochs between checkpoints. + To disable, set ``every_n_val_epochs = 0``. This value must be non-negative. + This is not mutually exclusive with ``every_n_val_epochs``. If both are set, pay extreme caution if also setting ``monitor`` + as the ``monitor`` value must be available in both training and validation. This can have unintended consequences with tracking the top K models. period: Interval (number of epochs) between checkpoints. .. warning:: This argument has been deprecated in v1.3 and will be removed in v1.5. - Use ``every_n_epochs`` instead. + Use ``every_n_val_epochs`` instead. Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -173,8 +177,8 @@ def __init__( save_weights_only: bool = False, mode: str = "min", auto_insert_metric_name: bool = True - every_n_epochs: int = 1, - every_n_batches: int = 0, + every_n_train_steps: int = 0, + every_n_val_epochs: int = 1, period: Optional[int] = None, ): super().__init__() @@ -184,9 +188,9 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name - self.every_n_epochs = period if period is not None else every_n_epochs - self.period = self.every_n_epochs - self.every_n_batches = every_n_batches + self.every_n_val_epochs = period if period is not None else every_n_val_epochs + self.period = self.every_n_val_epochs + self.every_n_train_steps = every_n_train_steps self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -200,7 +204,7 @@ def __init__( if period is not None: rank_zero_warn( 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_epochs` instead.', DeprecationWarning + ' Please use `every_n_val_epochs` instead.', DeprecationWarning ) self.__init_monitor_mode(monitor, mode) @@ -214,11 +218,12 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: + def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None: + """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ if self._should_skip_saving_checkpoint(trainer): return - step = trainer.total_batch_idx - skip_batch = self.every_n_batches < 1 or ((step + 1) % self.every_n_batches != 0) + step = trainer.global_step + skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -228,8 +233,8 @@ def on_validation_end(self, trainer, pl_module): checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 1 - or (trainer.current_epoch + 1) % self.every_n_epochs != 0 + self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 ) if skip: return @@ -248,7 +253,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, pl_module): + def save_checkpoint(self, trainer, unused): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -292,6 +297,14 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self.every_n_train_steps < 0: + raise MisconfigurationException( + f'Invalid value for every_n_train_batches={self.every_n_train_steps}. Must be >= 0' + ) + if self.every_n_val_epochs < 0: + raise MisconfigurationException( + f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): @@ -610,9 +623,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A self._save_model(trainer, filepath) if ( - self.save_top_k is None - and self.best_model_path - and self.best_model_path != filepath + self.save_top_k is None and self.best_model_path and self.best_model_path != filepath and trainer.is_global_zero ): self._del_model(self.best_model_path) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 56a5a442cb190..bd01bf96ff381 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -524,6 +524,26 @@ def test_none_monitor_save_last(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_last=False) +def test_invalid_every_n_val_epochs(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() @@ -578,12 +598,12 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_epochs", list(range(4))) -def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=every_n_epochs + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -596,17 +616,22 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_epochs", list(range(4))) -def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs): - """ Tests that if period is set, it takes precedence over every_n_epochs for backwards compatibility. """ +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): + """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ model = LogInTwoMethods() epochs = 5 checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_epochs=(2 * every_n_epochs), period=every_n_epochs + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_val_epochs=(2 * every_n_val_epochs), + period=every_n_val_epochs ) trainer = Trainer( default_root_dir=tmpdir, @@ -619,19 +644,20 @@ def test_model_checkpoint_every_n_epochs_and_period(tmpdir, every_n_epochs): trainer.fit(model) # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else [] + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] assert set(os.listdir(tmpdir)) == set(expected) -def test_ckpt_every_n_batches(tmpdir): +def test_ckpt_every_n_train_steps(tmpdir): """ Tests that the checkpoints are saved every n training steps. """ model = LogInTwoMethods() - every_n_batches = 16 + every_n_train_steps = 16 checkpoint_callback = ModelCheckpoint( filename="{step}", - every_n_epochs=0, - every_n_batches=every_n_batches, + every_n_val_epochs=0, + every_n_train_steps=every_n_train_steps, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -645,18 +671,18 @@ def test_ckpt_every_n_batches(tmpdir): ) trainer.fit(model) - expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_batches)] + expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_train_steps)] assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_epochs", [1, 3]) -def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): +@pytest.mark.parametrize("every_n_val_epochs", [1, 3]) +def test_ckpt_every_n_train_steps_and_every_n_val_epochs(tmpdir, every_n_val_epochs): """ Tests that checkpoints are taken every 30 steps and every epochs """ model = LogInTwoMethods() - every_n_batches = 30 + every_n_train_steps = 30 checkpoint_callback = ModelCheckpoint( - every_n_epochs=every_n_epochs, - every_n_batches=every_n_batches, + every_n_val_epochs=every_n_val_epochs, + every_n_train_steps=every_n_train_steps, dirpath=tmpdir, save_top_k=-1, save_last=False, @@ -673,7 +699,7 @@ def test_ckpt_every_n_batches_and_every_n_epochs(tmpdir, every_n_epochs): trainer.fit(model) expected_steps_for_ckpt = [ i for i in range(epoch_step_length * max_epochs) - if ((i + 1) % every_n_batches) == 0 or (i + 1) % (every_n_epochs * epoch_step_length) == 0 + if ((i + 1) % every_n_train_steps) == 0 or (i + 1) % (every_n_val_epochs * epoch_step_length) == 0 ] expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f205f124f4e86..e65ebbab254de 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -78,7 +78,6 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) -<<<<<<< HEAD def test_v1_5_0_running_sanity_check(): trainer = Trainer() with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'): @@ -105,10 +104,10 @@ def configure_optimizers(self): with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): trainer.fit(model) -======= + + def test_v1_5_0_model_checkpoint_period(tmpdir): with no_warning_call(DeprecationWarning): ModelCheckpoint(dirpath=tmpdir) with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): ModelCheckpoint(dirpath=tmpdir, period=1) ->>>>>>> Update test_remove_1-5.py From 4b964038d94a5b0cf5708644e479be7c6f4ae6da Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:03:58 -0800 Subject: [PATCH 44/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 732a84944b03c..3b39f0101cd7c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -253,7 +253,7 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]): self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - def save_checkpoint(self, trainer, unused): + def save_checkpoint(self, trainer, unused: Optional = None): """ Performs the main logic around saving a checkpoint. This method runs on all ranks, it is the responsibility of `self.save_function` @@ -289,7 +289,8 @@ def save_checkpoint(self, trainer, unused): def _should_skip_saving_checkpoint(self, trainer) -> bool: return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.running_sanity_check # don't save anything during sanity check + or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == trainer.global_step # already saved at the last step ) From 989fafaa2602dbf0ed3dd18a17ed06ed40b69752 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:12:24 -0800 Subject: [PATCH 45/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3b39f0101cd7c..8f5c795d4bf5a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -287,6 +287,7 @@ def save_checkpoint(self, trainer, unused: Optional = None): self._save_last_checkpoint(trainer, monitor_candidates) def _should_skip_saving_checkpoint(self, trainer) -> bool: + from pytorch_lightning.trainer.states import TrainerState return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.state != TrainerState.FITTING # don't save anything during non-fit From 99be7202a5dd127037cc745f678f1f1282cc2a6e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:15:26 -0800 Subject: [PATCH 46/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8f5c795d4bf5a..5e2d49b9d9432 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -265,7 +265,6 @@ def save_checkpoint(self, trainer, unused: Optional = None): " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning ) - epoch = trainer.current_epoch global_step = trainer.global_step self._add_backward_monitor_support(trainer) From 524ba6891cb871b6eab59cd4319b576db43fc872 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:31:47 -0800 Subject: [PATCH 47/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 5e2d49b9d9432..2252a5ee19768 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -97,8 +97,10 @@ class ModelCheckpoint(Callback): To disable, set ``every_n_train_steps = 0``. This value must be non-negative. every_n_val_epochs: Number of validation epochs between checkpoints. To disable, set ``every_n_val_epochs = 0``. This value must be non-negative. - This is not mutually exclusive with ``every_n_val_epochs``. If both are set, pay extreme caution if also setting ``monitor`` - as the ``monitor`` value must be available in both training and validation. This can have unintended consequences with tracking the top K models. + This is not mutually exclusive with ``every_n_val_epochs``. + If both are set, pay extreme caution if also setting ``monitor`` + as the ``monitor`` value must be available in both training and validation. + This can have unintended consequences with tracking the top k models. period: Interval (number of epochs) between checkpoints. .. warning:: @@ -289,7 +291,7 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerState return ( trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.state != TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == trainer.global_step # already saved at the last step From c2120ff2d342a200951010275c01de6f792b74d9 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 8 Mar 2021 23:54:08 -0800 Subject: [PATCH 48/54] Update model_checkpoint.py --- pytorch_lightning/callbacks/model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2252a5ee19768..08260da1a669c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -293,7 +293,6 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: trainer.fast_dev_run # disable checkpointing with fast_dev_run or trainer.state != TrainerState.FITTING # don't save anything during non-fit or trainer.sanity_checking # don't save anything during sanity check - or self.save_top_k == 0 # no models are saved or self._last_global_step_saved == trainer.global_step # already saved at the last step ) From 3ccf8d147badfcdb1e9ad6640bad714b76d660ba Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 10 Mar 2021 16:24:19 -0800 Subject: [PATCH 49/54] mutual-exclusive Make every_n_train_steps and every_n_val_epochs mutually exclusive --- .../callbacks/model_checkpoint.py | 75 ++++++++++++++----- tests/checkpointing/test_model_checkpoint.py | 54 ++++++------- 2 files changed, 77 insertions(+), 52 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 08260da1a669c..493129822b3f5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -94,13 +94,13 @@ class ModelCheckpoint(Callback): saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). every_n_train_steps: Number of training steps between checkpoints. - To disable, set ``every_n_train_steps = 0``. This value must be non-negative. + If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training + To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative. + This must be mutually exclusive with ``every_n_val_epochs``. every_n_val_epochs: Number of validation epochs between checkpoints. - To disable, set ``every_n_val_epochs = 0``. This value must be non-negative. - This is not mutually exclusive with ``every_n_val_epochs``. - If both are set, pay extreme caution if also setting ``monitor`` - as the ``monitor`` value must be available in both training and validation. - This can have unintended consequences with tracking the top k models. + If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end + To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``every_n_train_steps``. period: Interval (number of epochs) between checkpoints. .. warning:: @@ -179,8 +179,8 @@ def __init__( save_weights_only: bool = False, mode: str = "min", auto_insert_metric_name: bool = True - every_n_train_steps: int = 0, - every_n_val_epochs: int = 1, + every_n_train_steps: Optional[int] = None, + every_n_val_epochs: Optional[int] = None, period: Optional[int] = None, ): super().__init__() @@ -190,9 +190,6 @@ def __init__( self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name - self.every_n_val_epochs = period if period is not None else every_n_val_epochs - self.period = self.every_n_val_epochs - self.every_n_train_steps = every_n_train_steps self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -203,14 +200,9 @@ def __init__( self.save_function = None self.warned_result_obj = False - if period is not None: - rank_zero_warn( - 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.', DeprecationWarning - ) - self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) self.__validate_init_configuration() def on_pretrain_routine_start(self, trainer, pl_module): @@ -235,8 +227,8 @@ def on_validation_end(self, trainer, pl_module): checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs < 1 - or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 + self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs is None + or self.every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 ) if skip: return @@ -301,12 +293,17 @@ def __validate_init_configuration(self): raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.every_n_train_steps < 0: raise MisconfigurationException( - f'Invalid value for every_n_train_batches={self.every_n_train_steps}. Must be >= 0' + f'Invalid value for every_n_train_steps={self.every_n_train_steps}. Must be >= 0' ) if self.every_n_val_epochs < 0: raise MisconfigurationException( f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0' ) + if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0: + raise MisconfigurationException( + f'Invalid values for every_n_train_steps={self.every_n_train_steps} and every_n_val_epochs={self.every_n_val_epochs}.' + 'Both cannot be enabled at the same time.' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): @@ -353,6 +350,44 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] + def __init_triggers( + self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + ) -> None: + + # Default to running once after each validation epoch if neither + # every_n_train_steps nor every_n_val_epochs is set + self.every_n_val_epochs = every_n_val_epochs or 0 + self.every_n_train_steps = every_n_train_steps or 0 + if self.every_n_train_steps == 0 and self.every_n_val_epochs == 0: + self.every_n_val_epochs = 1 + log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") + + # period takes precedence for every_n_val_epochs for backwards compatibility + if period is not None: + rank_zero_warn( + 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self.every_n_val_epochs = period + + self._period = self.every_n_val_epochs + + @property + def period(self) -> Optional[int]: + rank_zero_warn( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + return self._period + + @period.setter + def period(self, value: Optional[int]) -> None: + rank_zero_warn( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self._period = value + @rank_zero_only def _del_model(self, filepath: str): if self._fs.exists(filepath): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index bd01bf96ff381..6a4c637ad485b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -544,6 +544,22 @@ def test_invalid_every_n_train_steps(tmpdir): ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) +def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): + """ Make sure that a MisconfigurationException is raised if both every_n_val_epochs and every_n_train_steps are enabled together. """ + with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0) + + +def test_none_every_n_train_steps_val_epochs(tmpdir): + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) + assert checkpoint_callback.period == 1 + assert checkpoint_callback.every_n_val_epochs == 1 + assert checkpoint_callback.every_n_train_steps == 0 + + def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() @@ -598,7 +614,7 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +@pytest.mark.parametrize("every_n_val_epochs", list(range(1, 4))) def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): model = LogInTwoMethods() epochs = 5 @@ -654,6 +670,8 @@ def test_ckpt_every_n_train_steps(tmpdir): model = LogInTwoMethods() every_n_train_steps = 16 + max_epochs = 2 + epoch_length = 64 checkpoint_callback = ModelCheckpoint( filename="{step}", every_n_val_epochs=0, @@ -671,38 +689,10 @@ def test_ckpt_every_n_train_steps(tmpdir): ) trainer.fit(model) - expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_train_steps)] - assert set(os.listdir(tmpdir)) == set(expected) - - -@pytest.mark.parametrize("every_n_val_epochs", [1, 3]) -def test_ckpt_every_n_train_steps_and_every_n_val_epochs(tmpdir, every_n_val_epochs): - """ Tests that checkpoints are taken every 30 steps and every epochs """ - model = LogInTwoMethods() - every_n_train_steps = 30 - checkpoint_callback = ModelCheckpoint( - every_n_val_epochs=every_n_val_epochs, - every_n_train_steps=every_n_train_steps, - dirpath=tmpdir, - save_top_k=-1, - save_last=False, - filename="{step}", - ) - max_epochs = 3 - epoch_step_length = 64 - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=max_epochs, - callbacks=[checkpoint_callback], - logger=False, - ) - trainer.fit(model) - expected_steps_for_ckpt = [ - i for i in range(epoch_step_length * max_epochs) - if ((i + 1) % every_n_train_steps) == 0 or (i + 1) % (every_n_val_epochs * epoch_step_length) == 0 + expected = [ + f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps) ] - expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] - assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) + assert set(os.listdir(tmpdir)) == set(expected) def test_model_checkpoint_topk_zero(tmpdir): From 548e9382b89a2706b5c2529bc6140b08967886c7 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 10 Mar 2021 16:39:18 -0800 Subject: [PATCH 50/54] fix-default-0 --- pytorch_lightning/callbacks/model_checkpoint.py | 10 ++++++---- tests/checkpointing/test_model_checkpoint.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 493129822b3f5..211b20036e28a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -356,13 +356,15 @@ def __init_triggers( # Default to running once after each validation epoch if neither # every_n_train_steps nor every_n_val_epochs is set - self.every_n_val_epochs = every_n_val_epochs or 0 - self.every_n_train_steps = every_n_train_steps or 0 - if self.every_n_train_steps == 0 and self.every_n_val_epochs == 0: + if every_n_train_steps is None and every_n_val_epochs is None: self.every_n_val_epochs = 1 + self.every_n_train_steps = 0 log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") + else: + self.every_n_val_epochs = every_n_val_epochs or 0 + self.every_n_train_steps = every_n_train_steps or 0 - # period takes precedence for every_n_val_epochs for backwards compatibility + # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: rank_zero_warn( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 6a4c637ad485b..a2e7e89cab97b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -614,7 +614,7 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_val_epochs", list(range(1, 4))) +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): model = LogInTwoMethods() epochs = 5 From 376962b3b33e8a66b8bbafed1d3c822252242def Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 10 Mar 2021 16:41:19 -0800 Subject: [PATCH 51/54] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dd87c5ca9e16..9dcdea4c1601d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) -- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) +- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) From 1e7f64087d0bf20afd09673da73ad7ba839851c0 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 10 Mar 2021 16:56:59 -0800 Subject: [PATCH 52/54] formatting --- pytorch_lightning/callbacks/model_checkpoint.py | 3 ++- tests/checkpointing/test_model_checkpoint.py | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 211b20036e28a..731232612cd0c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -301,7 +301,8 @@ def __validate_init_configuration(self): ) if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0: raise MisconfigurationException( - f'Invalid values for every_n_train_steps={self.every_n_train_steps} and every_n_val_epochs={self.every_n_val_epochs}.' + f'Invalid values for every_n_train_steps={self.every_n_train_steps}' + ' and every_n_val_epochs={self.every_n_val_epochs}.' 'Both cannot be enabled at the same time.' ) if self.monitor is None: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a2e7e89cab97b..e9d5e2daa85a7 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -545,7 +545,10 @@ def test_invalid_every_n_train_steps(tmpdir): def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): - """ Make sure that a MisconfigurationException is raised if both every_n_val_epochs and every_n_train_steps are enabled together. """ + """ + Test that a MisconfigurationException is raised if both + every_n_val_epochs and every_n_train_steps are enabled together. + """ with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'): ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2) # These should not fail From ab4012d0eaa4d1bab8d0caa41626cee3e8d680b5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 10 Mar 2021 22:47:06 -0800 Subject: [PATCH 53/54] make-private make attributes private to the class --- .../callbacks/model_checkpoint.py | 32 +++++++++---------- tests/checkpointing/test_model_checkpoint.py | 4 +-- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 731232612cd0c..9c9244b9db317 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -217,7 +217,7 @@ def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None: if self._should_skip_saving_checkpoint(trainer): return step = trainer.global_step - skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0) + skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) if skip_batch: return self.save_checkpoint(trainer, pl_module) @@ -227,8 +227,8 @@ def on_validation_end(self, trainer, pl_module): checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs is None - or self.every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 + self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 ) if skip: return @@ -291,18 +291,18 @@ def _should_skip_saving_checkpoint(self, trainer) -> bool: def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') - if self.every_n_train_steps < 0: + if self._every_n_train_steps < 0: raise MisconfigurationException( - f'Invalid value for every_n_train_steps={self.every_n_train_steps}. Must be >= 0' + f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0' ) - if self.every_n_val_epochs < 0: + if self._every_n_val_epochs < 0: raise MisconfigurationException( - f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0' + f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' ) - if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0: + if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0: raise MisconfigurationException( - f'Invalid values for every_n_train_steps={self.every_n_train_steps}' - ' and every_n_val_epochs={self.every_n_val_epochs}.' + f'Invalid values for every_n_train_steps={self._every_n_train_steps}' + ' and every_n_val_epochs={self._every_n_val_epochs}.' 'Both cannot be enabled at the same time.' ) if self.monitor is None: @@ -358,12 +358,12 @@ def __init_triggers( # Default to running once after each validation epoch if neither # every_n_train_steps nor every_n_val_epochs is set if every_n_train_steps is None and every_n_val_epochs is None: - self.every_n_val_epochs = 1 - self.every_n_train_steps = 0 + self._every_n_val_epochs = 1 + self._every_n_train_steps = 0 log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") else: - self.every_n_val_epochs = every_n_val_epochs or 0 - self.every_n_train_steps = every_n_train_steps or 0 + self._every_n_val_epochs = every_n_val_epochs or 0 + self._every_n_train_steps = every_n_train_steps or 0 # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: @@ -371,9 +371,9 @@ def __init_triggers( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' ' Please use `every_n_val_epochs` instead.', DeprecationWarning ) - self.every_n_val_epochs = period + self._every_n_val_epochs = period - self._period = self.every_n_val_epochs + self._period = self._every_n_val_epochs @property def period(self) -> Optional[int]: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e9d5e2daa85a7..a9e86b0b2a223 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -559,8 +559,8 @@ def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): def test_none_every_n_train_steps_val_epochs(tmpdir): checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) assert checkpoint_callback.period == 1 - assert checkpoint_callback.every_n_val_epochs == 1 - assert checkpoint_callback.every_n_train_steps == 0 + assert checkpoint_callback._every_n_val_epochs == 1 + assert checkpoint_callback._every_n_train_steps == 0 def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): From 89c2df7c8f0cbe36883968a63911284f6276505c Mon Sep 17 00:00:00 2001 From: ananthsub Date: Thu, 11 Mar 2021 08:58:18 -0800 Subject: [PATCH 54/54] rebase --- .../callbacks/model_checkpoint.py | 23 ++++++++++--------- tests/checkpointing/test_model_checkpoint.py | 7 ++---- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 9c9244b9db317..bf6c799ef728a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -101,6 +101,10 @@ class ModelCheckpoint(Callback): If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative. This must be mutually exclusive with ``every_n_train_steps``. + Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and + ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` + will only save checkpoints at epochs 0 < E <= N + where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E. period: Interval (number of epochs) between checkpoints. .. warning:: @@ -178,7 +182,7 @@ def __init__( save_top_k: Optional[int] = None, save_weights_only: bool = False, mode: str = "min", - auto_insert_metric_name: bool = True + auto_insert_metric_name: bool = True, every_n_train_steps: Optional[int] = None, every_n_val_epochs: Optional[int] = None, period: Optional[int] = None, @@ -212,7 +216,7 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint - def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None: + def on_train_batch_end(self, trainer, *args, **kwargs) -> None: """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ if self._should_skip_saving_checkpoint(trainer): return @@ -220,9 +224,9 @@ def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None: skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) if skip_batch: return - self.save_checkpoint(trainer, pl_module) + self.save_checkpoint(trainer) - def on_validation_end(self, trainer, pl_module): + def on_validation_end(self, trainer, *args, **kwargs) -> None: """ checkpoints can be saved at the end of the val loop """ @@ -232,7 +236,7 @@ def on_validation_end(self, trainer, pl_module): ) if skip: return - self.save_checkpoint(trainer, pl_module) + self.save_checkpoint(trainer) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { @@ -303,7 +307,7 @@ def __validate_init_configuration(self): raise MisconfigurationException( f'Invalid values for every_n_train_steps={self._every_n_train_steps}' ' and every_n_val_epochs={self._every_n_val_epochs}.' - 'Both cannot be enabled at the same time.' + ' Both cannot be enabled at the same time.' ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved @@ -504,11 +508,8 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], """ filename = self._format_checkpoint_name( - self.filename, - epoch, - step, - metrics, - auto_insert_metric_name=self.auto_insert_metric_name) + self.filename, epoch, step, metrics, auto_insert_metric_name=self.auto_insert_metric_name + ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a9e86b0b2a223..e5583b9bbdf86 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -434,11 +434,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): # auto_insert_metric_name=False ckpt_name = ModelCheckpoint._format_checkpoint_name( - 'epoch={epoch:03d}-val_acc={val/acc}', - 3, - 2, - {'val/acc': 0.03}, - auto_insert_metric_name=False) + 'epoch={epoch:03d}-val_acc={val/acc}', 3, 2, {'val/acc': 0.03}, auto_insert_metric_name=False + ) assert ckpt_name == 'epoch=003-val_acc=0.03'