Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish Allow on_save_checkpoint... #3688

Merged
merged 40 commits into from
Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
bd08451
Finish #3562
williamFalcon Sep 28, 2020
211d82f
Apply suggestions from code review
Borda Sep 28, 2020
4ad52c2
Apply suggestions from code review
Borda Sep 28, 2020
742a46c
fix tests
williamFalcon Sep 28, 2020
49ce472
Finish #3562
williamFalcon Sep 28, 2020
cce9e16
Apply suggestions from code review
Borda Sep 28, 2020
d6b5e13
Apply suggestions from code review
Borda Sep 28, 2020
4a16581
fix tests
williamFalcon Sep 28, 2020
9d78a1f
fix structure
williamFalcon Sep 28, 2020
b9ab13f
fix structure
williamFalcon Sep 28, 2020
a957244
Merge branch 'r3n' of https://github.com/PyTorchLightning/pytorch-lig…
williamFalcon Sep 28, 2020
33eb5b8
Merge branch 'master' into r3n
Sep 29, 2020
6a7a7c0
make save_last test pass
awaelchli Sep 29, 2020
b8192fb
unnecessary global rank check
awaelchli Sep 29, 2020
4eea9de
fix test
awaelchli Sep 30, 2020
d792143
Merge remote-tracking branch 'PyTorchLightning/r3n' into r3n
awaelchli Sep 30, 2020
cbdcde9
update test
awaelchli Sep 30, 2020
8068b93
update test
awaelchli Sep 30, 2020
32f0754
test
awaelchli Sep 30, 2020
075f32b
test
awaelchli Sep 30, 2020
648f47a
run save on all
awaelchli Sep 30, 2020
ad643a8
remove assert
awaelchli Sep 30, 2020
b8fc751
tracking saves
awaelchli Sep 30, 2020
5cd5edd
check if fails
awaelchli Sep 30, 2020
26c8312
test
awaelchli Sep 30, 2020
84070cd
clean up
awaelchli Sep 30, 2020
0b268f1
adjust horovod test
awaelchli Sep 30, 2020
219aa6e
clean up
awaelchli Sep 30, 2020
a2f2fe5
remove unnecessary makdirs
awaelchli Sep 30, 2020
9ba223c
Merge branch 'master' into r3n
awaelchli Sep 30, 2020
48ecd41
change
awaelchli Sep 30, 2020
dfbcff3
undo
awaelchli Sep 30, 2020
d133adf
debug
awaelchli Sep 30, 2020
ad3dcfc
debug
awaelchli Sep 30, 2020
02523ab
debug
awaelchli Sep 30, 2020
dc9ab72
debug
awaelchli Sep 30, 2020
cc30c2a
mock
awaelchli Sep 30, 2020
22895a1
undo debug code
awaelchli Sep 30, 2020
9cdaaf8
add extra assertions
awaelchli Sep 30, 2020
8c3f19b
test
awaelchli Sep 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, rank_zero_info
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -285,17 +285,20 @@ def __init_monitor_mode(self, monitor, mode):

self.kth_value, self.mode = mode_dict[mode]

@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
self._fs.rm(filepath)
log.debug(f"Removed checkpoint: {filepath}")

