Skip to content

Commit

Permalink
Merge 81385f2 into 55dd3a4
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub authored Mar 9, 2021
2 parents 55dd3a4 + 81385f2 commit 4065f42
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 21 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))

- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))


Expand Down Expand Up @@ -46,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))

Expand Down
78 changes: 60 additions & 18 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,21 @@ class ModelCheckpoint(Callback):
save_weights_only: if ``True``, then only the model's weights will be
saved (``model.save_weights(filepath)``), else the full model
is saved (``model.save(filepath)``).
every_n_train_steps: Number of training steps between checkpoints.
To disable, set ``every_n_train_steps = 0``. This value must be non-negative.
every_n_val_epochs: Number of validation epochs between checkpoints.
To disable, set ``every_n_val_epochs = 0``. This value must be non-negative.
This is not mutually exclusive with ``every_n_val_epochs``.
If both are set, pay extreme caution if also setting ``monitor``
as the ``monitor`` value must be available in both training and validation.
This can have unintended consequences with tracking the top k models.
period: Interval (number of epochs) between checkpoints.
.. warning::
This argument has been deprecated in v1.3 and will be removed in v1.5.
Use ``every_n_val_epochs`` instead.
Note:
For extra customization, ModelCheckpoint includes the following attributes:
Expand Down Expand Up @@ -155,15 +168,19 @@ def __init__(
save_top_k: Optional[int] = None,
save_weights_only: bool = False,
mode: str = "min",
period: int = 1,
every_n_train_steps: int = 0,
every_n_val_epochs: int = 1,
period: Optional[int] = None,
):
super().__init__()
self.monitor = monitor
self.verbose = verbose
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.every_n_val_epochs = period if period is not None else every_n_val_epochs
self.period = self.every_n_val_epochs
self.every_n_train_steps = every_n_train_steps
self._last_global_step_saved = -1
self.current_score = None
self.best_k_models = {}
Expand All @@ -174,6 +191,12 @@ def __init__(
self.save_function = None
self.warned_result_obj = False

if period is not None:
rank_zero_warn(
'Argument `period` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__validate_init_configuration()
Expand All @@ -185,11 +208,27 @@ def on_pretrain_routine_start(self, trainer, pl_module):
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint

def on_train_batch_end(self, trainer, pl_module, *args, **kwargs) -> None:
""" Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """
if self._should_skip_saving_checkpoint(trainer):
return
step = trainer.global_step
skip_batch = self.every_n_train_steps < 1 or ((step + 1) % self.every_n_train_steps != 0)
if skip_batch:
return
self.save_checkpoint(trainer, pl_module)

def on_validation_end(self, trainer, pl_module):
"""
checkpoints can be saved at the end of the val loop
"""
self.save_checkpoint(trainer)
skip = (
self._should_skip_saving_checkpoint(trainer) or self.every_n_val_epochs < 1
or (trainer.current_epoch + 1) % self.every_n_val_epochs != 0
)
if skip:
return
self.save_checkpoint(trainer, pl_module)

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
Expand All @@ -216,20 +255,8 @@ def save_checkpoint(self, trainer, unused: Optional = None):
" has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning
)

epoch = trainer.current_epoch
global_step = trainer.global_step

from pytorch_lightning.trainer.states import TrainerState
if (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or self._last_global_step_saved == global_step # already saved at the last step
):
return

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)

Expand All @@ -248,9 +275,26 @@ def save_checkpoint(self, trainer, unused: Optional = None):
# Mode 3: save last checkpoints
self._save_last_checkpoint(trainer, monitor_candidates)

def _should_skip_saving_checkpoint(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
if self.every_n_train_steps < 0:
raise MisconfigurationException(
f'Invalid value for every_n_train_batches={self.every_n_train_steps}. Must be >= 0'
)
if self.every_n_val_epochs < 0:
raise MisconfigurationException(
f'Invalid value for every_n_val_epochs={self.every_n_val_epochs}. Must be >= 0'
)
if self.monitor is None:
# None: save last epoch, -1: save all epochs, 0: nothing is saved
if self.save_top_k not in (None, -1, 0):
Expand Down Expand Up @@ -554,9 +598,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A
self._save_model(trainer, filepath)

if (
self.save_top_k is None
and self.best_model_path
and self.best_model_path != filepath
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
and trainer.is_global_zero
):
self._del_model(self.best_model_path)
Expand Down
132 changes: 129 additions & 3 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,26 @@ def test_none_monitor_save_last(tmpdir):
ModelCheckpoint(dirpath=tmpdir, save_last=False)


def test_invalid_every_n_val_epochs(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)


def test_invalid_every_n_train_steps(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)


def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
""" Test that it is possible to save all checkpoints when monitor=None. """
seed_everything()
Expand Down Expand Up @@ -558,9 +578,8 @@ def test_model_checkpoint_period(tmpdir, period: int):
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
val_check_interval=1.0,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)
Expand All @@ -570,6 +589,113 @@ def test_model_checkpoint_period(tmpdir, period: int):
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename='{epoch}',
save_top_k=-1,
every_n_val_epochs=(2 * every_n_val_epochs),
period=every_n_val_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)


def test_ckpt_every_n_train_steps(tmpdir):
""" Tests that the checkpoints are saved every n training steps. """

model = LogInTwoMethods()
every_n_train_steps = 16
checkpoint_callback = ModelCheckpoint(
filename="{step}",
every_n_val_epochs=0,
every_n_train_steps=every_n_train_steps,
dirpath=tmpdir,
save_top_k=-1,
save_last=False,
)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
progress_bar_refresh_rate=0,
callbacks=[checkpoint_callback],
logger=False,
)

trainer.fit(model)
expected = [f"step={i}.ckpt" for i in range(15, 128, every_n_train_steps)]
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", [1, 3])
def test_ckpt_every_n_train_steps_and_every_n_val_epochs(tmpdir, every_n_val_epochs):
""" Tests that checkpoints are taken every 30 steps and every epochs """
model = LogInTwoMethods()
every_n_train_steps = 30
checkpoint_callback = ModelCheckpoint(
every_n_val_epochs=every_n_val_epochs,
every_n_train_steps=every_n_train_steps,
dirpath=tmpdir,
save_top_k=-1,
save_last=False,
filename="{step}",
)
max_epochs = 3
epoch_step_length = 64
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=max_epochs,
callbacks=[checkpoint_callback],
logger=False,
)
trainer.fit(model)
expected_steps_for_ckpt = [
i for i in range(epoch_step_length * max_epochs)
if ((i + 1) % every_n_train_steps) == 0 or (i + 1) % (every_n_val_epochs * epoch_step_length) == 0
]
expected_ckpt_files = [f"step={step}.ckpt" for step in expected_steps_for_ckpt]
assert set(os.listdir(tmpdir)) == set(expected_ckpt_files)


def test_model_checkpoint_topk_zero(tmpdir):
""" Test that no checkpoints are saved when save_top_k=0. """
model = LogInTwoMethods()
Expand Down
7 changes: 7 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,10 @@ def configure_optimizers(self):

with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
trainer.fit(model)


def test_v1_5_0_model_checkpoint_period(tmpdir):
with no_warning_call(DeprecationWarning):
ModelCheckpoint(dirpath=tmpdir)
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
ModelCheckpoint(dirpath=tmpdir, period=1)

0 comments on commit 4065f42

Please sign in to comment.