diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f721b263668f..9dcdea4c1601d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) +- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) - Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072)) @@ -55,6 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146)) + - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f05a10a41996b..bf6c799ef728a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -93,8 +93,25 @@ class ModelCheckpoint(Callback): save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). + every_n_train_steps: Number of training steps between checkpoints. + If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training + To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative. + This must be mutually exclusive with ``every_n_val_epochs``. + every_n_val_epochs: Number of validation epochs between checkpoints. + If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end + To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``every_n_train_steps``. + Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and + ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` + will only save checkpoints at epochs 0 < E <= N + where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E. period: Interval (number of epochs) between checkpoints. + .. warning:: + This argument has been deprecated in v1.3 and will be removed in v1.5. + + Use ``every_n_val_epochs`` instead. + Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -165,8 +182,10 @@ def __init__( save_top_k: Optional[int] = None, save_weights_only: bool = False, mode: str = "min", - period: int = 1, - auto_insert_metric_name: bool = True + auto_insert_metric_name: bool = True, + every_n_train_steps: Optional[int] = None, + every_n_val_epochs: Optional[int] = None, + period: Optional[int] = None, ): super().__init__() self.monitor = monitor @@ -174,7 +193,6 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.period = period self.auto_insert_metric_name = auto_insert_metric_name self._last_global_step_saved = -1 self.current_score = None @@ -188,6 +206,7 @@ def __init__( self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) self.__validate_init_configuration() def on_pretrain_routine_start(self, trainer, pl_module): @@ -197,10 +216,26 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.__resolve_ckpt_dir(trainer) self.save_function = trainer.save_checkpoint - def on_validation_end(self, trainer, pl_module): + def on_train_batch_end(self, trainer, *args, **kwargs) -> None: + """ Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """ + if self._should_skip_saving_checkpoint(trainer): + return + step = trainer.global_step + skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0) + if skip_batch: + return + self.save_checkpoint(trainer) + + def on_validation_end(self, trainer, *args, **kwargs) -> None: """ checkpoints can be saved at the end of the val loop """ + skip = ( + self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1 + or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0 + ) + if skip: + return self.save_checkpoint(trainer) def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: @@ -228,20 +263,8 @@ def save_checkpoint(self, trainer, unused: Optional = None): " has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning ) - epoch = trainer.current_epoch global_step = trainer.global_step - from pytorch_lightning.trainer.states import TrainerState - if ( - trainer.fast_dev_run # disable checkpointing with fast_dev_run - or trainer.state != TrainerState.FITTING # don't save anything during non-fit - or trainer.sanity_checking # don't save anything during sanity check - or self.period < 1 # no models are saved - or (epoch + 1) % self.period # skip epoch - or self._last_global_step_saved == global_step # already saved at the last step - ): - return - self._add_backward_monitor_support(trainer) self._validate_monitor_key(trainer) @@ -260,9 +283,32 @@ def save_checkpoint(self, trainer, unused: Optional = None): # Mode 3: save last checkpoints self._save_last_checkpoint(trainer, monitor_candidates) + def _should_skip_saving_checkpoint(self, trainer) -> bool: + from pytorch_lightning.trainer.states import TrainerState + return ( + trainer.fast_dev_run # disable checkpointing with fast_dev_run + or trainer.state != TrainerState.FITTING # don't save anything during non-fit + or trainer.sanity_checking # don't save anything during sanity check + or self._last_global_step_saved == trainer.global_step # already saved at the last step + ) + def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') + if self._every_n_train_steps < 0: + raise MisconfigurationException( + f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0' + ) + if self._every_n_val_epochs < 0: + raise MisconfigurationException( + f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0' + ) + if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0: + raise MisconfigurationException( + f'Invalid values for every_n_train_steps={self._every_n_train_steps}' + ' and every_n_val_epochs={self._every_n_val_epochs}.' + ' Both cannot be enabled at the same time.' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): @@ -309,6 +355,46 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] + def __init_triggers( + self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + ) -> None: + + # Default to running once after each validation epoch if neither + # every_n_train_steps nor every_n_val_epochs is set + if every_n_train_steps is None and every_n_val_epochs is None: + self._every_n_val_epochs = 1 + self._every_n_train_steps = 0 + log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") + else: + self._every_n_val_epochs = every_n_val_epochs or 0 + self._every_n_train_steps = every_n_train_steps or 0 + + # period takes precedence over every_n_val_epochs for backwards compatibility + if period is not None: + rank_zero_warn( + 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self._every_n_val_epochs = period + + self._period = self._every_n_val_epochs + + @property + def period(self) -> Optional[int]: + rank_zero_warn( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + return self._period + + @period.setter + def period(self, value: Optional[int]) -> None: + rank_zero_warn( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self._period = value + @rank_zero_only def _del_model(self, filepath: str): if self._fs.exists(filepath): @@ -422,11 +508,8 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], """ filename = self._format_checkpoint_name( - self.filename, - epoch, - step, - metrics, - auto_insert_metric_name=self.auto_insert_metric_name) + self.filename, epoch, step, metrics, auto_insert_metric_name=self.auto_insert_metric_name + ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) @@ -581,9 +664,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A self._save_model(trainer, filepath) if ( - self.save_top_k is None - and self.best_model_path - and self.best_model_path != filepath + self.save_top_k is None and self.best_model_path and self.best_model_path != filepath and trainer.is_global_zero ): self._del_model(self.best_model_path) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4a8088070f041..e5583b9bbdf86 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -434,11 +434,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): # auto_insert_metric_name=False ckpt_name = ModelCheckpoint._format_checkpoint_name( - 'epoch={epoch:03d}-val_acc={val/acc}', - 3, - 2, - {'val/acc': 0.03}, - auto_insert_metric_name=False) + 'epoch={epoch:03d}-val_acc={val/acc}', 3, 2, {'val/acc': 0.03}, auto_insert_metric_name=False + ) assert ckpt_name == 'epoch=003-val_acc=0.03' @@ -524,6 +521,45 @@ def test_none_monitor_save_last(tmpdir): ModelCheckpoint(dirpath=tmpdir, save_last=False) +def test_invalid_every_n_val_epochs(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps(tmpdir): + """ Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """ + with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1) + ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) + + +def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): + """ + Test that a MisconfigurationException is raised if both + every_n_val_epochs and every_n_train_steps are enabled together. + """ + with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0) + + +def test_none_every_n_train_steps_val_epochs(tmpdir): + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) + assert checkpoint_callback.period == 1 + assert checkpoint_callback._every_n_val_epochs == 1 + assert checkpoint_callback._every_n_train_steps == 0 + + def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() @@ -567,9 +603,8 @@ def test_model_checkpoint_period(tmpdir, period: int): default_root_dir=tmpdir, callbacks=[checkpoint_callback], max_epochs=epochs, - limit_train_batches=0.1, - limit_val_batches=0.1, - val_check_interval=1.0, + limit_train_batches=1, + limit_val_batches=1, logger=False, ) trainer.fit(model) @@ -579,6 +614,87 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs): + """ Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """ + model = LogInTwoMethods() + epochs = 5 + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename='{epoch}', + save_top_k=-1, + every_n_val_epochs=(2 * every_n_val_epochs), + period=every_n_val_epochs + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + max_epochs=epochs, + limit_train_batches=1, + limit_val_batches=1, + logger=False, + ) + trainer.fit(model) + + # check that the correct ckpts were created + expected = [f'epoch={e}.ckpt' for e in range(epochs) + if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else [] + assert set(os.listdir(tmpdir)) == set(expected) + + +def test_ckpt_every_n_train_steps(tmpdir): + """ Tests that the checkpoints are saved every n training steps. """ + + model = LogInTwoMethods() + every_n_train_steps = 16 + max_epochs = 2 + epoch_length = 64 + checkpoint_callback = ModelCheckpoint( + filename="{step}", + every_n_val_epochs=0, + every_n_train_steps=every_n_train_steps, + dirpath=tmpdir, + save_top_k=-1, + save_last=False, + ) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + progress_bar_refresh_rate=0, + callbacks=[checkpoint_callback], + logger=False, + ) + + trainer.fit(model) + expected = [ + f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps) + ] + assert set(os.listdir(tmpdir)) == set(expected) + + def test_model_checkpoint_topk_zero(tmpdir): """ Test that no checkpoints are saved when save_top_k=0. """ model = LogInTwoMethods() diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 7d8c7d2adeea1..e65ebbab254de 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -104,3 +104,10 @@ def configure_optimizers(self): with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"): trainer.fit(model) + + +def test_v1_5_0_model_checkpoint_period(tmpdir): + with no_warning_call(DeprecationWarning): + ModelCheckpoint(dirpath=tmpdir) + with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"): + ModelCheckpoint(dirpath=tmpdir, period=1)