diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index fbef267f933829..1060e24109604a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -94,13 +94,13 @@ class ModelCheckpoint(Callback): saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). every_n_train_steps: Number of training steps between checkpoints. - To disable, set ``every_n_train_steps = 0``. This value must be non-negative. + If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training + To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative. + This must be mutually exclusive with ``every_n_val_epochs``. every_n_val_epochs: Number of validation epochs between checkpoints. - To disable, set ``every_n_val_epochs = 0``. This value must be non-negative. - This is not mutually exclusive with ``every_n_val_epochs``. - If both are set, pay extreme caution if also setting ``monitor`` - as the ``monitor`` value must be available in both training and validation. - This can have unintended consequences with tracking the top k models. + If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end + To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative. + This must be mutually exclusive with ``every_n_train_steps``. period: Interval (number of epochs) between checkpoints. .. warning:: @@ -168,8 +168,8 @@ def __init__( save_top_k: Optional[int] = None, save_weights_only: bool = False, mode: str = "min", - every_n_train_steps: int = 0, - every_n_val_epochs: int = 1, + every_n_train_steps: Optional[int] = None, + every_n_val_epochs: Optional[int] = None, period: Optional[int] = None, ): super().__init__() @@ -178,9 +178,6 @@ def __init__( self.save_last = save_last self.save_top_k = save_top_k self.save_weights_only = save_weights_only - self.every_n_val_epochs = period if period is not None else every_n_val_epochs - self.period = self.every_n_val_epochs - self.every_n_train_steps = every_n_train_steps self._last_global_step_saved = -1 self.current_score = None self.best_k_models = {} @@ -191,14 +188,9 @@ def __init__( self.save_function = None self.warned_result_obj = False - if period is not None: - rank_zero_warn( - 'Argument `period` is deprecated in v1.3 and will be removed in v1.5.' - ' Please use `every_n_val_epochs` instead.', DeprecationWarning - ) - self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) + self.__init_triggers(every_n_train_steps, every_n_val_epochs, period) self.__validate_init_configuration() def on_pretrain_routine_start(self, trainer, pl_module): @@ -223,8 +215,8 @@ def on_validation_end(self, trainer, pl_module): checkpoints can be saved at the end of the val loop """ skip = ( - self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs < 1 - or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 + self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs is None + or self.every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0 ) if skip: return @@ -289,12 +281,17 @@ def __validate_init_configuration(self): raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1') if self.every_n_train_steps < 0: raise MisconfigurationException( - f'Invalid value for every_n_train_batches={self.every_n_train_steps}. Must be >= 0' + f'Invalid value for every_n_train_steps={self.every_n_train_steps}. Must be >= 0' ) if self.every_n_val_epochs < 0: raise MisconfigurationException( f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0' ) + if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0: + raise MisconfigurationException( + f'Invalid values for every_n_train_steps={self.every_n_train_steps} and every_n_val_epochs={self.every_n_val_epochs}.' + 'Both cannot be enabled at the same time.' + ) if self.monitor is None: # None: save last epoch, -1: save all epochs, 0: nothing is saved if self.save_top_k not in (None, -1, 0): @@ -341,6 +338,44 @@ def __init_monitor_mode(self, monitor, mode): self.kth_value, self.mode = mode_dict[mode] + def __init_triggers( + self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int] + ) -> None: + + # Default to running once after each validation epoch if neither + # every_n_train_steps nor every_n_val_epochs is set + self.every_n_val_epochs = every_n_val_epochs or 0 + self.every_n_train_steps = every_n_train_steps or 0 + if self.every_n_train_steps == 0 and self.every_n_val_epochs == 0: + self.every_n_val_epochs = 1 + log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") + + # period takes precedence for every_n_val_epochs for backwards compatibility + if period is not None: + rank_zero_warn( + 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self.every_n_val_epochs = period + + self._period = self.every_n_val_epochs + + @property + def period(self) -> Optional[int]: + rank_zero_warn( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + return self._period + + @period.setter + def period(self, value: Optional[int]) -> None: + rank_zero_warn( + 'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' + ' Please use `every_n_val_epochs` instead.', DeprecationWarning + ) + self._period = value + @rank_zero_only def _del_model(self, filepath: str): if self._fs.exists(filepath): diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index b56e5188172315..db102a696d8f40 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -535,6 +535,22 @@ def test_invalid_every_n_train_steps(tmpdir): ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2) +def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir): + """ Make sure that a MisconfigurationException is raised if both every_n_val_epochs and every_n_train_steps are enabled together. """ + with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'): + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2) + # These should not fail + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3) + ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0) + + +def test_none_every_n_train_steps_val_epochs(tmpdir): + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir) + assert checkpoint_callback.period == 1 + assert checkpoint_callback.every_n_val_epochs == 1 + assert checkpoint_callback.every_n_train_steps == 0 + + def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog): """ Test that it is possible to save all checkpoints when monitor=None. """ seed_everything() @@ -589,7 +605,7 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) +@pytest.mark.parametrize("every_n_val_epochs", list(range(1, 4))) def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): model = LogInTwoMethods() epochs = 5 @@ -645,6 +661,8 @@ def test_ckpt_every_n_train_steps(tmpdir): model = LogInTwoMethods() every_n_train_steps = 16 + max_epochs = 2 + epoch_length = 64 checkpoint_callback = ModelCheckpoint( filename="{step}", every_n_val_epochs=0, @@ -662,38 +680,10 @@ def test_ckpt_every_n_train_steps(tmpdir): ) trainer.fit(model) - expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_train_steps)] - assert set(os.listdir(tmpdir)) == set(expected) - - -@pytest.mark.parametrize("every_n_val_epochs", [1, 3]) -def test_ckpt_every_n_train_steps_and_every_n_val_epochs(tmpdir, every_n_val_epochs): - """ Tests that checkpoints are taken every 30 steps and every epochs """ - model = LogInTwoMethods() - every_n_train_steps = 30 - checkpoint_callback = ModelCheckpoint( - every_n_val_epochs=every_n_val_epochs, - every_n_train_steps=every_n_train_steps, - dirpath=tmpdir, - save_top_k=-1, - save_last=False, - filename="{step}", - ) - max_epochs = 3 - epoch_step_length = 64 - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=max_epochs, - callbacks=[checkpoint_callback], - logger=False, - ) - trainer.fit(model) - expected_steps_for_ckpt = [ - i for i in range(epoch_step_length * max_epochs) - if ((i + 1) % every_n_train_steps) == 0 or (i + 1) % (every_n_val_epochs * epoch_step_length) == 0 + expected = [ + f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps) ] - expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt] - assert set(os.listdir(tmpdir)) == set(expected_ckpt_files) + assert set(os.listdir(tmpdir)) == set(expected) def test_model_checkpoint_topk_zero(tmpdir):