diff --git a/CHANGELOG.md b/CHANGELOG.md index d3ca05789fa6b..313a790ce5468 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a flavor of `training_step` that takes `dataloader_iter` as an argument ([#8807](https://github.com/PyTorchLightning/pytorch-lightning/pull/8807)) -- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) +- Added `state_key` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886)) - Progress tracking @@ -60,6 +60,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored CheckpointConnector to offload validating logic to the checkpoitn IO plugin ([#9045](https://github.com/PyTorchLightning/pytorch-lightning/pull/9045)) +- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187)) + + - Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974)) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index 2c9ee612ceb22..b007fd479b0d0 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -113,16 +113,69 @@ Lightning has a few built-in callbacks. ---------- +.. _Persisting Callback State: + Persisting State ---------------- Some callbacks require internal state in order to function properly. You can optionally choose to persist your callback's state as part of model checkpoint files using the callback hooks :meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`. -However, you must follow two constraints: +Note that the returned state must be able to be pickled. + +When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough +to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then +the callback must define a :attr:`~pytorch_lightning.callbacks.Callback.state_key` property in order for Lightning +to be able to distinguish the different states when loading the callback state. This concept is best illustrated by +the following example. + +.. testcode:: + + class Counter(Callback): + def __init__(self, what="epochs", verbose=True): + self.what = what + self.verbose = verbose + self.state = {"epochs": 0, "batches": 0} + + @property + def state_key(self): + # note: we do not include `verbose` here on purpose + return self._generate_state_key(what=self.what) + + def on_train_epoch_end(self, *args, **kwargs): + if self.what == "epochs": + self.state["epochs"] += 1 + + def on_train_batch_end(self, *args, **kwargs): + if self.what == "batches": + self.state["batches"] += 1 + + def on_load_checkpoint(self, trainer, pl_module, callback_state): + self.state.update(callback_state) + + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + return self.state.copy() + + + # two callbacks of the same type are being used + trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")]) + +A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information: + +.. code-block:: + + { + "state_dict": ..., + "callbacks": { + "Counter{'what': 'batches'}": {"batches": 32, "epochs": 0}, + "Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2}, + ... + } + } -1. Your returned state must be able to be pickled. -2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class. +The implementation of a :attr:`~pytorch_lightning.callbacks.Callback.state_key` is essential here. If it were missing, +Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~pytorch_lightning.callbacks.Callback.state_key` +by default only defines the class name as the key, e.g., here ``Counter``. Best Practices diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index c38613b0e3159..fdb22a44ed307 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -34,20 +34,30 @@ class Callback(abc.ABC): """ @property - def state_id(self) -> str: + def state_key(self) -> str: """ Identifier for the state of the callback. Used to store and retrieve a callback's state from the - checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to - provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of + checkpoint dictionary by ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to + provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback. """ return self.__class__.__qualname__ @property - def _legacy_state_id(self) -> Type["Callback"]: - """State identifier for checkpoints saved prior to version 1.5.0.""" + def _legacy_state_key(self) -> Type["Callback"]: + """State key for checkpoints saved prior to version 1.5.0.""" return type(self) + def _generate_state_key(self, **kwargs: Any) -> str: + """ + Formats a set of key-value pairs into a state key string with the callback class name prefixed. + Useful for defining a :attr:`state_key`. + + Args: + **kwargs: A set of key-value pairs. Must be serializable to :class:`str`. + """ + return f"{self.__class__.__qualname__}{repr(kwargs)}" + def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Called before configure sharded model""" diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ad7beec6927d7..77683ad2819f3 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -75,6 +75,13 @@ class EarlyStopping(Callback): >>> from pytorch_lightning.callbacks import EarlyStopping >>> early_stopping = EarlyStopping('val_loss') >>> trainer = Trainer(callbacks=[early_stopping]) + + .. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the + following arguments: + + *monitor, mode* + + Read more: :ref:`Persisting Callback State` """ mode_dict = {"min": torch.lt, "max": torch.gt} @@ -120,6 +127,10 @@ def __init__( ) self.monitor = monitor or "early_stop_on" + @property + def state_key(self) -> str: + return self._generate_state_key(monitor=self.monitor, mode=self.mode) + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._check_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch, we try to check after diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1144af7e32e9f..414a92af6a66c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -194,6 +194,12 @@ class ModelCheckpoint(Callback): trainer.fit(model) checkpoint_callback.best_model_path + .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the + following arguments: + + *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end* + + Read more: :ref:`Persisting Callback State` """ CHECKPOINT_JOIN_CHAR = "-" @@ -248,6 +254,17 @@ def __init__( self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period) self.__validate_init_configuration() + @property + def state_key(self) -> str: + return self._generate_state_key( + monitor=self.monitor, + mode=self.mode, + every_n_train_steps=self._every_n_train_steps, + every_n_epochs=self._every_n_epochs, + train_time_interval=self._train_time_interval, + save_on_train_epoch_end=self._save_on_train_epoch_end, + ) + def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """ When pretrain routine starts we build the ckpt dir on the fly diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 2261538d7cbcc..36a3e9abb7b7a 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -242,7 +242,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: for callback in self.callbacks: state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint) if state: - callback_states[callback.state_id] = state + callback_states[callback.state_key] = state return callback_states def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: @@ -267,7 +267,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: ) for callback in self.callbacks: - state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id)) + state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key)) if state: state = deepcopy(state) callback.on_load_checkpoint(self, self.lightning_module, state) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index d190feed7e1f7..dc191f4853cc1 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -109,7 +109,7 @@ def __init__(self, state): self.state = state @property - def state_id(self): + def state_key(self): return type(self) def on_save_checkpoint(self, *args): @@ -120,7 +120,7 @@ def on_load_checkpoint(self, trainer, pl_module, callback_state): def test_resume_callback_state_saved_by_type(tmpdir): - """Test that a legacy checkpoint that didn't use a state identifier before can still be loaded.""" + """Test that a legacy checkpoint that didn't use a state key before can still be loaded.""" model = BoringModel() callback = OldStatefulCallback(state=111) trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback]) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 4c3b990dd1b13..ad343cdf329f5 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -33,6 +33,11 @@ _logger = logging.getLogger(__name__) +def test_early_stopping_state_key(): + early_stopping = EarlyStopping(monitor="val_loss") + assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}" + + class EarlyStoppingTestRestore(EarlyStopping): # this class has to be defined outside the test function, otherwise we get pickle error def __init__(self, expected_state, *args, **kwargs): @@ -77,7 +82,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] assert 4 == len(early_stop_callback.saved_states) - assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state + es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}" + assert checkpoint["callbacks"][es_name] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") diff --git a/tests/callbacks/test_lambda_function.py b/tests/callbacks/test_lambda_function.py index 82f64d676c774..88752d56bf697 100644 --- a/tests/callbacks/test_lambda_function.py +++ b/tests/callbacks/test_lambda_function.py @@ -11,12 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import inspect from functools import partial from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import Callback, LambdaCallback from tests.helpers.boring_model import BoringModel +from tests.models.test_hooks import get_members def test_lambda_call(tmpdir): @@ -32,7 +32,7 @@ def on_train_epoch_start(self): def call(hook, *_, **__): checker.add(hook) - hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)} + hooks = get_members(Callback) hooks_args = {h: partial(call, h) for h in hooks} hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")] diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 314ed899c588a..f49fa16598fd2 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -43,6 +43,15 @@ from tests.helpers.runif import RunIf +def test_model_checkpoint_state_key(): + early_stopping = ModelCheckpoint(monitor="val_loss") + expected_id = ( + "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" + ) + assert early_stopping.state_key == expected_id + + class LogInTwoMethods(BoringModel): def training_step(self, batch, batch_idx): out = super().training_step(batch, batch_idx) @@ -148,7 +157,10 @@ def on_validation_epoch_end(self): assert chk["epoch"] == epoch + 1 assert chk["global_step"] == limit_train_batches * (epoch + 1) - mc_specific_data = chk["callbacks"]["ModelCheckpoint"] + mc_specific_data = chk["callbacks"][ + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -259,7 +271,10 @@ def _make_assertions(epoch, ix, version=""): expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) assert chk["global_step"] == expected_global_step - mc_specific_data = chk["callbacks"]["ModelCheckpoint"] + mc_specific_data = chk["callbacks"][ + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': False}" + ] assert mc_specific_data["dirpath"] == checkpoint.dirpath assert mc_specific_data["monitor"] == monitor assert mc_specific_data["current_score"] == score @@ -857,7 +872,12 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] - assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"] + + ckpt_id = ( + "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + ) + assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] # it is easier to load the model objects than to iterate over the raw dict of tensors model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch) @@ -1095,7 +1115,13 @@ def training_step(self, *args): trainer.fit(TestModel()) assert model_checkpoint.current_score == 0.3 ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] - ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts] + ckpts = [ + ckpt["callbacks"][ + "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" + ] + for ckpt in ckpts + ] assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 43158865f9e75..455e08dc10ad5 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -64,28 +64,55 @@ def on_save_checkpoint(self, *args): class StatefulCallback1(Callback): + def __init__(self, unique=None, other=None): + self._unique = unique + self._other = other + + @property + def state_key(self): + return self._generate_state_key(unique=self._unique) + def on_save_checkpoint(self, *args): - return {"content1": 1} + return {"content1": self._unique} def test_all_callback_states_saved_before_checkpoint_callback(tmpdir): - """Test that all callback states get saved even if the ModelCheckpoint is not given as last.""" + """ + Test that all callback states get saved even if the ModelCheckpoint is not given as last + and when there are multiple callbacks of the same type. + """ callback0 = StatefulCallback0() - callback1 = StatefulCallback1() + callback1 = StatefulCallback1(unique="one") + callback2 = StatefulCallback1(unique="two", other=2) checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states") model = BoringModel() trainer = Trainer( - default_root_dir=tmpdir, max_steps=1, limit_val_batches=1, callbacks=[callback0, checkpoint_callback, callback1] + default_root_dir=tmpdir, + max_steps=1, + limit_val_batches=1, + callbacks=[ + callback0, + # checkpoint callback does not have to be at the end + checkpoint_callback, + # callback2 and callback3 have the same type + callback1, + callback2, + ], ) trainer.fit(model) ckpt = torch.load(str(tmpdir / "all_states.ckpt")) state0 = ckpt["callbacks"]["StatefulCallback0"] - state1 = ckpt["callbacks"]["StatefulCallback1"] + state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"] + state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"] assert "content0" in state0 and state0["content0"] == 0 - assert "content1" in state1 and state1["content1"] == 1 - assert "ModelCheckpoint" in ckpt["callbacks"] + assert "content1" in state1 and state1["content1"] == "one" + assert "content1" in state2 and state2["content1"] == "two" + assert ( + "ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"] + ) def test_attach_model_callbacks():