diff --git a/CHANGELOG.md b/CHANGELOG.md index d1c347c00a3f1f..b64e4572c75e4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) @@ -46,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f457e9de7d0fa8..fbef267f933829 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -93,8 +93,21 @@ class ModelCheckpoint(Callback): save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). + every_n_train_steps: Number of training steps between checkpoints. + To disable, set ``every_n_train_steps = 0``. This value must be non-negative. + every_n_val_epochs: Number of validation epochs between checkpoints. + To disable, set ``every_n_val_epochs = 0``. This value must be non-negative. + This is not mutually exclusive with ``every_n_val_epochs``. + If both are set, pay extreme caution if also setting ``monitor`` + as the ``monitor`` value must be available in both training and validation. + This can have unintended consequences with tracking the top k models. period: Interval (number of epochs) between checkpoints. + .. warning:: + This argument has been deprecated in v1.3 and will be removed in v1.5. + + Use ``every_n_val_epochs`` instead. + Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -155,7 +168,9 @@ def __init__( save_top_k: Optional[int] = None, save_weights_only: bool = False, mode: str = "min", - period: int = 1, + every_n_train_steps: int = 0, + every_n_val_epochs: int = 1, + period: Optional[int] = None, ): super().__init__() self.monitor = monitor @@ -163,7 +178,9 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.period = period + self.every_n_val_epochs = period if period is not None else every_n_val_epochs + self.period = self.every_n_val_epochs + self.every_n_train_steps = every_n_train_steps self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -174,6 +191,12 @@ def __init__( self.save_function = None self.warned_result_obj = False + if period is not None: + rank_zero_warn( + 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) self.__validate_init_configuration() @@ -185,11 +208,27 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint + def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None: + """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ + if self._should_skip_saving_checkpoint(trainer): + return + step = trainer.global_step + skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0) + if skip_batch: + return + self.save_checkpoint(trainer, pl_module) + def on_validation_end(self, trainer, pl_module): """ checkpoints can be saved at the end of the val loop """ - self.save_checkpoint(trainer) + skip = ( + self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 + ) + if skip: + return + self.save_checkpoint(trainer, pl_module) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { @@ -216,20 +255,8 @@ def save_checkpoint(self, trainer, unused: Optional = None): " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning ) - epoch = trainer.current_epoch global_step = trainer.global_step - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state != TrainerState.FITTING # don't save anything during non-fit - or trainer.sanity_checking # don't save anything during sanity check - or self.period < 1 # no models are saved - or (epoch + 1) % self.period # skip epoch - or self._last_global_step_saved == global_step # already saved at the last step - ): - return - self._add_backward_monitor_support(trainer) self._validate_monitor_key(trainer) @@ -248,9 +275,26 @@ def save_checkpoint(self, trainer, unused: Optional = None): # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) + def _should_skip_saving_checkpoint(self, trainer) -> bool: + from pytorch_lightning.trainer.states import TrainerState + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check + or self._last_global_step_saved == trainer.global_step # already saved at the last step + ) + def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self.every_n_train_steps < 0: + raise MisconfigurationException( + f'Invalid value for every_n_train_batches={self.every_n_train_steps}. Must be >= 0' + ) + if self.every_n_val_epochs < 0: + raise MisconfigurationException( + f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): @@ -554,9 +598,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A self._save_model(trainer, filepath) if ( - self.save_top_k is None - and self.best_model_path - and self.best_model_path != filepath + self.save_top_k is None and self.best_model_path and self.best_model_path != filepath and trainer.is_global_zero ): self._del_model(self.best_model_path) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 845b05aed9b380..b56e5188172315 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -515,6 +515,26 @@ def test_none_monitor_save_last(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_last=False) +def test_invalid_every_n_val_epochs(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() @@ -558,9 +578,8 @@ def test_model_checkpoint_period(tmpdir, period: int): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -570,6 +589,113 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): + """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_val_epochs=(2 * every_n_val_epochs), + period=every_n_val_epochs + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +def test_ckpt_every_n_train_steps(tmpdir): + """ Tests that the checkpoints are saved every n training steps. """ + + model = LogInTwoMethods() + every_n_train_steps = 16 + checkpoint_callback = ModelCheckpoint( + filename="{step}", + every_n_val_epochs=0, + every_n_train_steps=every_n_train_steps, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + progress_bar_refresh_rate=0, + callbacks=[checkpoint_callback], + logger=False, + ) + + trainer.fit(model) + expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_train_steps)] + assert set(os.listdir(tmpdir)) == set(expected) + + +@pytest.mark.parametrize("every_n_val_epochs", [1, 3]) +def test_ckpt_every_n_train_steps_and_every_n_val_epochs(tmpdir, every_n_val_epochs): + """ Tests that checkpoints are taken every 30 steps and every epochs """ + model = LogInTwoMethods() + every_n_train_steps = 30 + checkpoint_callback = ModelCheckpoint( + every_n_val_epochs=every_n_val_epochs, + every_n_train_steps=every_n_train_steps, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + filename="{step}", + ) + max_epochs = 3 + epoch_step_length = 64 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=max_epochs, + callbacks=[checkpoint_callback], + logger=False, + ) + trainer.fit(model) + expected_steps_for_ckpt = [ + i for i in range(epoch_step_length * max_epochs) + if ((i + 1) % every_n_train_steps) == 0 or (i + 1) % (every_n_val_epochs * epoch_step_length) == 0 + ] + expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] + assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) + + def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods() diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 7d8c7d2adeea10..e65ebbab254de1 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -104,3 +104,10 @@ def configure_optimizers(self): with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): trainer.fit(model) + + +def test_v1_5_0_model_checkpoint_period(tmpdir): + with no_warning_call(DeprecationWarning): + ModelCheckpoint(dirpath=tmpdir) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + ModelCheckpoint(dirpath=tmpdir, period=1)