-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finish Allow on_save_checkpoint... #3688
Changes from 28 commits
bd08451
211d82f
4ad52c2
742a46c
49ce472
cce9e16
d6b5e13
4a16581
9d78a1f
b9ab13f
a957244
33eb5b8
6a7a7c0
b8192fb
4eea9de
d792143
cbdcde9
8068b93
32f0754
075f32b
648f47a
ad643a8
b8fc751
5cd5edd
26c8312
84070cd
0b268f1
219aa6e
a2f2fe5
9ba223c
48ecd41
dfbcff3
d133adf
ad3dcfc
02523ab
dc9ab72
cc30c2a
22895a1
9cdaaf8
8c3f19b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,7 @@ | |
import torch | ||
from pytorch_lightning import _logger as log | ||
from pytorch_lightning.callbacks.base import Callback | ||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn | ||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, rank_zero_info | ||
from pytorch_lightning.utilities.cloud_io import get_filesystem | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
|
@@ -176,16 +176,16 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): | |
self.best_model_score = checkpointed_state["best_model_score"] | ||
self.best_model_path = checkpointed_state["best_model_path"] | ||
|
||
@rank_zero_only | ||
def save_checkpoint(self, trainer, pl_module): | ||
""" | ||
Performs the main logic around saving a checkpoint | ||
Performs the main logic around saving a checkpoint. | ||
This method runs on all ranks, it is the responsibility of `self.save_function` | ||
to handle correct behaviour in distributed training, i.e., saving only on rank 0. | ||
""" | ||
epoch = trainer.current_epoch | ||
|
||
if ( | ||
trainer.global_rank != 0 # only run on main process | ||
or self.save_top_k == 0 # no models are saved | ||
self.save_top_k == 0 # no models are saved | ||
or self.period < 1 # no models are saved | ||
or (epoch + 1) % self.period # skip epoch | ||
or trainer.running_sanity_check # don't save anything during sanity check | ||
|
@@ -207,16 +207,16 @@ def save_checkpoint(self, trainer, pl_module): | |
|
||
# callback supports multiple simultaneous modes | ||
# here we call each mode sequentially | ||
# Mode 1: save the last checkpoint | ||
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath) | ||
|
||
# Mode 2: save all checkpoints OR only the top k | ||
# Mode 1: save all checkpoints OR only the top k | ||
if self.save_top_k: | ||
if self.save_top_k == -1: | ||
self._save_all_checkpoints(trainer, pl_module, epoch, filepath) | ||
else: | ||
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath) | ||
|
||
# Mode 2: save the last checkpoint | ||
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i had to switch the order here. We need to run save_last after topk, because the topk code tracks the best model path. |
||
|
||
def __validate_init_configuration(self): | ||
if self.save_top_k is not None and self.save_top_k < -1: | ||
raise MisconfigurationException( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in __init_ckpt_dir, you could have L261 be gated by trainer.is_global_zero to avoid unnecessary file I/O outside rank-0 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. true, this one we missed. but we don't have access to trainer, nor global rank at this point when init happens. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes it could be delayed to on pretrain routine start |
||
|
@@ -279,18 +279,21 @@ def __init_monitor_mode(self, monitor, mode): | |
|
||
self.kth_value, self.mode = mode_dict[mode] | ||
|
||
@rank_zero_only | ||
def _del_model(self, filepath: str): | ||
if self._fs.exists(filepath): | ||
self._fs.rm(filepath) | ||
log.debug(f"Removed checkpoint: {filepath}") | ||
|
||
def _save_model(self, filepath: str, trainer, pl_module): | ||
# in debugging, track when we save checkpoints | ||
trainer.dev_debugger.track_checkpointing_history(filepath) | ||
|
||
# make paths | ||
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
if trainer.is_global_zero: | ||
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) | ||
|
||
# delegate the saving to the model | ||
# delegate the saving to the trainer | ||
if self.save_function is not None: | ||
self.save_function(filepath, self.save_weights_only) | ||
else: | ||
|
@@ -325,7 +328,7 @@ def _format_checkpoint_name( | |
filename = "{epoch}" | ||
# check and parse user passed keys in the string | ||
groups = re.findall(r"(\{.*?)[:\}]", filename) | ||
if groups: | ||
if len(groups) >= 0: | ||
metrics["epoch"] = epoch | ||
for group in groups: | ||
name = group[1:] | ||
|
@@ -404,10 +407,8 @@ def __resolve_ckpt_dir(self, trainer, pl_module): | |
|
||
self.dirpath = ckpt_path | ||
|
||
assert ( | ||
trainer.global_rank == 0 | ||
), "tried to make a checkpoint from non global_rank=0" | ||
self._fs.makedirs(self.dirpath, exist_ok=True) | ||
if trainer.is_global_zero: | ||
self._fs.makedirs(self.dirpath, exist_ok=True) | ||
|
||
def _add_backward_monitor_support(self, trainer): | ||
metrics = trainer.logger_connector.callback_metrics | ||
|
@@ -460,13 +461,18 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi | |
|
||
# when user ALSO asked for the 'last.ckpt' change the name | ||
if self.save_last: | ||
filename = self._format_checkpoint_name( | ||
last_filepath = self._format_checkpoint_name( | ||
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix | ||
) | ||
last_filepath = os.path.join(self.dirpath, f"{filename}.ckpt") | ||
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt") | ||
|
||
self._save_model(last_filepath, trainer, pl_module) | ||
if self.last_model_path and self.last_model_path != last_filepath and (self.save_top_k != -1 or self.save_last): | ||
if ( | ||
self.last_model_path | ||
and self.last_model_path != last_filepath | ||
and (self.save_top_k != -1 or self.save_last) | ||
and trainer.is_global_zero | ||
): | ||
self._del_model(self.last_model_path) | ||
self.last_model_path = last_filepath | ||
|
||
|
@@ -491,15 +497,14 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath): | |
elif self.check_monitor_top_k(current): | ||
self._do_check_save(filepath, current, epoch, trainer, pl_module) | ||
elif self.verbose: | ||
log.info( | ||
rank_zero_info( | ||
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}" | ||
) | ||
|
||
def _save_all_checkpoints(self, trainer, pl_module, epoch, filepath): | ||
if self.verbose: | ||
log.info(f"Epoch {epoch:d}: saving model to {filepath}") | ||
rank_zero_info(f"Epoch {epoch:d}: saving model to {filepath}") | ||
|
||
assert (trainer.global_rank == 0), "tried to make a checkpoint from non global_rank=0" | ||
self._save_model(filepath, trainer, pl_module) | ||
|
||
def _is_valid_monitor_key(self, metrics): | ||
|
@@ -535,7 +540,7 @@ def _do_check_save( | |
self.best_model_score = self.best_k_models[self.best_model_path] | ||
|
||
if self.verbose: | ||
log.info( | ||
rank_zero_info( | ||
f"Epoch {epoch:d}: {self.monitor} reached" | ||
f" {current:0.5f} (best {self.best_model_score:0.5f})," | ||
f" saving model to {filepath} as top {self.save_top_k}" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
import os | ||
from unittest.mock import MagicMock | ||
|
||
import yaml | ||
import pickle | ||
import platform | ||
|
@@ -92,21 +94,24 @@ class ModelCheckpointTestInvocations(ModelCheckpoint): | |
|
||
def __init__(self, expected_count, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.count = 0 | ||
self.expected_count = expected_count | ||
self.on_save_checkpoint_count = 0 | ||
|
||
def on_train_start(self, trainer, pl_module): | ||
torch.save = MagicMock() | ||
|
||
def _save_model(self, filepath, trainer, pl_module): | ||
# make sure we don't save twice | ||
assert not os.path.isfile(filepath) | ||
self.count += 1 | ||
super()._save_model(filepath, trainer, pl_module) | ||
def on_save_checkpoint(self, trainer, pl_module): | ||
# expect all ranks to run but only rank 0 will actually write the checkpoint file | ||
super().on_save_checkpoint(trainer, pl_module) | ||
self.on_save_checkpoint_count += 1 | ||
|
||
def on_train_end(self, trainer, pl_module): | ||
super().on_train_end(trainer, pl_module) | ||
# on rank 0 we expect the saved files and on all others no saves | ||
assert (trainer.global_rank == 0 and self.count == self.expected_count) or ( | ||
trainer.global_rank > 0 and self.count == 0 | ||
) | ||
assert self.on_save_checkpoint_count == self.expected_count | ||
if trainer.is_global_zero: | ||
assert torch.save.call_count == self.expected_count | ||
else: | ||
assert torch.save.call_count == 0 | ||
|
||
|
||
@pytest.mark.skipif( | ||
|
@@ -220,34 +225,6 @@ def test_none_monitor_save_last(tmpdir): | |
ModelCheckpoint(filepath=tmpdir, save_last=False) | ||
|
||
|
||
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this test is 1:1 the same as a few lines above, I remove it. |
||
"""Tests that the save_last checkpoint contains the latest information.""" | ||
seed_everything(100) | ||
model = EvalModelTemplate() | ||
num_epochs = 3 | ||
model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True) | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
early_stop_callback=False, | ||
checkpoint_callback=model_checkpoint, | ||
max_epochs=num_epochs, | ||
) | ||
trainer.fit(model) | ||
|
||
path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {}) | ||
assert path_last_epoch != model_checkpoint.last_model_path | ||
|
||
ckpt_last_epoch = torch.load(path_last_epoch) | ||
ckpt_last = torch.load(model_checkpoint.last_model_path) | ||
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) | ||
|
||
# it is easier to load the model objects than to iterate over the raw dict of tensors | ||
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) | ||
model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path) | ||
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): | ||
assert w0.eq(w1).all() | ||
|
||
|
||
def test_model_checkpoint_none_monitor(tmpdir): | ||
""" Test that it is possible to save all checkpoints when monitor=None. """ | ||
model = EvalModelTemplate() | ||
|
@@ -419,3 +396,42 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v | |
) | ||
trainer.fit(model) | ||
assert caplog.messages.count('Saving latest checkpoint...') == save_last | ||
|
||
|
||
def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): | ||
""" Tests that the save_last checkpoint contains the latest information. """ | ||
seed_everything(100) | ||
model = EvalModelTemplate() | ||
num_epochs = 3 | ||
model_checkpoint = ModelCheckpoint( | ||
monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True | ||
) | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
early_stop_callback=False, | ||
checkpoint_callback=model_checkpoint, | ||
max_epochs=num_epochs, | ||
) | ||
trainer.fit(model) | ||
|
||
path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") | ||
path_last = str(tmpdir / f"last.ckpt") | ||
assert path_last == model_checkpoint.last_model_path | ||
|
||
ckpt_last_epoch = torch.load(path_last_epoch) | ||
ckpt_last = torch.load(path_last) | ||
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) | ||
|
||
ch_type = type(model_checkpoint) | ||
assert all(list( | ||
ckpt_last["callbacks"][ch_type][k] == ckpt_last_epoch["callbacks"][ch_type][k] | ||
for k in ("best_model_score", "best_model_path") | ||
)) | ||
|
||
# it is easier to load the model objects than to iterate over the raw dict of tensors | ||
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch) | ||
model_last = EvalModelTemplate.load_from_checkpoint( | ||
model_checkpoint.last_model_path | ||
) | ||
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()): | ||
assert w0.eq(w1).all() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this logic already feels too complex. i hope solving #2586 will force us to consolidate some of this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ananth, this is exactly what I was thinking. Saving all models is really a special case of save_top_k and should be handled there. It becomes obvious when we look at the fix here: #3735
Would appreciate your feedback there too!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah... i think we need to come back to this modelcheckpoint after 1.0 with a nice clean re-write.
It's gotten super messy now haha