Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish Allow on_save_checkpoint... #3688

Merged
merged 40 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
bd08451
Finish #3562
williamFalcon Sep 28, 2020
211d82f
Apply suggestions from code review
Borda Sep 28, 2020
4ad52c2
Apply suggestions from code review
Borda Sep 28, 2020
742a46c
fix tests
williamFalcon Sep 28, 2020
49ce472
Finish #3562
williamFalcon Sep 28, 2020
cce9e16
Apply suggestions from code review
Borda Sep 28, 2020
d6b5e13
Apply suggestions from code review
Borda Sep 28, 2020
4a16581
fix tests
williamFalcon Sep 28, 2020
9d78a1f
fix structure
williamFalcon Sep 28, 2020
b9ab13f
fix structure
williamFalcon Sep 28, 2020
a957244
Merge branch 'r3n' of https://github.com/PyTorchLightning/pytorch-lig…
williamFalcon Sep 28, 2020
33eb5b8
Merge branch 'master' into r3n
Sep 29, 2020
6a7a7c0
make save_last test pass
awaelchli Sep 29, 2020
b8192fb
unnecessary global rank check
awaelchli Sep 29, 2020
4eea9de
fix test
awaelchli Sep 30, 2020
d792143
Merge remote-tracking branch 'PyTorchLightning/r3n' into r3n
awaelchli Sep 30, 2020
cbdcde9
update test
awaelchli Sep 30, 2020
8068b93
update test
awaelchli Sep 30, 2020
32f0754
test
awaelchli Sep 30, 2020
075f32b
test
awaelchli Sep 30, 2020
648f47a
run save on all
awaelchli Sep 30, 2020
ad643a8
remove assert
awaelchli Sep 30, 2020
b8fc751
tracking saves
awaelchli Sep 30, 2020
5cd5edd
check if fails
awaelchli Sep 30, 2020
26c8312
test
awaelchli Sep 30, 2020
84070cd
clean up
awaelchli Sep 30, 2020
0b268f1
adjust horovod test
awaelchli Sep 30, 2020
219aa6e
clean up
awaelchli Sep 30, 2020
a2f2fe5
remove unnecessary makdirs
awaelchli Sep 30, 2020
9ba223c
Merge branch 'master' into r3n
awaelchli Sep 30, 2020
48ecd41
change
awaelchli Sep 30, 2020
dfbcff3
undo
awaelchli Sep 30, 2020
d133adf
debug
awaelchli Sep 30, 2020
ad3dcfc
debug
awaelchli Sep 30, 2020
02523ab
debug
awaelchli Sep 30, 2020
dc9ab72
debug
awaelchli Sep 30, 2020
cc30c2a
mock
awaelchli Sep 30, 2020
22895a1
undo debug code
awaelchli Sep 30, 2020
9cdaaf8
add extra assertions
awaelchli Sep 30, 2020
8c3f19b
test
awaelchli Sep 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -176,16 +176,16 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]

@rank_zero_only
def save_checkpoint(self, trainer, pl_module):
"""
Performs the main logic around saving a checkpoint
Performs the main logic around saving a checkpoint.
This method runs on all ranks, it is the responsibility of `self.save_function`
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
"""
epoch = trainer.current_epoch

if (
trainer.global_rank != 0 # only run on main process
or self.save_top_k == 0 # no models are saved
self.save_top_k == 0 # no models are saved
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
Expand All @@ -207,16 +207,16 @@ def save_checkpoint(self, trainer, pl_module):

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)

# Mode 2: save all checkpoints OR only the top k
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
if self.save_top_k == -1:
self._save_all_checkpoints(trainer, pl_module, epoch, filepath)
else:
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
Copy link
Contributor

@ananthsub ananthsub Sep 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic already feels too complex. i hope solving #2586 will force us to consolidate some of this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ananth, this is exactly what I was thinking. Saving all models is really a special case of save_top_k and should be handled there. It becomes obvious when we look at the fix here: #3735
Would appreciate your feedback there too!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah... i think we need to come back to this modelcheckpoint after 1.0 with a nice clean re-write.
It's gotten super messy now haha


# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i had to switch the order here. We need to run save_last after topk, because the topk code tracks the best model path.
otherwise the last.ckpt will not point to the correct "best" model
the new test below checks that.


def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in __init_ckpt_dir, you could have L261 be gated by trainer.is_global_zero to avoid unnecessary file I/O outside rank-0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, this one we missed. but we don't have access to trainer, nor global rank at this point when init happens.
I'm wondering if we even need to call makedirs at this point. We could delay it, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes it could be delayed to on pretrain routine start

Expand Down Expand Up @@ -279,18 +279,21 @@ def __init_monitor_mode(self, monitor, mode):

self.kth_value, self.mode = mode_dict[mode]

