Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Support iteration-based checkpointing in model checkpoint callback #6146

Merged
merged 54 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
727591c
Update model_checkpoint.py
ananthsub Feb 23, 2021
eeeffd8
add tests
ananthsub Feb 23, 2021
be36e86
Update model_checkpoint.py
ananthsub Feb 23, 2021
218737f
Update test_model_checkpoint.py
ananthsub Feb 23, 2021
f89ea03
fix tests
ananthsub Feb 23, 2021
f857ffa
every_n_batches
ananthsub Feb 23, 2021
1763ea4
Update test_model_checkpoint.py
ananthsub Feb 23, 2021
e86305c
defaults
ananthsub Feb 23, 2021
45f16dd
rm tests
ananthsub Feb 23, 2021
fd90771
Update model_checkpoint.py
ananthsub Feb 23, 2021
be2ae2e
Update test_model_checkpoint.py
ananthsub Feb 24, 2021
572fc9d
Prune deprecated metrics for 1.3 (#6161)
Borda Feb 24, 2021
59d0ce8
Update model_checkpoint.py
ananthsub Feb 23, 2021
c26dd03
add tests
ananthsub Feb 23, 2021
f7c5100
defaults
ananthsub Feb 23, 2021
81a1434
Update CHANGELOG.md
ananthsub Feb 24, 2021
6423e1d
pre-commit
ananthsub Feb 24, 2021
a7a469b
Update model_checkpoint.py
ananthsub Mar 2, 2021
3cfc44a
update defaults
ananthsub Mar 2, 2021
70dc438
Update test_remove_1-5.py
ananthsub Mar 2, 2021
ddaa783
Update model_checkpoint.py
ananthsub Mar 2, 2021
2384874
Update model_checkpoint.py
ananthsub Mar 2, 2021
1f2d0f2
Update model_checkpoint.py
ananthsub Mar 2, 2021
a7cec2b
Update model_checkpoint.py
ananthsub Mar 2, 2021
6e06b8a
Update model_checkpoint.py
ananthsub Mar 2, 2021
28e3683
Update model_checkpoint.py
ananthsub Mar 2, 2021
9239325
fix tests
ananthsub Mar 5, 2021
b9152a1
Update test_model_checkpoint.py
ananthsub Mar 5, 2021
dbbb446
Update model_checkpoint.py
ananthsub Mar 5, 2021
22e917b
Update model_checkpoint.py
ananthsub Mar 5, 2021
bd45c53
Update model_checkpoint.py
ananthsub Mar 5, 2021
744a078
Update test_model_checkpoint.py
ananthsub Mar 5, 2021
7b7ca5d
ckpt-callback
ananthsub Mar 5, 2021
b32e500
Update test_model_checkpoint.py
ananthsub Mar 5, 2021
7ff63ac
Update model_checkpoint.py
ananthsub Mar 5, 2021
77b70fd
Update model_checkpoint.py
ananthsub Mar 5, 2021
791f876
validation-end
ananthsub Mar 5, 2021
2e802c4
Update model_checkpoint.py
ananthsub Mar 5, 2021
fd9f661
Update test_model_checkpoint.py
ananthsub Mar 5, 2021
f565786
Update test_model_checkpoint.py
ananthsub Mar 5, 2021
bceba8b
Update test_model_checkpoint.py
ananthsub Mar 5, 2021
1e9244b
Update test_model_checkpoint.py
ananthsub Mar 5, 2021
47868d1
clarify-names
ananthsub Mar 9, 2021
4b96403
Update model_checkpoint.py
ananthsub Mar 9, 2021
989fafa
Update model_checkpoint.py
ananthsub Mar 9, 2021
99be720
Update model_checkpoint.py
ananthsub Mar 9, 2021
524ba68
Update model_checkpoint.py
ananthsub Mar 9, 2021
c2120ff
Update model_checkpoint.py
ananthsub Mar 9, 2021
3ccf8d1
mutual-exclusive
ananthsub Mar 11, 2021
548e938
fix-default-0
ananthsub Mar 11, 2021
376962b
Update CHANGELOG.md
ananthsub Mar 11, 2021
1e7f640
formatting
ananthsub Mar 11, 2021
ab4012d
make-private
ananthsub Mar 11, 2021
89c2df7
rebase
ananthsub Mar 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

- Added support to checkpoint after training steps in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))

- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))

Expand Down Expand Up @@ -55,6 +56,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))


- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))

Expand Down
129 changes: 105 additions & 24 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,25 @@ class ModelCheckpoint(Callback):
save_weights_only: if ``True``, then only the model's weights will be
saved (``model.save_weights(filepath)``), else the full model
is saved (``model.save(filepath)``).
every_n_train_steps: Number of training steps between checkpoints.
If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training
To disable, set ``every_n_train_steps = 0``. This value must be ``None`` non-negative.
This must be mutually exclusive with ``every_n_val_epochs``.
every_n_val_epochs: Number of validation epochs between checkpoints.
If ``every_n_val_epochs == None or every_n_val_epochs == 0``, we skip saving on validation end
To disable, set ``every_n_val_epochs = 0``. This value must be ``None`` or non-negative.
This must be mutually exclusive with ``every_n_train_steps``.
Setting both ``ModelCheckpoint(..., every_n_val_epochs=V)`` and
``Trainer(max_epochs=N, check_val_every_n_epoch=M)``
will only save checkpoints at epochs 0 < E <= N
where both values for ``every_n_val_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
period: Interval (number of epochs) between checkpoints.

