Skip to content

Commit

Permalink
mutual-exclusive
Browse files Browse the repository at this point in the history
Make every_n_train_steps and every_n_val_epochs mutually exclusive
  • Loading branch information
ananthsub committed Mar 11, 2021
1 parent 2995d88 commit fc92f8a
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 52 deletions.
75 changes: 55 additions & 20 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ class ModelCheckpoint(Callback):
saved (``model.save_weights(filepath)``), else the full model
is saved (``model.save(filepath)``).
every_n_train_steps: Number of training steps between checkpoints.
To disable, set ``every_n_train_steps = 0``. This value must be non-negative.
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative.
This must be mutually exclusive with ``every_n_val_epochs``.
every_n_val_epochs: Number of validation epochs between checkpoints.
To disable, set ``every_n_val_epochs = 0``. This value must be non-negative.
This is not mutually exclusive with ``every_n_val_epochs``.
If both are set, pay extreme caution if also setting ``monitor``
as the ``monitor`` value must be available in both training and validation.
This can have unintended consequences with tracking the top k models.
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``every_n_train_steps``.
period: Interval (number of epochs) between checkpoints.
.. warning::
Expand Down Expand Up @@ -168,8 +168,8 @@ def __init__(
save_top_k: Optional[int] = None,
save_weights_only: bool = False,
mode: str = "min",
every_n_train_steps: int = 0,
every_n_val_epochs: int = 1,
every_n_train_steps: Optional[int] = None,
every_n_val_epochs: Optional[int] = None,
period: Optional[int] = None,
):
super().__init__()
Expand All @@ -178,9 +178,6 @@ def __init__(
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.every_n_val_epochs = period if period is not None else every_n_val_epochs
self.period = self.every_n_val_epochs
self.every_n_train_steps = every_n_train_steps
self._last_global_step_saved = -1
self.current_score = None
self.best_k_models = {}
Expand All @@ -191,14 +188,9 @@ def __init__(
self.save_function = None
self.warned_result_obj = False

if period is not None:
rank_zero_warn(
'Argument `period` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
self.__validate_init_configuration()

def on_pretrain_routine_start(self, trainer, pl_module):
Expand All @@ -223,8 +215,8 @@ def on_validation_end(self, trainer, pl_module):
checkpoints can be saved at the end of the val loop
"""
skip = (
self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs < 1
or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0
self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs is None
or self.every_n_val_epochs < 1 or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0
)
if skip:
return
Expand Down Expand Up @@ -289,12 +281,17 @@ def __validate_init_configuration(self):
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
if self.every_n_train_steps < 0:
raise MisconfigurationException(
f'Invalid value for every_n_train_batches={self.every_n_train_steps}. Must be >= 0'
f'Invalid value for every_n_train_steps={self.every_n_train_steps}. Must be >= 0'
)
if self.every_n_val_epochs < 0:
raise MisconfigurationException(
f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0'
)
if self.every_n_train_steps > 0 and self.every_n_val_epochs > 0:
raise MisconfigurationException(
f'Invalid values for every_n_train_steps={self.every_n_train_steps} and every_n_val_epochs={self.every_n_val_epochs}.'
'Both cannot be enabled at the same time.'
)
if self.monitor is None:
# None: save last epoch, -1: save all epochs, 0: nothing is saved
if self.save_top_k not in (None, -1, 0):
Expand Down Expand Up @@ -341,6 +338,44 @@ def __init_monitor_mode(self, monitor, mode):

self.kth_value, self.mode = mode_dict[mode]

def __init_triggers(
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int]
) -> None:

# Default to running once after each validation epoch if neither
# every_n_train_steps nor every_n_val_epochs is set
self.every_n_val_epochs = every_n_val_epochs or 0
self.every_n_train_steps = every_n_train_steps or 0
if self.every_n_train_steps == 0 and self.every_n_val_epochs == 0:
self.every_n_val_epochs = 1
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")

# period takes precedence for every_n_val_epochs for backwards compatibility
if period is not None:
rank_zero_warn(
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
self.every_n_val_epochs = period

self._period = self.every_n_val_epochs

@property
def period(self) -> Optional[int]:
rank_zero_warn(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
return self._period

@period.setter
def period(self, value: Optional[int]) -> None:
rank_zero_warn(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
self._period = value

@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
Expand Down
54 changes: 22 additions & 32 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,22 @@ def test_invalid_every_n_train_steps(tmpdir):
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)


def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir):
""" Make sure that a MisconfigurationException is raised if both every_n_val_epochs and every_n_train_steps are enabled together. """
with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)


def test_none_every_n_train_steps_val_epochs(tmpdir):
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
assert checkpoint_callback.period == 1
assert checkpoint_callback.every_n_val_epochs == 1
assert checkpoint_callback.every_n_train_steps == 0


def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
""" Test that it is possible to save all checkpoints when monitor=None. """
seed_everything()
Expand Down Expand Up @@ -589,7 +605,7 @@ def test_model_checkpoint_period(tmpdir, period: int):
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
@pytest.mark.parametrize("every_n_val_epochs", list(range(1, 4)))
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
model = LogInTwoMethods()
epochs = 5
Expand Down Expand Up @@ -645,6 +661,8 @@ def test_ckpt_every_n_train_steps(tmpdir):

model = LogInTwoMethods()
every_n_train_steps = 16
max_epochs = 2
epoch_length = 64
checkpoint_callback = ModelCheckpoint(
filename="{step}",
every_n_val_epochs=0,
Expand All @@ -662,38 +680,10 @@ def test_ckpt_every_n_train_steps(tmpdir):
)

trainer.fit(model)
expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_train_steps)]
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", [1, 3])
def test_ckpt_every_n_train_steps_and_every_n_val_epochs(tmpdir, every_n_val_epochs):
""" Tests that checkpoints are taken every 30 steps and every epochs """
model = LogInTwoMethods()
every_n_train_steps = 30
checkpoint_callback = ModelCheckpoint(
every_n_val_epochs=every_n_val_epochs,
every_n_train_steps=every_n_train_steps,
dirpath=tmpdir,
save_top_k=-1,
save_last=False,
filename="{step}",
)
max_epochs = 3
epoch_step_length = 64
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=max_epochs,
callbacks=[checkpoint_callback],
logger=False,
)
trainer.fit(model)
expected_steps_for_ckpt = [
i for i in range(epoch_step_length * max_epochs)
if ((i + 1) % every_n_train_steps) == 0 or (i + 1) % (every_n_val_epochs * epoch_step_length) == 0
expected = [
f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps)
]
expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt]
assert set(os.listdir(tmpdir)) == set(expected_ckpt_files)
assert set(os.listdir(tmpdir)) == set(expected)


def test_model_checkpoint_topk_zero(tmpdir):
Expand Down

0 comments on commit fc92f8a

Please sign in to comment.