From d37f1cbb14f8fd48c2eabd9c9958d0fb2e7b72fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 09:26:57 +0200 Subject: [PATCH 1/3] wip --- pytorch_lightning/core/saving.py | 3 + pytorch_lightning/trainer/callback_hook.py | 1 - pytorch_lightning/utilities/argparse.py | 10 ++-- pytorch_lightning/utilities/migration/base.py | 56 ++++++++++--------- .../utilities/migration/migrations.py | 29 +++++++--- 5 files changed, 59 insertions(+), 40 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index ffa9b0a1359ee0..faea69a2cced58 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -30,6 +30,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.migration.migrations import upgrade_checkpoint from pytorch_lightning.utilities.parsing import parse_class_init_keys log = logging.getLogger(__name__) @@ -134,6 +135,8 @@ def load_from_checkpoint( else: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) + upgrade_checkpoint(checkpoint) + if hparams_file is not None: extension = hparams_file.split('.')[-1] if extension.lower() == 'csv': diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 0f7d6a6fe50667..e3fb58d57d8faa 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -264,7 +264,6 @@ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: """Called when loading a model checkpoint.""" callback_states = checkpoint.get('callbacks') - version = checkpoint.get('pytorch-lightning_version') # Todo: the `callback_states` are dropped with TPUSpawn as they # can't be saved using `xm.save` # https://github.com/pytorch/xla/issues/2773 diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index 6dbc4636b9481e..d7607c7b68d884 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -286,11 +286,11 @@ def _gpus_allowed_type(x) -> Union[int, str]: return int(x) -def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover - # unused, but here for backward compatibility with old checkpoints that need to be able to - # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8 - # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 - pass +# def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover +# # unused, but here for backward compatibility with old checkpoints that need to be able to +# # unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8 +# # see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 +# pass def _int_or_float_type(x) -> Union[int, float]: diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration/base.py index cdd9f96df2d08d..c3e0920a681c79 100644 --- a/pytorch_lightning/utilities/migration/base.py +++ b/pytorch_lightning/utilities/migration/base.py @@ -90,13 +90,16 @@ "1.2.6", "1.2.7", "1.2.8", + "1.2.9", "1.3.0rc0", "1.3.0rc1", - pytorch_lightning.__version__, ] +if pytorch_lightning.__version__ not in version_history: + version_history.append(pytorch_lightning.__version__) -def default_upgrade_rule(checkpoint): + +def default_migration(checkpoint): """ Upgrades to the next version by only replacing the current version with the new one. """ # TODO: find more elegant version for the if below current = get_version(checkpoint) @@ -118,30 +121,31 @@ def set_version(checkpoint: dict, version: str): checkpoint["pytorch-lightning_version"] = version +all_migrations = dict((ver, default_migration) for ver in version_history) + + class Migration: """ Decorator for a function that upgrades a checkpoint from one version to the next. """ - all_migrations = dict.fromkeys(version_history, [default_upgrade_rule]) - - def __init__(self, requires: Optional[str]): - self.required_version = requires - - def __call__(self, fn: callable) -> callable: - @wraps(fn) - def wrapper(ckpt): - current_version = get_version(ckpt) - if self.required_version and current_version != self.required_version: - log.error(f"skipping, {current_version}") - return ckpt - new_ckpt = fn(ckpt) - return new_ckpt - - self.all_migrations[self.required_version].insert(0, wrapper) - return wrapper - - @staticmethod - def migrate(checkpoint: dict) -> dict: - for version_migrations in Migration.all_migrations.values(): - for migration in version_migrations: - checkpoint = migration(checkpoint) - return checkpoint + def __init__(self, target: Optional[str]): + self.target_version = target + + def __call__(self, upgrade_fn: callable) -> callable: + if getattr(upgrade_fn, "_migration_registered", False) and all_migrations[self.target_version] != default_migration: + raise ValueError( + f"Tried to register a new migration {upgrade_fn.__name__}, but" + f" there is already a migration for version {self.target_version}:" + f" {all_migrations[self.target_version].__name__}" + ) + all_migrations[self.target_version] = upgrade_fn + upgrade_fn._migration_registered = True + return upgrade_fn + + +def upgrade_checkpoint(checkpoint: dict) -> dict: + for migration in all_migrations.values(): + if migration is None: + checkpoint = default_migration(checkpoint) + checkpoint = migration(checkpoint) + return checkpoint + diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py index f9a94573666615..497a20c1d91690 100644 --- a/pytorch_lightning/utilities/migration/migrations.py +++ b/pytorch_lightning/utilities/migration/migrations.py @@ -1,28 +1,41 @@ import torch -from pytorch_lightning.utilities.migration.base import Migration, get_version +from pytorch_lightning.utilities.migration.base import Migration, get_version, upgrade_checkpoint, version_history from pytorch_lightning.utilities.migration.patch import pl_legacy_patch -@Migration(requires="1.2.7") +@Migration(target="1.2.8") def upgrade_callback_names(checkpoint: dict) -> dict: if "callbacks" not in checkpoint: return checkpoint - checkpoint["callbacks"] = reversed(checkpoint["callbacks"]) + # checkpoint["callbacks"] = reversed(checkpoint["callbacks"]) print(get_version(checkpoint)) + return checkpoint -@Migration(requires="1.2.8") +@Migration(target="1.2.8") def upgrade_something_else(checkpoint: dict) -> dict: return checkpoint +@Migration(target="1.2.9") +def upgrade_callback_state_identifiers(checkpoint): + if "callbacks" not in checkpoint: + return + callbacks = checkpoint["callbacks"] + print(callbacks) + checkpoint["callbacks"] = dict((callback_type.__name__, state) for callback_type, state in callbacks.items()) + return checkpoint + + if __name__ == "__main__": + # print(Migration.all_migrations) with pl_legacy_patch(): - checkpoint = torch.load("gpus-default-legacy.ckpt") + ckpt = torch.load("gpus-default-legacy.ckpt") # checkpoint = torch.load("example.ckpt") # checkpoint["pytorch-lightning_version"] = "1.2.6" - # getattr() - checkpoint = Migration.migrate(checkpoint) - print(checkpoint) + + ckpt = upgrade_checkpoint(ckpt) + from pprint import pprint + pprint(ckpt) From 3b44a3967ae688a4cd66c29262e31c37f4859dfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 10:59:43 +0200 Subject: [PATCH 2/3] update --- pytorch_lightning/core/saving.py | 5 ++++- pytorch_lightning/utilities/migration/migrations.py | 8 ++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index faea69a2cced58..5bf67c02cd1da6 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -135,7 +135,8 @@ def load_from_checkpoint( else: checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage) - upgrade_checkpoint(checkpoint) + # convert legacy checkpoints to the new format + checkpoint = upgrade_checkpoint(checkpoint) if hparams_file is not None: extension = hparams_file.split('.')[-1] @@ -151,6 +152,7 @@ def load_from_checkpoint( # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + # TODO: make this a migration: # for past checkpoint need to add the new key if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint: checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {} @@ -174,6 +176,7 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys + # TODO: make this a migration: for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py index 497a20c1d91690..f8f23c1bcd3841 100644 --- a/pytorch_lightning/utilities/migration/migrations.py +++ b/pytorch_lightning/utilities/migration/migrations.py @@ -1,6 +1,6 @@ import torch -from pytorch_lightning.utilities.migration.base import Migration, get_version, upgrade_checkpoint, version_history +from pytorch_lightning.utilities.migration.base import Migration, upgrade_checkpoint from pytorch_lightning.utilities.migration.patch import pl_legacy_patch @@ -8,9 +8,7 @@ def upgrade_callback_names(checkpoint: dict) -> dict: if "callbacks" not in checkpoint: return checkpoint - # checkpoint["callbacks"] = reversed(checkpoint["callbacks"]) - print(get_version(checkpoint)) - + checkpoint["callbacks"] = reversed(checkpoint["callbacks"]) return checkpoint @@ -24,13 +22,11 @@ def upgrade_callback_state_identifiers(checkpoint): if "callbacks" not in checkpoint: return callbacks = checkpoint["callbacks"] - print(callbacks) checkpoint["callbacks"] = dict((callback_type.__name__, state) for callback_type, state in callbacks.items()) return checkpoint if __name__ == "__main__": - # print(Migration.all_migrations) with pl_legacy_patch(): ckpt = torch.load("gpus-default-legacy.ckpt") # checkpoint = torch.load("example.ckpt") From 792e3b3ef85ad20ac997611b0f1c1abf31ce710b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 21 Apr 2021 11:02:36 +0200 Subject: [PATCH 3/3] clean up --- pytorch_lightning/utilities/migration/base.py | 6 +++--- pytorch_lightning/utilities/migration/migrations.py | 8 -------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/utilities/migration/base.py b/pytorch_lightning/utilities/migration/base.py index c3e0920a681c79..d50d905ddd354a 100644 --- a/pytorch_lightning/utilities/migration/base.py +++ b/pytorch_lightning/utilities/migration/base.py @@ -113,6 +113,9 @@ def default_migration(checkpoint): return checkpoint +all_migrations = dict((ver, default_migration) for ver in version_history) + + def get_version(checkpoint: dict) -> str: return checkpoint["pytorch-lightning_version"] @@ -121,9 +124,6 @@ def set_version(checkpoint: dict, version: str): checkpoint["pytorch-lightning_version"] = version -all_migrations = dict((ver, default_migration) for ver in version_history) - - class Migration: """ Decorator for a function that upgrades a checkpoint from one version to the next. """ diff --git a/pytorch_lightning/utilities/migration/migrations.py b/pytorch_lightning/utilities/migration/migrations.py index f8f23c1bcd3841..48a246059bcc7f 100644 --- a/pytorch_lightning/utilities/migration/migrations.py +++ b/pytorch_lightning/utilities/migration/migrations.py @@ -4,14 +4,6 @@ from pytorch_lightning.utilities.migration.patch import pl_legacy_patch -@Migration(target="1.2.8") -def upgrade_callback_names(checkpoint: dict) -> dict: - if "callbacks" not in checkpoint: - return checkpoint - checkpoint["callbacks"] = reversed(checkpoint["callbacks"]) - return checkpoint - - @Migration(target="1.2.8") def upgrade_something_else(checkpoint: dict) -> dict: return checkpoint