.. warning::
This argument has been deprecated in v1.3 and will be removed in v1.5.

Use ``every_n_val_epochs`` instead.

Note:
For extra customization, ModelCheckpoint includes the following attributes:

Expand Down Expand Up @@ -165,16 +182,17 @@ def __init__(
save_top_k: Optional[int] = None,
save_weights_only: bool = False,
mode: str = "min",
period: int = 1,
auto_insert_metric_name: bool = True
auto_insert_metric_name: bool = True,
every_n_train_steps: Optional[int] = None,
every_n_val_epochs: Optional[int] = None,
period: Optional[int] = None,
):
super().__init__()
self.monitor = monitor
self.verbose = verbose
self.save_last = save_last
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.auto_insert_metric_name = auto_insert_metric_name
self._last_global_step_saved = -1
self.current_score = None
Expand All @@ -188,6 +206,7 @@ def __init__(

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
self.__validate_init_configuration()

def on_pretrain_routine_start(self, trainer, pl_module):
Expand All @@ -197,10 +216,26 @@ def on_pretrain_routine_start(self, trainer, pl_module):
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint

def on_validation_end(self, trainer, pl_module):
def on_train_batch_end(self, trainer, *args, **kwargs) -> None:
""" Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps` """
if self._should_skip_saving_checkpoint(trainer):
return
step = trainer.global_step
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
if skip_batch:
return
self.save_checkpoint(trainer)

def on_validation_end(self, trainer, *args, **kwargs) -> None:
"""
checkpoints can be saved at the end of the val loop
"""
skip = (
self._should_skip_saving_checkpoint(trainer) or self._every_n_val_epochs < 1
or (trainer.current_epoch + 1) % self._every_n_val_epochs != 0
)
if skip:
return
self.save_checkpoint(trainer)

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
Expand Down Expand Up @@ -228,20 +263,8 @@ def save_checkpoint(self, trainer, unused: Optional = None):
" has been removed. Support for the old signature will be removed in v1.5", DeprecationWarning
)

epoch = trainer.current_epoch
global_step = trainer.global_step

from pytorch_lightning.trainer.states import TrainerState
if (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or self._last_global_step_saved == global_step # already saved at the last step
):
return

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)

Expand All @@ -260,9 +283,32 @@ def save_checkpoint(self, trainer, unused: Optional = None):
# Mode 3: save last checkpoints
self._save_last_checkpoint(trainer, monitor_candidates)

