From 89131c2ac262a4ed21be35b8654e53a205f3ef68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:06:51 +0200 Subject: [PATCH 01/31] class name as key --- pytorch_lightning/trainer/callback_hook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 606f6b2e4b52b..62bfa00ddcf1a 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -243,7 +243,7 @@ def __is_old_signature(fn: Callable) -> bool: return True return False - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: @@ -257,10 +257,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[type(callback)] = state + callback_states[callback.__class__.__name__] = state return callback_states - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') # Todo: the `callback_states` are dropped with TPUSpawn as they @@ -268,7 +268,7 @@ def on_load_checkpoint(self, checkpoint): # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(type(callback)) + state = callback_states.get(callback.__class__.__name__) if state: state = deepcopy(state) callback.on_load_checkpoint(state) From 63fb9830fbc7b4f81807035c03d1d8343df07bfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:37:22 +0200 Subject: [PATCH 02/31] string state identifier --- pytorch_lightning/callbacks/base.py | 4 ++++ pytorch_lightning/trainer/callback_hook.py | 5 +++-- .../trainer/connectors/checkpoint_connector.py | 2 +- tests/callbacks/test_early_stopping.py | 2 +- tests/trainer/connectors/test_callback_connector.py | 6 +++--- tests/trainer/logging_/test_logger_connector.py | 2 +- 6 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 768e4ebca30ee..a214904b1c31e 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -29,6 +29,10 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ + @property + def state_identifier(self) -> str: + return self.__class__.__name__ + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 62bfa00ddcf1a..dfc4922e14277 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -257,18 +257,19 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.__class__.__name__] = state + callback_states[callback.state_identifier] = state return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') + version = checkpoint.get('pytorch-lightning_version') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.__class__.__name__) + state = callback_states.get(callback.state_identifier) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4ae42e4bad6ac..887bd4064dc6c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -244,7 +244,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: structured dictionary: { 'epoch': training epoch 'global_step': training global step - 'pytorch-lightning_version': PyTorch Lightning's version + 'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint 'callbacks': "callback specific state"[] # if not weights_only 'optimizer_states': "PT optim's state_dict"[] # if not weights_only 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index cc619077ee136..7768e112d27cc 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -76,7 +76,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state + assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor='train_loss') diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 34149e2231bf5..aa9faa59a3188 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -68,11 +68,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): trainer.fit(model) ckpt = torch.load(str(tmpdir / "all_states.ckpt")) - state0 = ckpt["callbacks"][type(callback0)] - state1 = ckpt["callbacks"][type(callback1)] + state0 = ckpt["callbacks"]["StatefulCallback0"] + state1 = ckpt["callbacks"]["StatefulCallback1"] assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == 1 - assert type(checkpoint_callback) in ckpt["callbacks"] + assert "ModelCheckpoint" in ckpt["callbacks"] def test_attach_model_callbacks(): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 923821a5e50e4..310f43f94a177 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -268,7 +268,7 @@ def test_dataloader(self): def test_call_back_validator(tmpdir): - funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_')]) + funcs_name = sorted([f for f in dir(Callback) if not f.startswith('_') and callable(getattr(Callback, f))]) callbacks_func = [ 'on_after_backward', From 7dc218a7b491879b2d745be40ba6a1f18da7cd6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 03:51:04 +0200 Subject: [PATCH 03/31] add legacy state loading --- pytorch_lightning/callbacks/base.py | 6 +++++- pytorch_lightning/trainer/callback_hook.py | 6 +++++- tests/checkpointing/test_legacy_checkpoints.py | 1 + 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a214904b1c31e..bf82899373230 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type from pytorch_lightning.core.lightning import LightningModule @@ -33,6 +33,10 @@ class Callback(abc.ABC): def state_identifier(self) -> str: return self.__class__.__name__ + @property + def _legacy_state_identifier(self) -> Type: + return type(self) + def on_configure_sharded_model(self, trainer, pl_module: LightningModule) -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index dfc4922e14277..0f7d6a6fe5066 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,6 +14,7 @@ from abc import ABC from copy import deepcopy +from distutils.version import LooseVersion from inspect import signature from typing import Any, Callable, Dict, List, Optional, Type @@ -269,7 +270,10 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.state_identifier) + state = ( + callback_states.get(callback.state_identifier) + or callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 7d1284ee0d329..4080eb1deb788 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -60,6 +60,7 @@ "1.2.5", "1.2.6", "1.2.7", + "1.2.8", ] ) def test_resume_legacy_checkpoints(tmpdir, pl_version: str): From 04b588b7ebf11894ebefa2b7bd8cdfc9936b6464 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 21:03:11 +0200 Subject: [PATCH 04/31] update test --- tests/checkpointing/test_model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f58ff768759e8..5143e4e1eb33f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -840,7 +840,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): ckpt_last = torch.load(path_last) assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step")) - ch_type = type(model_checkpoint) + ch_type = "ModelCheckpoint" assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] # it is easier to load the model objects than to iterate over the raw dict of tensors @@ -1098,7 +1098,7 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts] + ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] From bb11e2872356ce8bbcf1f333cdec59cbe554d501 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 21:07:56 +0200 Subject: [PATCH 05/31] update tests --- tests/checkpointing/test_model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 5143e4e1eb33f..d2a57a6c4d2b8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -147,7 +147,7 @@ def configure_optimizers(self): assert chk['epoch'] == epoch + 1 assert chk['global_step'] == limit_train_batches * (epoch + 1) - mc_specific_data = chk['callbacks'][type(checkpoint)] + mc_specific_data = chk['callbacks']["ModelCheckpoint"] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score @@ -251,7 +251,7 @@ def configure_optimizers(self): assert chk['epoch'] == epoch + 1 assert chk['global_step'] == per_epoch_steps * (global_ix + 1) - mc_specific_data = chk['callbacks'][type(checkpoint)] + mc_specific_data = chk['callbacks']["ModelCheckpoint"] assert mc_specific_data['dirpath'] == checkpoint.dirpath assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score From 0259ecbdf72fe3fdbecd1524bcae62e9e328b871 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 11:20:23 +0200 Subject: [PATCH 06/31] flake8 --- pytorch_lightning/trainer/callback_hook.py | 4 +--- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 0f7d6a6fe5066..fb5c82a2330db 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -14,9 +14,8 @@ from abc import ABC from copy import deepcopy -from distutils.version import LooseVersion from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule @@ -264,7 +263,6 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') - version = checkpoint.get('pytorch-lightning_version') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 17c1f0ea589a5..13236d32ea522 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -264,10 +264,10 @@ def _make_assertions(epoch, ix, add=''): expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num) assert chk['global_step'] == expected_global_step - mc_specific_data = chk['callbacks']["ModelCheckpoint"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk['callbacks']["ModelCheckpoint"] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score if not reduce_lr_on_plateau: lr_scheduler_specific_data = chk['lr_schedulers'][0] From d56e5e47d7cf74199650eebe7d409acea7c29285 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 12:33:07 +0200 Subject: [PATCH 07/31] add test --- .../checkpointing/test_legacy_checkpoints.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 4080eb1deb788..d4d053f6136ac 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,11 +14,18 @@ import glob import os import sys +from copy import deepcopy +from pathlib import Path import pytest +import torch +from pytorch_lightning.callbacks.base import Callback + +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning import Trainer from tests import PATH_LEGACY +from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" @@ -87,3 +94,34 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): # assert result sys.path = orig_sys_paths + + +class StatefulCallback(Callback): + + def on_save_checkpoint(self, *args): + return {"content": 123} + + +def test_callback_state_loading_by_type(tmpdir): + """ Test that legacy checkpoints that don't use a state identifier can still be loaded. """ + model = BoringModel() + callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=1, + callbacks=[callback], + ) + trainer.fit(model) + # simulate old format where type(callback) was the key + new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) + old_checkpiont = deepcopy(new_checkpoint) + old_checkpiont["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} + torch.save(old_checkpiont, Path(tmpdir, "old.ckpt")) + + trainer = Trainer( + default_root_dir=tmpdir, + max_steps=2, + callbacks=[callback], + resume_from_checkpoint=Path(tmpdir, "old.ckpt"), + ) + trainer.fit(model) From 72ba44026c411476bae3544ebcd67d2961c6ffa3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 21:39:11 +0200 Subject: [PATCH 08/31] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/trainer/callback_hook.py | 7 ++----- tests/checkpointing/test_legacy_checkpoints.py | 6 +++--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 839d78d276efe..4f35e30f7ee5d 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -293,16 +293,13 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - callback_states = checkpoint.get('callbacks') + callback_states: Dict[str, dict] = checkpoint.get('callbacks') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = ( - callback_states.get(callback.state_identifier) - or callback_states.get(callback._legacy_state_identifier) - ) + state = callback_states.get(callback.state_identifier, callback_states.get(callback._legacy_state_identifier)) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index d4d053f6136ac..c44dc526a1314 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -114,9 +114,9 @@ def test_callback_state_loading_by_type(tmpdir): trainer.fit(model) # simulate old format where type(callback) was the key new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) - old_checkpiont = deepcopy(new_checkpoint) - old_checkpiont["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} - torch.save(old_checkpiont, Path(tmpdir, "old.ckpt")) + old_checkpoint = deepcopy(new_checkpoint) + old_checkpoint["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} + torch.save(old_checkpoint, Path(tmpdir, "old.ckpt")) trainer = Trainer( default_root_dir=tmpdir, From 79d85684d80443d449e937aba433e5425a736020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 22:21:37 +0200 Subject: [PATCH 09/31] improve test --- .../checkpointing/test_legacy_checkpoints.py | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index c44dc526a1314..b3e639c0ee3d5 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,15 +14,11 @@ import glob import os import sys -from copy import deepcopy from pathlib import Path import pytest -import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint - from pytorch_lightning import Trainer from tests import PATH_LEGACY from tests.helpers import BoringModel @@ -96,32 +92,41 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): sys.path = orig_sys_paths -class StatefulCallback(Callback): +class OldStatefulCallback(Callback): + + def __init__(self, state): + self.state = state + + @property + def state_identifier(self): + return type(self) def on_save_checkpoint(self, *args): - return {"content": 123} + return {"state": self.state} + + def on_load_checkpoint(self, callback_state): + self.state = callback_state["state"] -def test_callback_state_loading_by_type(tmpdir): - """ Test that legacy checkpoints that don't use a state identifier can still be loaded. """ +def test_resume_callback_state_saved_by_type(tmpdir): + """ Test that a legacy checkpoint that didn't use a state identifier before can still be loaded. """ model = BoringModel() - callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) + callback = OldStatefulCallback(state=111) trainer = Trainer( default_root_dir=tmpdir, max_steps=1, callbacks=[callback], ) trainer.fit(model) - # simulate old format where type(callback) was the key - new_checkpoint = torch.load(Path(tmpdir, "last.ckpt")) - old_checkpoint = deepcopy(new_checkpoint) - old_checkpoint["callbacks"] = {type(callback): new_checkpoint["callbacks"]["ModelCheckpoint"]} - torch.save(old_checkpoint, Path(tmpdir, "old.ckpt")) + ckpt_path = Path(trainer.checkpoint_callback.best_model_path) + assert ckpt_path.exists() + callback = OldStatefulCallback(state=222) trainer = Trainer( default_root_dir=tmpdir, max_steps=2, callbacks=[callback], - resume_from_checkpoint=Path(tmpdir, "old.ckpt"), + resume_from_checkpoint=ckpt_path, ) trainer.fit(model) + assert callback.state == 111 From d9ea8c165fc7983b6611e5c769a3fdbe70d59575 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Apr 2021 22:29:34 +0200 Subject: [PATCH 10/31] flake --- pytorch_lightning/trainer/callback_hook.py | 4 +++- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 4f35e30f7ee5d..82bfd644ca3bc 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -299,7 +299,9 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(callback.state_identifier, callback_states.get(callback._legacy_state_identifier)) + state = callback_states.get( + callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) callback.on_load_checkpoint(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index b3e639c0ee3d5..40d755d76f2c6 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -17,9 +17,9 @@ from pathlib import Path import pytest -from pytorch_lightning.callbacks.base import Callback from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.base import Callback from tests import PATH_LEGACY from tests.helpers import BoringModel From 0851f0d439d30e1cda22310ba76eeb5c22d92e7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:16:11 +0200 Subject: [PATCH 11/31] fix merge --- pytorch_lightning/trainer/callback_hook.py | 17 +++++------------ tests/checkpointing/test_legacy_checkpoints.py | 5 +++-- 2 files changed, 8 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 2ecdb5159d146..343019171f406 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Type, Union import torch @@ -282,19 +282,10 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 - if callback_states is not None: - for callback in self.callbacks: - state = callback_states.get( - callback.state_identifier, callback_states.get(callback._legacy_state_identifier) - ) - if state: - state = deepcopy(state) - callback.on_load_checkpoint(state) - if callback_states is None: return @@ -309,7 +300,9 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) for callback in self.callbacks: - state = callback_states.get(type(callback)) + state = callback_states.get( + callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + ) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 8c5adf19357dc..4bbccb5c85006 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,8 +18,9 @@ import pytest -from pytorch_lightning import Trainer +from pytorch_lightning import Trainer, Callback from tests import _PATH_LEGACY +from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, 'checkpoints') CHECKPOINT_EXTENSION = ".ckpt" @@ -110,7 +111,7 @@ def state_identifier(self): def on_save_checkpoint(self, *args): return {"state": self.state} - def on_load_checkpoint(self, callback_state): + def on_load_checkpoint(self, trainer, pl_module, callback_state): self.state = callback_state["state"] From 82d5658a80d63974d9d4fe5dc1227cd0d8051c4c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 09:17:52 +0000 Subject: [PATCH 12/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/governance.rst | 2 +- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/governance.rst b/docs/source/governance.rst index 5c29f7d0da544..4114ccdb8a818 100644 --- a/docs/source/governance.rst +++ b/docs/source/governance.rst @@ -39,7 +39,7 @@ Board Alumni ------ -- Jeff Yang (`ydcjeff `_) +- Jeff Yang (`ydcjeff `_) - Jeff Ling (`jeffling `_) - Teddy Koker (`teddykoker `_) - Nate Raw (`nateraw `_) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 4bbccb5c85006..2dc499c1882d6 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -18,7 +18,7 @@ import pytest -from pytorch_lightning import Trainer, Callback +from pytorch_lightning import Callback, Trainer from tests import _PATH_LEGACY from tests.helpers import BoringModel From 334fd4a9d191db99d4cb10048fc9be82c64f27cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:21:18 +0200 Subject: [PATCH 13/31] use qualname --- pytorch_lightning/callbacks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 05685f7e9a688..a0af628c0aefc 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -35,7 +35,7 @@ class Callback(abc.ABC): @property def state_identifier(self) -> str: - return self.__class__.__name__ + return self.__class__.__qualname__ @property def _legacy_state_identifier(self) -> Type: From f144fd188af921b245eea9177c43d559aec9a818 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:44:57 +0200 Subject: [PATCH 14/31] rename state_id --- pytorch_lightning/callbacks/base.py | 4 ++-- pytorch_lightning/trainer/callback_hook.py | 4 ++-- tests/checkpointing/test_legacy_checkpoints.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a0af628c0aefc..f05ddd9217844 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -34,11 +34,11 @@ class Callback(abc.ABC): """ @property - def state_identifier(self) -> str: + def state_id(self) -> str: return self.__class__.__qualname__ @property - def _legacy_state_identifier(self) -> Type: + def _legacy_state_id(self) -> Type: return type(self) def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 343019171f406..5083bcae51263 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -277,7 +277,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.state_identifier] = state + callback_states[callback.state_id] = state return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: @@ -301,7 +301,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: for callback in self.callbacks: state = callback_states.get( - callback.state_identifier, callback_states.get(callback._legacy_state_identifier) + callback.state_id, callback_states.get(callback._legacy_state_id) ) if state: state = deepcopy(state) diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 2dc499c1882d6..cb43a65ac5610 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -105,7 +105,7 @@ def __init__(self, state): self.state = state @property - def state_identifier(self): + def state_id(self): return type(self) def on_save_checkpoint(self, *args): From 615498670d3f9ae886c5f617af39963b4fd22df6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 11:45:33 +0200 Subject: [PATCH 15/31] fix diff --- pytorch_lightning/trainer/callback_hook.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5083bcae51263..3075a3cfa8c77 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -282,10 +282,11 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") + if callback_states is None: return From 0ec9bd21e57b232f84fc9c35b6210e052da34b8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 12:48:03 +0200 Subject: [PATCH 16/31] update fx validator --- .../trainer/connectors/logger_connector/fx_validator.py | 1 + tests/trainer/logging_/test_logger_connector.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 3604574fd1e81..7ad74001ea686 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -65,6 +65,7 @@ class FxValidator: on_save_checkpoint=None, on_load_checkpoint=None, setup=None, + state_id=None, teardown=None, configure_sharded_model=None, training_step=dict(on_step=(False, True), on_epoch=(False, True)), diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 27598b40fbd31..64543c4357de0 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -78,6 +78,7 @@ def test_fx_validator(tmpdir): "on_predict_epoch_start", "on_predict_start", 'setup', + "state_id", 'teardown', ] @@ -105,6 +106,7 @@ def test_fx_validator(tmpdir): "on_train_end", "on_validation_end", "setup", + "state_id", "teardown", ] From 049f14d1d8deedea98b709234a3c19bf3e2f6b75 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 10:50:19 +0000 Subject: [PATCH 17/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/callback_hook.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 3075a3cfa8c77..4fae7edc2aa97 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -286,7 +286,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") - + if callback_states is None: return @@ -301,9 +301,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) for callback in self.callbacks: - state = callback_states.get( - callback.state_id, callback_states.get(callback._legacy_state_id) - ) + state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id)) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): From 3eca3c5acfff1dbda294d9ee988a19aeadf3badb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 13:55:58 +0200 Subject: [PATCH 18/31] black --- pytorch_lightning/callbacks/base.py | 2 +- .../metrics/functional/precision_recall_curve.py | 3 ++- tests/checkpointing/test_legacy_checkpoints.py | 1 - tests/checkpointing/test_model_checkpoint.py | 16 ++++++++-------- tests/trainer/logging_/test_logger_connector.py | 4 ++-- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index de568a8c60522..851d2b953c8c6 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -41,7 +41,7 @@ def state_id(self) -> str: def _legacy_state_id(self) -> Type: return type(self) - def on_configure_sharded_model(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: + def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called before configure sharded model""" def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index 93914c146e82f..93b203fae129b 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -27,7 +27,8 @@ def precision_recall_curve( pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, ) -> Union[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], ]: """ .. deprecated:: diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index fa9397671320f..34df42a28e81f 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -100,7 +100,6 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): class OldStatefulCallback(Callback): - def __init__(self, state): self.state = state diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a84ca6bb1c852..270ecdbf51b9b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -148,10 +148,10 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk['callbacks']["ModelCheckpoint"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk["callbacks"]["ModelCheckpoint"] + assert mc_specific_data["dirpath"] == checkpoint.dirpath + assert mc_specific_data["monitor"] == monitor + assert mc_specific_data["current_score"] == score if not reduce_lr_on_plateau: actual_step_count = chk["lr_schedulers"][0]["_step_count"] @@ -259,10 +259,10 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk['callbacks']["ModelCheckpoint"] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score + mc_specific_data = chk["callbacks"]["ModelCheckpoint"] + assert mc_specific_data["dirpath"] == checkpoint.dirpath + assert mc_specific_data["monitor"] == monitor + assert mc_specific_data["current_score"] == score if not reduce_lr_on_plateau: actual_step_count = chk["lr_schedulers"][0]["_step_count"] diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 94c9d0ec8df30..e1102224179a4 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -77,9 +77,9 @@ def test_fx_validator(tmpdir): "on_predict_epoch_end", "on_predict_epoch_start", "on_predict_start", - 'setup', + "setup", "state_id", - 'teardown', + "teardown", ] not_supported = [ From ff190fa61759e540bba42ced3a5c25a226c19776 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 11:56:58 +0000 Subject: [PATCH 19/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../metrics/functional/precision_recall_curve.py | 3 +-- tests/checkpointing/test_legacy_checkpoints.py | 15 +++------------ 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index 93b203fae129b..93914c146e82f 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -27,8 +27,7 @@ def precision_recall_curve( pos_label: Optional[int] = None, sample_weights: Optional[Sequence] = None, ) -> Union[ - Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], + Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]], ]: """ .. deprecated:: diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 34df42a28e81f..242dba8830475 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -115,24 +115,15 @@ def on_load_checkpoint(self, trainer, pl_module, callback_state): def test_resume_callback_state_saved_by_type(tmpdir): - """ Test that a legacy checkpoint that didn't use a state identifier before can still be loaded. """ + """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded.""" model = BoringModel() callback = OldStatefulCallback(state=111) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=1, - callbacks=[callback], - ) + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) trainer.fit(model) ckpt_path = Path(trainer.checkpoint_callback.best_model_path) assert ckpt_path.exists() callback = OldStatefulCallback(state=222) - trainer = Trainer( - default_root_dir=tmpdir, - max_steps=2, - callbacks=[callback], - resume_from_checkpoint=ckpt_path, - ) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) trainer.fit(model) assert callback.state == 111 From a1b5b23c79a607e7170095a3eb851417f34c4ff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 26 Jul 2021 14:07:54 +0200 Subject: [PATCH 20/31] update test to ignore properties --- .../trainer/connectors/logger_connector/fx_validator.py | 1 - tests/trainer/logging_/test_logger_connector.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 8b449f0323429..f2ad8f1130993 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -65,7 +65,6 @@ class FxValidator: on_save_checkpoint=None, on_load_checkpoint=None, setup=None, - state_id=None, teardown=None, configure_sharded_model=None, training_step=dict(on_step=(False, True), on_epoch=(False, True)), diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index e1102224179a4..b4628515db424 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from inspect import isfunction, getmembers from unittest import mock import pytest @@ -29,7 +30,7 @@ def test_fx_validator(tmpdir): - funcs_name = sorted([f for f in dir(Callback) if not f.startswith("_")]) + funcs_name = sorted([f for f, _ in getmembers(Callback, predicate=isfunction) if not f.startswith("_")]) callbacks_func = [ "on_before_backward", @@ -78,7 +79,6 @@ def test_fx_validator(tmpdir): "on_predict_epoch_start", "on_predict_start", "setup", - "state_id", "teardown", ] @@ -106,7 +106,6 @@ def test_fx_validator(tmpdir): "on_train_end", "on_validation_end", "setup", - "state_id", "teardown", ] From bffbd53d8b74ef5fbde3fdf40874f0a8aacf8ef5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Jul 2021 12:08:49 +0000 Subject: [PATCH 21/31] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/logging_/test_logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index b4628515db424..2c324be50d7f0 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from inspect import isfunction, getmembers +from inspect import getmembers, isfunction from unittest import mock import pytest From b10a45a76d8d0d3479d7c7d0ac4b1bf4a148ab1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 27 Jul 2021 23:47:27 +0200 Subject: [PATCH 22/31] update changelog --- CHANGELOG.md | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7872af715d68a..131249c91259b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- - - -- +- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) - @@ -28,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Replace `iteration_count` and other index attributes in the loops with progress dataclasses ([#8477](https://github.com/PyTorchLightning/pytorch-lightning/pull/8477)) -- +- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) - From d40b2cc5fd32cce7f4b44d973381392ea9fd7855 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 27 Jul 2021 23:52:23 +0200 Subject: [PATCH 23/31] update test_fx_validator test --- tests/trainer/logging_/test_logger_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 7c18b1c932117..181e7a8b6f592 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -30,7 +30,7 @@ def test_fx_validator(tmpdir): - funcs_name = sorted(f for f in dir(Callback) if not f.startswith("_")) + funcs_name = sorted(m for m, _ in getmembers(Callback, predicate=isfunction) if not m.startswith("_")) callbacks_func = [ "on_before_backward", From a52ad31c1d39467d2ad1285f144aa4781a205e2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Jul 2021 09:33:48 +0200 Subject: [PATCH 24/31] add docs for state id --- pytorch_lightning/callbacks/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 851d2b953c8c6..2b6166a8e2480 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -35,10 +35,15 @@ class Callback(abc.ABC): @property def state_id(self) -> str: + """ + Identifier for the state of the callback. Used to store and retrieve a callback's state from the + checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. + """ return self.__class__.__qualname__ @property def _legacy_state_id(self) -> Type: + """State identifier for checkpoints saved prior to version 1.5.0.""" return type(self) def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: From a3ec5710780dd1230dbcd59680d633c5df27bbbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Jul 2021 09:48:44 +0200 Subject: [PATCH 25/31] update docs for state id --- pytorch_lightning/callbacks/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 2b6166a8e2480..2907f18a4f8b1 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -37,7 +37,9 @@ class Callback(abc.ABC): def state_id(self) -> str: """ Identifier for the state of the callback. Used to store and retrieve a callback's state from the - checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. + checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to + provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of + multiple instances of that callback. """ return self.__class__.__qualname__ From 140c71b490e0ef2a9aa0bb801433a09cd702858c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Jul 2021 14:34:02 +0200 Subject: [PATCH 26/31] Update pytorch_lightning/callbacks/base.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- pytorch_lightning/callbacks/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 2907f18a4f8b1..17ab2fba10c07 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -44,7 +44,7 @@ def state_id(self) -> str: return self.__class__.__qualname__ @property - def _legacy_state_id(self) -> Type: + def _legacy_state_id(self) -> Type["Callback"]: """State identifier for checkpoints saved prior to version 1.5.0.""" return type(self) From d1b59db03104717a98ef9b3252273dd2e6e5f5a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Jul 2021 14:37:20 +0200 Subject: [PATCH 27/31] Update tests/trainer/logging_/test_logger_connector.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/trainer/logging_/test_logger_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 181e7a8b6f592..709b1ae7d091a 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -30,7 +30,9 @@ def test_fx_validator(tmpdir): - funcs_name = sorted(m for m, _ in getmembers(Callback, predicate=isfunction) if not m.startswith("_")) +from tests.models.test_hooks import get_members + + funcs_name = sorted(get_members(Callback)) callbacks_func = [ "on_before_backward", From 8dcb54e970562849c621dcde2337f562743d453b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Jul 2021 14:44:35 +0200 Subject: [PATCH 28/31] Update tests/checkpointing/test_model_checkpoint.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos Mocholí --- tests/checkpointing/test_model_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ac1e1267b1333..72acbfbfb8a8f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -858,8 +858,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - ch_type = "ModelCheckpoint" - assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] + assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch) From eea2dce7c10ce74b9e36492db6cf0d6a785d4599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Jul 2021 14:46:08 +0200 Subject: [PATCH 29/31] remove an empty line --- tests/checkpointing/test_model_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 72acbfbfb8a8f..bad7d1592f328 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -857,7 +857,6 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"] # it is easier to load the model objects than to iterate over the raw dict of tensors From e94f6df7d2a310034990db9afc74bfcf864f462a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Jul 2021 23:24:26 +0200 Subject: [PATCH 30/31] fix import error --- tests/trainer/logging_/test_logger_connector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 408d59099b1bb..518a401a72037 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from inspect import getmembers, isfunction from unittest import mock import pytest @@ -27,11 +26,10 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf +from tests.models.test_hooks import get_members def test_fx_validator(tmpdir): -from tests.models.test_hooks import get_members - funcs_name = sorted(get_members(Callback)) callbacks_func = [ From 302e72458b15007496226d72d137be10c8b2ab61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Jul 2021 23:26:52 +0200 Subject: [PATCH 31/31] move test --- tests/callbacks/test_callbacks.py | 33 +++++++++++++++++- .../checkpointing/test_legacy_checkpoints.py | 34 +------------------ 2 files changed, 33 insertions(+), 34 deletions(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index dedc74f021f81..d190feed7e1f7 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from unittest.mock import call, Mock -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, Trainer from tests.helpers import BoringModel @@ -101,3 +102,33 @@ def configure_callbacks(self): trainer_fn(model) callbacks_after = trainer.callbacks.copy() assert callbacks_after == callbacks_after_fit + + +class OldStatefulCallback(Callback): + def __init__(self, state): + self.state = state + + @property + def state_id(self): + return type(self) + + def on_save_checkpoint(self, *args): + return {"state": self.state} + + def on_load_checkpoint(self, trainer, pl_module, callback_state): + self.state = callback_state["state"] + + +def test_resume_callback_state_saved_by_type(tmpdir): + """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded.""" + model = BoringModel() + callback = OldStatefulCallback(state=111) + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) + trainer.fit(model) + ckpt_path = Path(trainer.checkpoint_callback.best_model_path) + assert ckpt_path.exists() + + callback = OldStatefulCallback(state=222) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) + trainer.fit(model) + assert callback.state == 111 diff --git a/tests/checkpointing/test_legacy_checkpoints.py b/tests/checkpointing/test_legacy_checkpoints.py index 242dba8830475..8693965a52abc 100644 --- a/tests/checkpointing/test_legacy_checkpoints.py +++ b/tests/checkpointing/test_legacy_checkpoints.py @@ -14,13 +14,11 @@ import glob import os import sys -from pathlib import Path import pytest -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Trainer from tests import _PATH_LEGACY -from tests.helpers import BoringModel LEGACY_CHECKPOINTS_PATH = os.path.join(_PATH_LEGACY, "checkpoints") CHECKPOINT_EXTENSION = ".ckpt" @@ -97,33 +95,3 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str): # trainer.fit(model) sys.path = orig_sys_paths - - -class OldStatefulCallback(Callback): - def __init__(self, state): - self.state = state - - @property - def state_id(self): - return type(self) - - def on_save_checkpoint(self, *args): - return {"state": self.state} - - def on_load_checkpoint(self, trainer, pl_module, callback_state): - self.state = callback_state["state"] - - -def test_resume_callback_state_saved_by_type(tmpdir): - """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded.""" - model = BoringModel() - callback = OldStatefulCallback(state=111) - trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) - trainer.fit(model) - ckpt_path = Path(trainer.checkpoint_callback.best_model_path) - assert ckpt_path.exists() - - callback = OldStatefulCallback(state=222) - trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) - trainer.fit(model) - assert callback.state == 111