@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
self._fs.rm(filepath)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, filepath: str, trainer, pl_module):
# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
if trainer.is_global_zero:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
# delegate the saving to the trainer
if self.save_function is not None:
self.save_function(filepath, self.save_weights_only)
else:
Expand Down Expand Up @@ -325,7 +328,7 @@ def _format_checkpoint_name(
filename = "{epoch}"
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
if groups:
if len(groups) >= 0:
metrics["epoch"] = epoch
for group in groups:
name = group[1:]
Expand Down Expand Up @@ -404,10 +407,8 @@ def __resolve_ckpt_dir(self, trainer, pl_module):

self.dirpath = ckpt_path

assert (
trainer.global_rank == 0
), "tried to make a checkpoint from non global_rank=0"
self._fs.makedirs(self.dirpath, exist_ok=True)
if trainer.is_global_zero:
self._fs.makedirs(self.dirpath, exist_ok=True)

def _add_backward_monitor_support(self, trainer):
metrics = trainer.logger_connector.callback_metrics
Expand Down Expand Up @@ -460,13 +461,18 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi

# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
filename = self._format_checkpoint_name(
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{filename}.ckpt")
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")

self._save_model(last_filepath, trainer, pl_module)
if self.last_model_path and self.last_model_path != last_filepath and (self.save_top_k != -1 or self.save_last):
if (
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
):
self._del_model(self.last_model_path)
self.last_model_path = last_filepath

Expand All @@ -491,15 +497,14 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose:
log.info(
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
)

def _save_all_checkpoints(self, trainer, pl_module, epoch, filepath):
if self.verbose:
log.info(f"Epoch {epoch:d}: saving model to {filepath}")
rank_zero_info(f"Epoch {epoch:d}: saving model to {filepath}")

assert (trainer.global_rank == 0), "tried to make a checkpoint from non global_rank=0"
self._save_model(filepath, trainer, pl_module)

def _is_valid_monitor_key(self, metrics):
Expand Down Expand Up @@ -535,7 +540,7 @@ def _do_check_save(
self.best_model_score = self.best_k_models[self.best_model_path]

if self.verbose:
log.info(
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
f" saving model to {filepath} as top {self.save_top_k}"
Expand Down
92 changes: 54 additions & 38 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
from unittest.mock import MagicMock

import yaml
import pickle
import platform
Expand Down Expand Up @@ -92,21 +94,24 @@ class ModelCheckpointTestInvocations(ModelCheckpoint):

def __init__(self, expected_count, *args, **kwargs):
super().__init__(*args, **kwargs)
self.count = 0
self.expected_count = expected_count
self.on_save_checkpoint_count = 0

def on_train_start(self, trainer, pl_module):
torch.save = MagicMock()

def _save_model(self, filepath, trainer, pl_module):
# make sure we don't save twice
assert not os.path.isfile(filepath)
self.count += 1
super()._save_model(filepath, trainer, pl_module)
def on_save_checkpoint(self, trainer, pl_module):
# expect all ranks to run but only rank 0 will actually write the checkpoint file
super().on_save_checkpoint(trainer, pl_module)
self.on_save_checkpoint_count += 1

def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
# on rank 0 we expect the saved files and on all others no saves
assert (trainer.global_rank == 0 and self.count == self.expected_count) or (
trainer.global_rank > 0 and self.count == 0
)
assert self.on_save_checkpoint_count == self.expected_count
if trainer.is_global_zero:
assert torch.save.call_count == self.expected_count
else:
assert torch.save.call_count == 0


@pytest.mark.skipif(
Expand Down Expand Up @@ -220,34 +225,6 @@ def test_none_monitor_save_last(tmpdir):
ModelCheckpoint(filepath=tmpdir, save_last=False)


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
Copy link
Contributor

@awaelchli awaelchli Sep 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is 1:1 the same as a few lines above, I remove it.

"""Tests that the save_last checkpoint contains the latest information."""
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
)
trainer.fit(model)

path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {})
assert path_last_epoch != model_checkpoint.last_model_path

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(model_checkpoint.last_model_path)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()


def test_model_checkpoint_none_monitor(tmpdir):
""" Test that it is possible to save all checkpoints when monitor=None. """
model = EvalModelTemplate()
Expand Down Expand Up @@ -419,3 +396,42 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
)
trainer.fit(model)
assert caplog.messages.count('Saving latest checkpoint...') == save_last


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
""" Tests that the save_last checkpoint contains the latest information. """
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(
monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
)
trainer.fit(model)

path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt")
path_last = str(tmpdir / f"last.ckpt")
assert path_last == model_checkpoint.last_model_path

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))

ch_type = type(model_checkpoint)
assert all(list(
ckpt_last["callbacks"][ch_type][k] == ckpt_last_epoch["callbacks"][ch_type][k]
for k in ("best_model_score", "best_model_path")
))

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
model_last = EvalModelTemplate.load_from_checkpoint(
model_checkpoint.last_model_path
)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()
3 changes: 0 additions & 3 deletions tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ def run_test_from_config(trainer_options):
assert hvd.size() == 2

if trainer.global_rank > 0:
# on higher ranks the checkpoint location is unknown
# we want to test checkpointing on rank 0 only
assert not trainer.checkpoint_callback.best_model_path
return

# test model loading
Expand Down