Skip to content

Commit

Permalink
fix-default-0
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Mar 11, 2021
1 parent fc92f8a commit 1022c8a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
10 changes: 6 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,15 @@ def __init_triggers(

# Default to running once after each validation epoch if neither
# every_n_train_steps nor every_n_val_epochs is set
self.every_n_val_epochs = every_n_val_epochs or 0
self.every_n_train_steps = every_n_train_steps or 0
if self.every_n_train_steps == 0 and self.every_n_val_epochs == 0:
if every_n_train_steps is None and every_n_val_epochs is None:
self.every_n_val_epochs = 1
self.every_n_train_steps = 0
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
else:
self.every_n_val_epochs = every_n_val_epochs or 0
self.every_n_train_steps = every_n_train_steps or 0

# period takes precedence for every_n_val_epochs for backwards compatibility
# period takes precedence over every_n_val_epochs for backwards compatibility
if period is not None:
rank_zero_warn(
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
Expand Down
2 changes: 1 addition & 1 deletion tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def test_model_checkpoint_period(tmpdir, period: int):
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(1, 4)))
@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
model = LogInTwoMethods()
epochs = 5
Expand Down

0 comments on commit 1022c8a

Please sign in to comment.