Skip to content

Commit

Permalink
[2 / 3] improvements to saving and loading callback state (#7187)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
4 people authored Aug 24, 2021
1 parent 376734a commit b9443a0
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 27 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a flavor of `training_step` that takes `dataloader_iter` as an argument ([#8807](https://github.com/PyTorchLightning/pytorch-lightning/pull/8807))


- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))
- Added `state_key` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


- Progress tracking
Expand Down Expand Up @@ -60,6 +60,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored CheckpointConnector to offload validating logic to the checkpoitn IO plugin ([#9045](https://github.com/PyTorchLightning/pytorch-lightning/pull/9045))


- Added support for saving and loading state of multiple callbacks of the same type ([#7187](https://github.com/PyTorchLightning/pytorch-lightning/pull/7187))


- Added DeepSpeed Stage 1 support ([#8974](https://github.com/PyTorchLightning/pytorch-lightning/pull/8974))


Expand Down
59 changes: 56 additions & 3 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,69 @@ Lightning has a few built-in callbacks.

----------

.. _Persisting Callback State:

Persisting State
----------------

Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback's state as part of model checkpoint files using the callback hooks
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
However, you must follow two constraints:
Note that the returned state must be able to be pickled.

When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then
the callback must define a :attr:`~pytorch_lightning.callbacks.Callback.state_key` property in order for Lightning
to be able to distinguish the different states when loading the callback state. This concept is best illustrated by
the following example.

.. testcode::

class Counter(Callback):
def __init__(self, what="epochs", verbose=True):
self.what = what
self.verbose = verbose
self.state = {"epochs": 0, "batches": 0}

@property
def state_key(self):
# note: we do not include `verbose` here on purpose
return self._generate_state_key(what=self.what)

def on_train_epoch_end(self, *args, **kwargs):
if self.what == "epochs":
self.state["epochs"] += 1
def on_train_batch_end(self, *args, **kwargs):
if self.what == "batches":
self.state["batches"] += 1
def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state.update(callback_state)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
return self.state.copy()


# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])

A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:

.. code-block::
{
"state_dict": ...,
"callbacks": {
"Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
"Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
...
}
}
1. Your returned state must be able to be pickled.
2. You can only use one instance of that class in the Trainer callbacks list. We don't support persisting state for multiple callbacks of the same class.
The implementation of a :attr:`~pytorch_lightning.callbacks.Callback.state_key` is essential here. If it were missing,
Lightning would not be able to disambiguate the state for these two callbacks, and :attr:`~pytorch_lightning.callbacks.Callback.state_key`
by default only defines the class name as the key, e.g., here ``Counter``.


Best Practices
Expand Down
20 changes: 15 additions & 5 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,30 @@ class Callback(abc.ABC):
"""

@property
def state_id(self) -> str:
def state_key(self) -> str:
"""
Identifier for the state of the callback. Used to store and retrieve a callback's state from the
checkpoint dictionary by ``checkpoint["callbacks"][state_id]``. Implementations of a callback need to
provide a unique state id if 1) the callback has state and 2) it is desired to maintain the state of
checkpoint dictionary by ``checkpoint["callbacks"][state_key]``. Implementations of a callback need to
provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of
multiple instances of that callback.
"""
return self.__class__.__qualname__

@property
def _legacy_state_id(self) -> Type["Callback"]:
"""State identifier for checkpoints saved prior to version 1.5.0."""
def _legacy_state_key(self) -> Type["Callback"]:
"""State key for checkpoints saved prior to version 1.5.0."""
return type(self)

def _generate_state_key(self, **kwargs: Any) -> str:
"""
Formats a set of key-value pairs into a state key string with the callback class name prefixed.
Useful for defining a :attr:`state_key`.
Args:
**kwargs: A set of key-value pairs. Must be serializable to :class:`str`.
"""
return f"{self.__class__.__qualname__}{repr(kwargs)}"

def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Called before configure sharded model"""

Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class EarlyStopping(Callback):
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(callbacks=[early_stopping])
.. tip:: Saving and restoring multiple early stopping callbacks at the same time is supported under variation in the
following arguments:
*monitor, mode*
Read more: :ref:`Persisting Callback State`
"""
mode_dict = {"min": torch.lt, "max": torch.gt}

Expand Down Expand Up @@ -120,6 +127,10 @@ def __init__(
)
self.monitor = monitor or "early_stop_on"

@property
def state_key(self) -> str:
return self._generate_state_key(monitor=self.monitor, mode=self.mode)

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._check_on_train_epoch_end is None:
# if the user runs validation multiple times per training epoch, we try to check after
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ class ModelCheckpoint(Callback):
trainer.fit(model)
checkpoint_callback.best_model_path
.. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the
following arguments:
*monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end*
Read more: :ref:`Persisting Callback State`
"""

CHECKPOINT_JOIN_CHAR = "-"
Expand Down Expand Up @@ -248,6 +254,17 @@ def __init__(
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval, period)
self.__validate_init_configuration()

@property
def state_key(self) -> str:
return self._generate_state_key(
monitor=self.monitor,
mode=self.mode,
every_n_train_steps=self._every_n_train_steps,
every_n_epochs=self._every_n_epochs,
train_time_interval=self._train_time_interval,
save_on_train_epoch_end=self._save_on_train_epoch_end,
)

def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""
When pretrain routine starts we build the ckpt dir on the fly
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
for callback in self.callbacks:
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[callback.state_id] = state
callback_states[callback.state_key] = state
return callback_states

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
Expand All @@ -267,7 +267,7 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
)

for callback in self.callbacks:
state = callback_states.get(callback.state_id, callback_states.get(callback._legacy_state_id))
state = callback_states.get(callback.state_key, callback_states.get(callback._legacy_state_key))
if state:
state = deepcopy(state)
callback.on_load_checkpoint(self, self.lightning_module, state)
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(self, state):
self.state = state

@property
def state_id(self):
def state_key(self):
return type(self)

def on_save_checkpoint(self, *args):
Expand All @@ -120,7 +120,7 @@ def on_load_checkpoint(self, trainer, pl_module, callback_state):


def test_resume_callback_state_saved_by_type(tmpdir):
"""Test that a legacy checkpoint that didn't use a state identifier before can still be loaded."""
"""Test that a legacy checkpoint that didn't use a state key before can still be loaded."""
model = BoringModel()
callback = OldStatefulCallback(state=111)
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, callbacks=[callback])
Expand Down
8 changes: 7 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
_logger = logging.getLogger(__name__)


def test_early_stopping_state_key():
early_stopping = EarlyStopping(monitor="val_loss")
assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"


class EarlyStoppingTestRestore(EarlyStopping):
# this class has to be defined outside the test function, otherwise we get pickle error
def __init__(self, expected_state, *args, **kwargs):
Expand Down Expand Up @@ -77,7 +82,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
assert checkpoint["callbacks"][es_name] == early_stop_callback_state

# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import partial

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import Callback, LambdaCallback
from tests.helpers.boring_model import BoringModel
from tests.models.test_hooks import get_members


def test_lambda_call(tmpdir):
Expand All @@ -32,7 +32,7 @@ def on_train_epoch_start(self):
def call(hook, *_, **__):
checker.add(hook)

hooks = {m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)}
hooks = get_members(Callback)
hooks_args = {h: partial(call, h) for h in hooks}
hooks_args["on_save_checkpoint"] = lambda *_: [checker.add("on_save_checkpoint")]

Expand Down
34 changes: 30 additions & 4 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@
from tests.helpers.runif import RunIf


def test_model_checkpoint_state_key():
early_stopping = ModelCheckpoint(monitor="val_loss")
expected_id = (
"ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': None}"
)
assert early_stopping.state_key == expected_id


class LogInTwoMethods(BoringModel):
def training_step(self, batch, batch_idx):
out = super().training_step(batch, batch_idx)
Expand Down Expand Up @@ -148,7 +157,10 @@ def on_validation_epoch_end(self):
assert chk["epoch"] == epoch + 1
assert chk["global_step"] == limit_train_batches * (epoch + 1)

mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
mc_specific_data = chk["callbacks"][
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
Expand Down Expand Up @@ -259,7 +271,10 @@ def _make_assertions(epoch, ix, version=""):
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk["global_step"] == expected_global_step

mc_specific_data = chk["callbacks"]["ModelCheckpoint"]
mc_specific_data = chk["callbacks"][
f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': False}"
]
assert mc_specific_data["dirpath"] == checkpoint.dirpath
assert mc_specific_data["monitor"] == monitor
assert mc_specific_data["current_score"] == score
Expand Down Expand Up @@ -857,7 +872,12 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):

assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"]
assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"]
assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"]

ckpt_id = (
"ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
)
assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]

# it is easier to load the model objects than to iterate over the raw dict of tensors
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
Expand Down Expand Up @@ -1095,7 +1115,13 @@ def training_step(self, *args):
trainer.fit(TestModel())
assert model_checkpoint.current_score == 0.3
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
ckpts = [
ckpt["callbacks"][
"ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}"
]
for ckpt in ckpts
]
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]


