diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1060e24109604a..eac8349004ad98 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -344,13 +344,15 @@ def __init_triggers( # Default to running once after each validation epoch if neither # every_n_train_steps nor every_n_val_epochs is set - self.every_n_val_epochs = every_n_val_epochs or 0 - self.every_n_train_steps = every_n_train_steps or 0 - if self.every_n_train_steps == 0 and self.every_n_val_epochs == 0: + if every_n_train_steps is None and every_n_val_epochs is None: self.every_n_val_epochs = 1 + self.every_n_train_steps = 0 log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1") + else: + self.every_n_val_epochs = every_n_val_epochs or 0 + self.every_n_train_steps = every_n_train_steps or 0 - # period takes precedence for every_n_val_epochs for backwards compatibility + # period takes precedence over every_n_val_epochs for backwards compatibility if period is not None: rank_zero_warn( 'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.' diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index db102a696d8f40..73fff72a8362a2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -605,7 +605,7 @@ def test_model_checkpoint_period(tmpdir, period: int): assert set(os.listdir(tmpdir)) == set(expected) -@pytest.mark.parametrize("every_n_val_epochs", list(range(1, 4))) +@pytest.mark.parametrize("every_n_val_epochs", list(range(4))) def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs): model = LogInTwoMethods() epochs = 5