def _save_model(self, filepath: str, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
if trainer.is_global_zero:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -413,10 +416,8 @@ def __resolve_ckpt_dir(self, trainer, pl_module):

self.dirpath = ckpt_path

assert (
trainer.global_rank == 0
), "tried to make a checkpoint from non global_rank=0"
self._fs.makedirs(self.dirpath, exist_ok=True)
if trainer.is_global_zero:
self._fs.makedirs(self.dirpath, exist_ok=True)

def _add_backward_monitor_support(self, trainer):
metrics = trainer.logger_connector.callback_metrics
Expand Down Expand Up @@ -479,7 +480,12 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
last_filepath = os.path.join(self.dirpath, f"{filename}.ckpt")

self._save_model(last_filepath, trainer, pl_module)
if self.last_model_path and self.last_model_path != last_filepath and (self.save_top_k != -1 or self.save_last):
if (
self.last_model_path
and self.last_model_path != last_filepath
and (self.save_top_k != -1 or self.save_last)
and trainer.is_global_zero
):
self._del_model(self.last_model_path)
self.last_model_path = last_filepath

Expand All @@ -502,13 +508,13 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose:
log.info(
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
)

def _save_all_checkpoints(self, trainer, pl_module, epoch, filepath):
if self.verbose:
log.info(f"Epoch {epoch:d}: saving model to {filepath}")
rank_zero_info(f"Epoch {epoch:d}: saving model to {filepath}")

assert (trainer.global_rank == 0), "tried to make a checkpoint from non global_rank=0"
self._save_model(filepath, trainer, pl_module)
Expand All @@ -517,12 +523,12 @@ def _is_valid_monitor_key(self, metrics):
return self.monitor in metrics or len(metrics) == 0

def _do_check_save(
self,
filepath: str,
current: torch.Tensor,
epoch: int,
trainer,
pl_module,
self,
filepath: str,
current: torch.Tensor,
epoch: int,
trainer,
pl_module,
Borda marked this conversation as resolved.
Show resolved Hide resolved
):
# remove kth

Expand All @@ -546,16 +552,17 @@ def _do_check_save(
self.best_model_score = self.best_k_models[self.best_model_path]

if self.verbose:
log.info(
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
f" saving model to {filepath} as top {self.save_top_k}"
)
self._save_model(filepath, trainer, pl_module)

for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
if trainer.is_global_zero:
for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)

def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
"""
Expand Down
91 changes: 68 additions & 23 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def _save_model(self, filepath, trainer, pl_module):

def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
# on rank 0 we expect the saved files and on all others no saves
assert (trainer.global_rank == 0 and self.count == self.expected_count) or (
trainer.global_rank > 0 and self.count == 0
)

# expect all ranks to run but only rank 0 will actually write
# the checkpoint file
assert self.count == self.expected_count


@pytest.mark.skipif(
Expand Down Expand Up @@ -220,32 +220,28 @@ def test_none_monitor_save_last(tmpdir):
ModelCheckpoint(filepath=tmpdir, save_last=False)


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
Copy link
Contributor

@awaelchli awaelchli Sep 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is 1:1 the same as a few lines above, I remove it.

"""Tests that the save_last checkpoint contains the latest information."""
seed_everything(100)
def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
model = EvalModelTemplate()
num_epochs = 3
model_checkpoint = ModelCheckpoint(monitor='val_loss', filepath=tmpdir, save_top_k=num_epochs, save_last=True)
epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last-{epoch}"
model_checkpoint = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
max_epochs=epochs,
)
trainer.fit(model)

path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {})
assert path_last_epoch != model_checkpoint.last_model_path

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(model_checkpoint.last_model_path)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
model_last = EvalModelTemplate.load_from_checkpoint(model_checkpoint.last_model_path)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()
last_filename = model_checkpoint._format_checkpoint_name(
ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {}
)
last_filename = last_filename + ".ckpt"
assert str(tmpdir / last_filename) == model_checkpoint.last_model_path
assert set(os.listdir(tmpdir)) == set(
[f"epoch={i}.ckpt" for i in range(epochs)] + [last_filename, "lightning_logs"]
)
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last"


def test_model_checkpoint_none_monitor(tmpdir):
Expand Down Expand Up @@ -398,3 +394,52 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
)
trainer.fit(model)
assert caplog.messages.count('Saving latest checkpoint...') == save_last


def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
"""Tests that the save_last checkpoint contains the latest information."""
seed_everything(100)
model = EvalModelTemplate()
num_epochs = 3
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last-{epoch}"
model_checkpoint = ModelCheckpoint(
filepath=tmpdir, save_top_k=num_epochs, save_last=True
)
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=False,
checkpoint_callback=model_checkpoint,
max_epochs=num_epochs,
)
trainer.fit(model)
last_filename = model_checkpoint._format_checkpoint_name(
ModelCheckpoint.CHECKPOINT_NAME_LAST, num_epochs - 1, {}
)
path_last_epoch = model_checkpoint.format_checkpoint_name(
num_epochs - 1, {}
) # epoch=3.ckpt
path_last = str(tmpdir / f"{last_filename}.ckpt") # last-epoch=3.ckpt
assert path_last_epoch != path_last
ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(path_last)

path_last_epoch = model_checkpoint.format_checkpoint_name(num_epochs - 1, {})
assert path_last_epoch != model_checkpoint.last_model_path

ckpt_last_epoch = torch.load(path_last_epoch)
ckpt_last = torch.load(model_checkpoint.last_model_path)
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))
ch_type = type(model_checkpoint)
assert all(
ckpt_last["callbacks"][ch_type][k] == ckpt_last_epoch["callbacks"][ch_type][k]
for k in ("best_model_score", "best_model_path")
)

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = EvalModelTemplate.load_from_checkpoint(path_last_epoch)
model_last = EvalModelTemplate.load_from_checkpoint(
model_checkpoint.last_model_path
)
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
assert w0.eq(w1).all()
ModelCheckpoint.CHECKPOINT_NAME_LAST = "last"