Expand Down
41 changes: 34 additions & 7 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,55 @@ def on_save_checkpoint(self, *args):


class StatefulCallback1(Callback):
def __init__(self, unique=None, other=None):
self._unique = unique
self._other = other

@property
def state_key(self):
return self._generate_state_key(unique=self._unique)

def on_save_checkpoint(self, *args):
return {"content1": 1}
return {"content1": self._unique}


def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
"""Test that all callback states get saved even if the ModelCheckpoint is not given as last."""
"""
Test that all callback states get saved even if the ModelCheckpoint is not given as last
and when there are multiple callbacks of the same type.
"""

callback0 = StatefulCallback0()
callback1 = StatefulCallback1()
callback1 = StatefulCallback1(unique="one")
callback2 = StatefulCallback1(unique="two", other=2)
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, max_steps=1, limit_val_batches=1, callbacks=[callback0, checkpoint_callback, callback1]
default_root_dir=tmpdir,
max_steps=1,
limit_val_batches=1,
callbacks=[
callback0,
# checkpoint callback does not have to be at the end
checkpoint_callback,
# callback2 and callback3 have the same type
callback1,
callback2,
],
)
trainer.fit(model)

ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
state0 = ckpt["callbacks"]["StatefulCallback0"]
state1 = ckpt["callbacks"]["StatefulCallback1"]
state1 = ckpt["callbacks"]["StatefulCallback1{'unique': 'one'}"]
state2 = ckpt["callbacks"]["StatefulCallback1{'unique': 'two'}"]
assert "content0" in state0 and state0["content0"] == 0
assert "content1" in state1 and state1["content1"] == 1
assert "ModelCheckpoint" in ckpt["callbacks"]
assert "content1" in state1 and state1["content1"] == "one"
assert "content1" in state2 and state2["content1"] == "two"
assert (
"ModelCheckpoint{'monitor': None, 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1,"
" 'train_time_interval': None, 'save_on_train_epoch_end': True}" in ckpt["callbacks"]
)


def test_attach_model_callbacks():
Expand Down

0 comments on commit b9443a0

Please sign in to comment.