def _should_skip_saving_checkpoint(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
or trainer.state != TrainerState.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
if self._every_n_train_steps < 0:
raise MisconfigurationException(
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
)
if self._every_n_val_epochs < 0:
raise MisconfigurationException(
f'Invalid value for every_n_val_epochs={self._every_n_val_epochs}. Must be >= 0'
)
if self._every_n_train_steps > 0 and self._every_n_val_epochs > 0:
raise MisconfigurationException(
f'Invalid values for every_n_train_steps={self._every_n_train_steps}'
' and every_n_val_epochs={self._every_n_val_epochs}.'
' Both cannot be enabled at the same time.'
)
if self.monitor is None:
# None: save last epoch, -1: save all epochs, 0: nothing is saved
if self.save_top_k not in (None, -1, 0):
Expand Down Expand Up @@ -309,6 +355,46 @@ def __init_monitor_mode(self, monitor, mode):

self.kth_value, self.mode = mode_dict[mode]

def __init_triggers(
self, every_n_train_steps: Optional[int], every_n_val_epochs: Optional[int], period: Optional[int]
) -> None:

# Default to running once after each validation epoch if neither
# every_n_train_steps nor every_n_val_epochs is set
if every_n_train_steps is None and every_n_val_epochs is None:
self._every_n_val_epochs = 1
self._every_n_train_steps = 0
log.debug("Both every_n_train_steps and every_n_val_epochs are not set. Setting every_n_val_epochs=1")
else:
self._every_n_val_epochs = every_n_val_epochs or 0
self._every_n_train_steps = every_n_train_steps or 0

# period takes precedence over every_n_val_epochs for backwards compatibility
if period is not None:
rank_zero_warn(
'Argument `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
self._every_n_val_epochs = period

self._period = self._every_n_val_epochs

@property
def period(self) -> Optional[int]:
rank_zero_warn(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
return self._period

@period.setter
def period(self, value: Optional[int]) -> None:
rank_zero_warn(
'Property `period` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `every_n_val_epochs` instead.', DeprecationWarning
)
self._period = value
ananthsub marked this conversation as resolved.
Show resolved Hide resolved

@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
Expand Down Expand Up @@ -422,11 +508,8 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any],

"""
filename = self._format_checkpoint_name(
self.filename,
epoch,
step,
metrics,
auto_insert_metric_name=self.auto_insert_metric_name)
self.filename, epoch, step, metrics, auto_insert_metric_name=self.auto_insert_metric_name
)

if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
Expand Down Expand Up @@ -581,9 +664,7 @@ def _save_none_monitor_checkpoint(self, trainer, monitor_candidates: Dict[str, A
self._save_model(trainer, filepath)

if (
self.save_top_k is None
and self.best_model_path
and self.best_model_path != filepath
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
and trainer.is_global_zero
):
self._del_model(self.best_model_path)
Expand Down
132 changes: 124 additions & 8 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,11 +434,8 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):

# auto_insert_metric_name=False
ckpt_name = ModelCheckpoint._format_checkpoint_name(
'epoch={epoch:03d}-val_acc={val/acc}',
3,
2,
{'val/acc': 0.03},
auto_insert_metric_name=False)
'epoch={epoch:03d}-val_acc={val/acc}', 3, 2, {'val/acc': 0.03}, auto_insert_metric_name=False
)
assert ckpt_name == 'epoch=003-val_acc=0.03'


Expand Down Expand Up @@ -524,6 +521,45 @@ def test_none_monitor_save_last(tmpdir):
ModelCheckpoint(dirpath=tmpdir, save_last=False)


def test_invalid_every_n_val_epochs(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=0)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)


def test_invalid_every_n_train_steps(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative every_n_val_epochs argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be >= 0'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=-3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1)
ModelCheckpoint(dirpath=tmpdir, every_n_val_epochs=2)


def test_invalid_every_n_train_steps_val_epochs_combination(tmpdir):
"""
Test that a MisconfigurationException is raised if both
every_n_val_epochs and every_n_train_steps are enabled together.
"""
with pytest.raises(MisconfigurationException, match=r'.*Both cannot be enabled at the same time'):
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=1, every_n_val_epochs=2)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=0, every_n_val_epochs=3)
ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=4, every_n_val_epochs=0)


def test_none_every_n_train_steps_val_epochs(tmpdir):
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir)
assert checkpoint_callback.period == 1
assert checkpoint_callback._every_n_val_epochs == 1
assert checkpoint_callback._every_n_train_steps == 0


def test_model_checkpoint_save_last_none_monitor(tmpdir, caplog):
""" Test that it is possible to save all checkpoints when monitor=None. """
seed_everything()
Expand Down Expand Up @@ -567,9 +603,8 @@ def test_model_checkpoint_period(tmpdir, period: int):
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=0.1,
limit_val_batches=0.1,
val_check_interval=1.0,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)
Expand All @@ -579,6 +614,87 @@ def test_model_checkpoint_period(tmpdir, period: int):
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs(tmpdir, every_n_val_epochs):
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir, filename='{epoch}', save_top_k=-1, every_n_val_epochs=every_n_val_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)


@pytest.mark.parametrize("every_n_val_epochs", list(range(4)))
def test_model_checkpoint_every_n_val_epochs_and_period(tmpdir, every_n_val_epochs):
""" Tests that if period is set, it takes precedence over every_n_val_epochs for backwards compatibility. """
model = LogInTwoMethods()
epochs = 5
checkpoint_callback = ModelCheckpoint(
dirpath=tmpdir,
filename='{epoch}',
save_top_k=-1,
every_n_val_epochs=(2 * every_n_val_epochs),
period=every_n_val_epochs
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
max_epochs=epochs,
limit_train_batches=1,
limit_val_batches=1,
logger=False,
)
trainer.fit(model)

# check that the correct ckpts were created
expected = [f'epoch={e}.ckpt' for e in range(epochs)
if not (e + 1) % every_n_val_epochs] if every_n_val_epochs > 0 else []
assert set(os.listdir(tmpdir)) == set(expected)


def test_ckpt_every_n_train_steps(tmpdir):
""" Tests that the checkpoints are saved every n training steps. """

model = LogInTwoMethods()
every_n_train_steps = 16
max_epochs = 2
epoch_length = 64
checkpoint_callback = ModelCheckpoint(
filename="{step}",
every_n_val_epochs=0,
every_n_train_steps=every_n_train_steps,
dirpath=tmpdir,
save_top_k=-1,
save_last=False,
)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
progress_bar_refresh_rate=0,
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
callbacks=[checkpoint_callback],
logger=False,
)

trainer.fit(model)
expected = [
f"step={i}.ckpt" for i in range(every_n_train_steps - 1, max_epochs * epoch_length, every_n_train_steps)
]
assert set(os.listdir(tmpdir)) == set(expected)


def test_model_checkpoint_topk_zero(tmpdir):
""" Test that no checkpoints are saved when save_top_k=0. """
model = LogInTwoMethods()
Expand Down
7 changes: 7 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,10 @@ def configure_optimizers(self):

with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
trainer.fit(model)


def test_v1_5_0_model_checkpoint_period(tmpdir):
with no_warning_call(DeprecationWarning):
ModelCheckpoint(dirpath=tmpdir)
with pytest.deprecated_call(match="is deprecated in v1.3 and will be removed in v1.5"):
ModelCheckpoint(dirpath=tmpdir, period=1)