diff --git a/CHANGELOG.md b/CHANGELOG.md index 129744c10f0f1..3c3d2d1114e2c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- - - -- +- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) - @@ -32,7 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -- +- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) - diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a492be314df26..17ab2fba10c07 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -17,7 +17,7 @@ """ import abc -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Type import torch from torch.optim import Optimizer @@ -33,6 +33,21 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ + @property + def state_id(self) -> str: + """ + Identifier for the state of the callback. Used to store and retrieve a callback's state from the + checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to + provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of + multiple instances of that callback. + """ + return self.__class__.__qualname__ + + @property + def _legacy_state_id(self) -> Type["Callback"]: + """State identifier for checkpoints saved prior to version 1.5.0.""" + return type(self) + def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index ffcac8f9073f6..444f4354672a0 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union import torch @@ -263,7 +263,7 @@ def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) return len(parameters) == 1 and parameters[0] != "args" - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: @@ -277,16 +277,15 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: else: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[type(callback)] = state + callback_states[callback.state_id] = state return callback_states - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" - # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 - callback_states = checkpoint.get("callbacks") + callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks") if callback_states is None: return @@ -303,7 +302,7 @@ def on_load_checkpoint(self, checkpoint): ) for callback in self.callbacks: - state = callback_states.get(type(callback)) + state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id)) if state: state = deepcopy(state) if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index e9d8dee7c7e55..611946fd53dae 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -308,7 +308,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: structured dictionary: { 'epoch': training epoch 'global_step': training global step - 'pytorch-lightning_version': PyTorch Lightning's version + 'pytorch-lightning_version': The version of PyTorch Lightning that produced this checkpoint 'callbacks': "callback specific state"[] # if not weights_only 'optimizer_states': "PT optim's state_dict"[] # if not weights_only 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index dedc74f021f81..d190feed7e1f7 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -11,9 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from unittest.mock import call, Mock -from pytorch_lightning import Trainer +from pytorch_lightning import Callback, Trainer from tests.helpers import BoringModel @@ -101,3 +102,33 @@ def configure_callbacks(self): trainer_fn(model) callbacks_after = trainer.callbacks.copy() assert callbacks_after == callbacks_after_fit + + +class OldStatefulCallback(Callback): + def __init__(self, state): + self.state = state + + @property + def state_id(self): + return type(self) + + def on_save_checkpoint(self, *args): + return {"state": self.state} + + def on_load_checkpoint(self, trainer, pl_module, callback_state): + self.state = callback_state["state"] + + +def test_resume_callback_state_saved_by_type(tmpdir): + """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded.""" + model = BoringModel() + callback = OldStatefulCallback(state=111) + trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) + trainer.fit(model) + ckpt_path = Path(trainer.checkpoint_callback.best_model_path) + assert ckpt_path.exists() + + callback = OldStatefulCallback(state=222) + trainer = Trainer(default_root_dir=tmpdir, max_steps=2, callbacks=[callback], resume_from_checkpoint=ckpt_path) + trainer.fit(model) + assert callback.state == 111 diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index cec2d78b92512..f7e4968e6135a 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -76,7 +76,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] - assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state + assert 4 == len(early_stop_callback.saved_states) + assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f9a199a7ebf98..0906ed3820705 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -148,7 +148,7 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk["callbacks"][type(checkpoint)] + mc_specific_data = chk["callbacks"]["ModelCheckpoint"] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -259,7 +259,7 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk["callbacks"][type(checkpoint)] + mc_specific_data = chk["callbacks"]["ModelCheckpoint"] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -857,9 +857,7 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - - ch_type = type(model_checkpoint) - assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type] + assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch) @@ -1097,7 +1095,7 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"][type(model_checkpoint)] for ckpt in ckpts] + ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index bdc19ee15aaad..338de72a31fed 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -76,11 +76,11 @@ def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): trainer.fit(model) ckpt = torch.load(str(tmpdir / "all_states.ckpt")) - state0 = ckpt["callbacks"][type(callback0)] - state1 = ckpt["callbacks"][type(callback1)] + state0 = ckpt["callbacks"]["StatefulCallback0"] + state1 = ckpt["callbacks"]["StatefulCallback1"] assert "content0" in state0 and state0["content0"] == 0 assert "content1" in state1 and state1["content1"] == 1 - assert type(checkpoint_callback) in ckpt["callbacks"] + assert "ModelCheckpoint" in ckpt["callbacks"] def test_attach_model_callbacks(): diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 00811f736891f..518a401a72037 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -26,10 +26,11 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf +from tests.models.test_hooks import get_members def test_fx_validator(tmpdir): - funcs_name = sorted(f for f in dir(Callback) if not f.startswith("_")) + funcs_name = sorted(get_members(Callback)) callbacks_func = [ "on_before_backward",