From 23754f4fc894e40990afa8922f11b7911fe02516 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 26 Jan 2022 17:33:07 -0800 Subject: [PATCH 01/18] first commit --- .../plugins/precision/apex_amp.py | 10 +++++----- .../plugins/precision/native_amp.py | 12 ++++++------ .../plugins/precision/precision_plugin.py | 6 ++++++ .../trainer/connectors/checkpoint_connector.py | 18 ++++++++++++++++-- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 1e86ec2633fe9..4699c43b5884b 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -92,9 +92,9 @@ def optimizer_step( if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: optimizer.step(**kwargs) - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if "amp_scaling_state" in checkpoint: - amp.load_state_dict(checkpoint["amp_scaling_state"]) + def state_dict(self) -> Dict[str, Any]: + return amp.state_dict() - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - checkpoint["amp_scaling_state"] = amp.state_dict() + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if state_dict: + amp.load_state_dict(state_dict) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index f6cb28c76c867..a2abed5ca70f5 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -106,10 +106,10 @@ def forward_context(self) -> Generator[None, None, None]: with self.autocast_context_manager(): yield - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if self.scaler is not None and "native_amp_scaling_state" in checkpoint: - self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"]) - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: - checkpoint["native_amp_scaling_state"] = self.scaler.state_dict() + return self.scaler.state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if self.scaler is not None and state_dict: + self.scaler.load_state_dict(state_dict) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index e875fe51f19e7..33d32c3493fde 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -242,3 +242,9 @@ def teardown(self) -> None: It is the right place to release memory and free other resources. """ + + def state_dict(self) -> Dict[str, Any]: + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + pass diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 5c437bfd889b2..79e7a7b15d1e2 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -192,7 +192,17 @@ def restore_training_state(self) -> None: return # restore precision plugin (scaler etc.) - self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) + prec_plugin = self.trainer.precision_plugin + prec_plugin.on_load_checkpoint(self._loaded_checkpoint) + if prec_plugin.__class__.__name__ in self._loaded_checkpoint: + prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__name__]) + + # old checkpoints compatibility + # should we raise error and force user to run utilities/upgrade_checkpoint instead? + if "amp_scaling_state" in self._loaded_checkpoint: + prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"]) + if "native_amp_scaling_state" in self._loaded_checkpoint: + prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"]) # restore loops and their progress self.restore_loops() @@ -372,7 +382,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: lr_schedulers.append(config.scheduler.state_dict()) checkpoint["lr_schedulers"] = lr_schedulers - self.trainer.precision_plugin.on_save_checkpoint(checkpoint) + # precision plugin + prec_plugin = self.trainer.precision_plugin + checkpoint[prec_plugin.__class__.__name__] = self.trainer.precision_plugin.state_dict() # dump hyper-parameters if model.hparams: @@ -389,6 +401,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: model.on_save_checkpoint(checkpoint) if self.trainer.datamodule is not None: self.trainer.datamodule.on_save_checkpoint(checkpoint) + if not weights_only: + self.trainer.precision_plugin.on_save_checkpoint(checkpoint) # TODO: remove this in v1.8. environment = self.trainer._accelerator_connector.cluster_environment From a61f2e3602b104130f0d2d61aee016f51c7d24fd Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 26 Jan 2022 17:43:14 -0800 Subject: [PATCH 02/18] import --- pytorch_lightning/plugins/precision/precision_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 33d32c3493fde..0381b780db7a6 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from functools import partial -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import torch from torch import Tensor From 50d47d5ab727fc6586935356082194941a2a1f8d Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Thu, 3 Feb 2022 12:11:25 -0800 Subject: [PATCH 03/18] some of the updates --- pytorch_lightning/plugins/precision/apex_amp.py | 3 +-- pytorch_lightning/plugins/precision/native_amp.py | 4 +++- .../plugins/precision/precision_plugin.py | 10 ++++++++++ .../trainer/connectors/checkpoint_connector.py | 11 ++++++----- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 4699c43b5884b..ff7c470ac077a 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -96,5 +96,4 @@ def state_dict(self) -> Dict[str, Any]: return amp.state_dict() def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - if state_dict: - amp.load_state_dict(state_dict) + amp.load_state_dict(state_dict) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index a2abed5ca70f5..d99d0b7e677fe 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -109,7 +109,9 @@ def forward_context(self) -> Generator[None, None, None]: def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() + else: + return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - if self.scaler is not None and state_dict: + if self.scaler is not None: self.scaler.load_state_dict(state_dict) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 0381b780db7a6..70d08fcdf0896 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -244,7 +244,17 @@ def teardown(self) -> None: """ def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate and save precision plugin state. + + Returns: + A dictionary containing precision plugin state. + """ return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict. + + Args: + state_dict: the precision plugin state returned by ``state_dict``. + """ pass diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 79e7a7b15d1e2..cc7fb12238ace 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -194,8 +194,8 @@ def restore_training_state(self) -> None: # restore precision plugin (scaler etc.) prec_plugin = self.trainer.precision_plugin prec_plugin.on_load_checkpoint(self._loaded_checkpoint) - if prec_plugin.__class__.__name__ in self._loaded_checkpoint: - prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__name__]) + if prec_plugin.__class__.__qualname__ in self._loaded_checkpoint: + prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__qualname__]) # old checkpoints compatibility # should we raise error and force user to run utilities/upgrade_checkpoint instead? @@ -384,7 +384,10 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # precision plugin prec_plugin = self.trainer.precision_plugin - checkpoint[prec_plugin.__class__.__name__] = self.trainer.precision_plugin.state_dict() + prec_plugin_state_dict = prec_plugin.state_dict() + if prec_plugin_state_dict: + checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict + prec_plugin.on_save_checkpoint(checkpoint) # dump hyper-parameters if model.hparams: @@ -401,8 +404,6 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: model.on_save_checkpoint(checkpoint) if self.trainer.datamodule is not None: self.trainer.datamodule.on_save_checkpoint(checkpoint) - if not weights_only: - self.trainer.precision_plugin.on_save_checkpoint(checkpoint) # TODO: remove this in v1.8. environment = self.trainer._accelerator_connector.cluster_environment From 7a3dacbff40612445269b1f352f5b9aca930ddbf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Feb 2022 20:12:44 +0000 Subject: [PATCH 04/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/precision/precision_plugin.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 70d08fcdf0896..131bfffc785a7 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -252,7 +252,8 @@ def state_dict(self) -> Dict[str, Any]: return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict. + """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin + state_dict. Args: state_dict: the precision plugin state returned by ``state_dict``. From 36f161253c8e5d3244fee85ba55b1fda434c214f Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Thu, 3 Feb 2022 12:17:46 -0800 Subject: [PATCH 05/18] clean --- pytorch_lightning/plugins/precision/native_amp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index d99d0b7e677fe..15123719d1067 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -109,8 +109,7 @@ def forward_context(self) -> Generator[None, None, None]: def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() - else: - return {} + return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: From 80e31e0e951044867c9b014bf2af862259517e74 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Mon, 7 Feb 2022 12:07:48 -0800 Subject: [PATCH 06/18] updates --- .../plugins/precision/apex_amp.py | 13 ++++++++++ .../plugins/precision/native_amp.py | 13 ++++++++++ .../connectors/checkpoint_connector.py | 26 +++++++++++-------- 3 files changed, 41 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index ff7c470ac077a..bb15a7a45212f 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -21,6 +21,7 @@ from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation from pytorch_lightning.utilities.types import _PARAMETERS if _APEX_AVAILABLE: @@ -97,3 +98,15 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: amp.load_state_dict(state_dict) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + rank_zero_deprecation( + "`ApexMixedPrecisionPlugin.on_load_checkpoint` is deprecated in v1.6 and will be removed in v1.8." + " Use `ApexMixedPrecisionPlugin.load_state_dict` instead." + ) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + rank_zero_deprecation( + "`ApexMixedPrecisionPlugin.on_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8." + " Use `ApexMixedPrecisionPlugin.state_dict` instead." + ) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 15123719d1067..55a5812f2b638 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -23,6 +23,7 @@ from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation if _TORCH_GREATER_EQUAL_1_10: from torch import autocast as new_autocast @@ -114,3 +115,15 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + rank_zero_deprecation( + "`NativeMixedPrecisionPlugin.on_load_checkpoint` is deprecated in v1.6 and will be removed in v1.8." + " Use `NativeMixedPrecisionPlugin.load_state_dict` instead." + ) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + rank_zero_deprecation( + "`NativeMixedPrecisionPlugin.on_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8." + " Use `NativeMixedPrecisionPlugin.state_dict` instead." + ) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index cc7fb12238ace..dc7a65931376f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -23,6 +23,7 @@ import pytorch_lightning as pl from pytorch_lightning.loops.utilities import _is_max_limit_reached from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -192,17 +193,7 @@ def restore_training_state(self) -> None: return # restore precision plugin (scaler etc.) - prec_plugin = self.trainer.precision_plugin - prec_plugin.on_load_checkpoint(self._loaded_checkpoint) - if prec_plugin.__class__.__qualname__ in self._loaded_checkpoint: - prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__qualname__]) - - # old checkpoints compatibility - # should we raise error and force user to run utilities/upgrade_checkpoint instead? - if "amp_scaling_state" in self._loaded_checkpoint: - prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"]) - if "native_amp_scaling_state" in self._loaded_checkpoint: - prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"]) + self.restore_precision_plugin_state() # restore loops and their progress self.restore_loops() @@ -212,6 +203,19 @@ def restore_training_state(self) -> None: # restore optimizers and schedulers state self.restore_optimizers_and_schedulers() + def restore_precision_plugin_state(self) -> None: + """Restore the precision plugin state from the pre-loaded checkpoint.""" + prec_plugin = self.trainer.precision_plugin + prec_plugin.on_load_checkpoint(self._loaded_checkpoint) + if prec_plugin.__class__.__qualname__ in self._loaded_checkpoint: + prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__qualname__]) + + # old checkpoints compatibility + if "amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, ApexMixedPrecisionPlugin): + prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"]) + if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, NativeMixedPrecisionPlugin): + prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"]) + def restore_callbacks(self) -> None: """Restores all callbacks from the pre-loaded checkpoint.""" if not self._loaded_checkpoint: From bb69624c32e12c31fa4d636e75c71293da4d02e2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Feb 2022 20:09:23 +0000 Subject: [PATCH 07/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/trainer/connectors/checkpoint_connector.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index dc7a65931376f..89d6621959680 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -213,7 +213,9 @@ def restore_precision_plugin_state(self) -> None: # old checkpoints compatibility if "amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, ApexMixedPrecisionPlugin): prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"]) - if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, NativeMixedPrecisionPlugin): + if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance( + prec_plugin, NativeMixedPrecisionPlugin + ): prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"]) def restore_callbacks(self) -> None: From abd98206573779e71aa615dab28561bd19579045 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Mon, 7 Feb 2022 12:57:32 -0800 Subject: [PATCH 08/18] update --- tests/models/test_hooks.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5f20d7bb4115a..4521438626abf 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -21,6 +21,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer +from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -495,10 +496,8 @@ def training_step(self, batch, batch_idx): "state_dict": ANY, "loops": ANY, } - if kwargs.get("amp_backend") == "native": - saved_ckpt["native_amp_scaling_state"] = ANY - elif kwargs.get("amp_backend") == "apex": - saved_ckpt["amp_scaling_state"] = ANY + if kwargs.get("amp_backend") == "native" or kwargs.get("amp_backend") == "apex": + saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY device = torch.device("cuda:0" if "gpus" in kwargs else "cpu") expected = [ dict(name="Callback.on_init_start", args=(trainer,)), From e49474e4b2e0357cc684a48f80e7f0104b7f0302 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Mon, 7 Feb 2022 12:59:28 -0800 Subject: [PATCH 09/18] clean --- tests/models/test_hooks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 4521438626abf..afc28b49c7813 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -21,7 +21,6 @@ from torch.utils.data import DataLoader from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, Trainer -from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf From a9370d6ec4bdc37fbf1fba0dcbd0345923d74aef Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Tue, 8 Feb 2022 10:34:48 -0800 Subject: [PATCH 10/18] docstring Update pytorch_lightning/plugins/precision/precision_plugin.py Co-authored-by: Rohit Gupta --- pytorch_lightning/plugins/precision/precision_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 131bfffc785a7..9f820b8b70154 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -244,7 +244,7 @@ def teardown(self) -> None: """ def state_dict(self) -> Dict[str, Any]: - """Called when saving a checkpoint, implement to generate and save precision plugin state. + """Called when saving a checkpoint, implement to generate precision plugin state_dict. Returns: A dictionary containing precision plugin state. From 1f803267a65d7fb16bca54d6a6d21f6a7d12c4a0 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Tue, 8 Feb 2022 11:09:58 -0800 Subject: [PATCH 11/18] docstring updates --- CHANGELOG.md | 3 +++ pytorch_lightning/trainer/connectors/checkpoint_connector.py | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa7c4f9b056bc..6671a8f91a7d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `MisconfigurationException` if user provided `opt_idx` in scheduler config doesn't match with actual optimizer index of its respective optimizer ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) +- Added `_Stateful` support for `PrecisionPlugin` ([#11638](https://github.com/PyTorchLightning/pytorch-lightning/pull/11638)) + + ### Changed - Set the `prog_bar` flag to False in `LightningModule.log_grad_norm` ([#11472](https://github.com/PyTorchLightning/pytorch-lightning/pull/11472)) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 89d6621959680..3877a68c82bca 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -340,9 +340,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'callbacks': "callback specific state"[] # if not weights_only 'optimizer_states': "PT optim's state_dict"[] # if not weights_only 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only - 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp - 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp 'state_dict': Model's state_dict (e.g. network weights) + precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only CHECKPOINT_HYPER_PARAMS_NAME: CHECKPOINT_HYPER_PARAMS_KEY: CHECKPOINT_HYPER_PARAMS_TYPE: From fef413ad57bf5cabc94fdb31caa37470ff97806f Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 9 Feb 2022 11:36:37 -0800 Subject: [PATCH 12/18] deprecation test --- tests/deprecated_api/test_remove_1-8.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index fa0d982478759..d6015ba883b85 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -19,6 +19,8 @@ from torch import optim from pytorch_lightning import Callback, Trainer +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin @@ -351,3 +353,24 @@ def test_v1_8_0_deprecated_lightning_optimizers(): match="Trainer.lightning_optimizers` is deprecated in v1.6 and will be removed in v1.8" ): assert trainer.lightning_optimizers == {} + +def test_v1_8_0_deprecated_precplugin_checkpointhooks(): + apex_amp = ApexMixedPrecisionPlugin() + with pytest.deprecated_call( + match="is deprecated in v1.6 and will be removed in v1.8." + ): + apex_amp.on_save_checkpoint({}) + with pytest.deprecated_call( + match="is deprecated in v1.6 and will be removed in v1.8." + ): + apex_amp.on_load_checkpoint({}) + + native_amp = NativeMixedPrecisionPlugin(1, "a") + with pytest.deprecated_call( + match="is deprecated in v1.6 and will be removed in v1.8." + ): + native_amp.on_save_checkpoint({}) + with pytest.deprecated_call( + match="is deprecated in v1.6 and will be removed in v1.8." + ): + native_amp.on_load_checkpoint({}) From a924f24e1ba7d16e5baf6470d8b8b87e32da7764 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Feb 2022 19:40:56 +0000 Subject: [PATCH 13/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/deprecated_api/test_remove_1-8.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index 244abcfe683f4..b0eb903a85657 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -364,7 +364,7 @@ def test_v1_8_0_deprecated_lightning_optimizers(): ): assert trainer.lightning_optimizers == {} - + def test_v1_8_0_remove_on_batch_start_end(tmpdir): class TestCallback(Callback): def on_batch_start(self, *args, **kwargs): @@ -503,24 +503,16 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs): ): trainer.fit(model) - + def test_v1_8_0_deprecated_precplugin_checkpointhooks(): apex_amp = ApexMixedPrecisionPlugin() - with pytest.deprecated_call( - match="is deprecated in v1.6 and will be removed in v1.8." - ): + with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): apex_amp.on_save_checkpoint({}) - with pytest.deprecated_call( - match="is deprecated in v1.6 and will be removed in v1.8." - ): + with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): apex_amp.on_load_checkpoint({}) native_amp = NativeMixedPrecisionPlugin(1, "a") - with pytest.deprecated_call( - match="is deprecated in v1.6 and will be removed in v1.8." - ): + with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): native_amp.on_save_checkpoint({}) - with pytest.deprecated_call( - match="is deprecated in v1.6 and will be removed in v1.8." - ): + with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): native_amp.on_load_checkpoint({}) From 95082cfe262caee34af946de2c338a619d2dc817 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 9 Feb 2022 12:37:20 -0800 Subject: [PATCH 14/18] update --- tests/deprecated_api/test_remove_1-8.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index b0eb903a85657..c53282c263334 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -504,13 +504,16 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs): trainer.fit(model) -def test_v1_8_0_deprecated_precplugin_checkpointhooks(): +@RunIf(amp_apex=True) +def test_v1_8_0_deprecated_apexamp_checkpointhooks(): apex_amp = ApexMixedPrecisionPlugin() with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): apex_amp.on_save_checkpoint({}) with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): apex_amp.on_load_checkpoint({}) + +def test_v1_8_0_deprecated_nativeamp_checkpointhooks(): native_amp = NativeMixedPrecisionPlugin(1, "a") with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): native_amp.on_save_checkpoint({}) From 10859a2d5d0a8786b087353c1a67a80ee1e39dd3 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 9 Feb 2022 16:09:34 -0800 Subject: [PATCH 15/18] docstring warning for now --- .../plugins/precision/apex_amp.py | 12 ++++-------- .../plugins/precision/native_amp.py | 12 ++++-------- tests/deprecated_api/test_remove_1-8.py | 19 ------------------- 3 files changed, 8 insertions(+), 35 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index cdef0d735dcb5..fdaf06f48e77f 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -101,13 +101,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: amp.load_state_dict(state_dict) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - rank_zero_deprecation( - "`ApexMixedPrecisionPlugin.on_load_checkpoint` is deprecated in v1.6 and will be removed in v1.8." - " Use `ApexMixedPrecisionPlugin.load_state_dict` instead." - ) + """"``ApexMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. + Lightning will auto-restore ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.load_state_dict`` instead""" def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - rank_zero_deprecation( - "`ApexMixedPrecisionPlugin.on_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8." - " Use `ApexMixedPrecisionPlugin.state_dict` instead." - ) + """"``ApexMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. + Lightning will auto-save ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.state_dict`` instead""" diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 9d07d7b5ee510..5fed3357cf810 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -119,13 +119,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.scaler.load_state_dict(state_dict) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - rank_zero_deprecation( - "`NativeMixedPrecisionPlugin.on_load_checkpoint` is deprecated in v1.6 and will be removed in v1.8." - " Use `NativeMixedPrecisionPlugin.load_state_dict` instead." - ) + """"``NativeMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. + Lightning will auto-restore NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.load_state_dict`` instead""" def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - rank_zero_deprecation( - "`NativeMixedPrecisionPlugin.on_save_checkpoint` is deprecated in v1.6 and will be removed in v1.8." - " Use `NativeMixedPrecisionPlugin.state_dict` instead." - ) + """"``NativeMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. + Lightning will auto-save NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.state_dict`` instead""" diff --git a/tests/deprecated_api/test_remove_1-8.py b/tests/deprecated_api/test_remove_1-8.py index c53282c263334..1f7a92d0745c9 100644 --- a/tests/deprecated_api/test_remove_1-8.py +++ b/tests/deprecated_api/test_remove_1-8.py @@ -19,8 +19,6 @@ from torch import optim from pytorch_lightning import Callback, Trainer -from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin -from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin @@ -502,20 +500,3 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs): " and will be removed in v1.8" ): trainer.fit(model) - - -@RunIf(amp_apex=True) -def test_v1_8_0_deprecated_apexamp_checkpointhooks(): - apex_amp = ApexMixedPrecisionPlugin() - with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): - apex_amp.on_save_checkpoint({}) - with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): - apex_amp.on_load_checkpoint({}) - - -def test_v1_8_0_deprecated_nativeamp_checkpointhooks(): - native_amp = NativeMixedPrecisionPlugin(1, "a") - with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): - native_amp.on_save_checkpoint({}) - with pytest.deprecated_call(match="is deprecated in v1.6 and will be removed in v1.8."): - native_amp.on_load_checkpoint({}) From a711bf98d371815b0d93cb44a47b2f23dc446fd6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Feb 2022 00:10:56 +0000 Subject: [PATCH 16/18] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/precision/apex_amp.py | 13 +++++++++---- pytorch_lightning/plugins/precision/native_amp.py | 13 +++++++++---- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index fdaf06f48e77f..ff981aed33f9c 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -101,9 +101,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: amp.load_state_dict(state_dict) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """"``ApexMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. - Lightning will auto-restore ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.load_state_dict`` instead""" + """ "``ApexMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. + + Lightning will auto-restore ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.load_state_dict`` + instead + """ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """"``ApexMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. - Lightning will auto-save ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.state_dict`` instead""" + """ "``ApexMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. + + Lightning will auto-save ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.state_dict`` instead + """ diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 5fed3357cf810..250e119b2554c 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -119,9 +119,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.scaler.load_state_dict(state_dict) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """"``NativeMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. - Lightning will auto-restore NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.load_state_dict`` instead""" + """ "``NativeMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. + + Lightning will auto-restore NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.load_state_dict`` + instead + """ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """"``NativeMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. - Lightning will auto-save NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.state_dict`` instead""" + """ "``NativeMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. + + Lightning will auto-save NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.state_dict`` instead + """ From 8c93a43ee3a5d0fa863bc593dde884d41ad7a92b Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 9 Feb 2022 16:13:41 -0800 Subject: [PATCH 17/18] clean imports --- pytorch_lightning/plugins/precision/apex_amp.py | 1 - pytorch_lightning/plugins/precision/native_amp.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index ff981aed33f9c..47171c404a726 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -21,7 +21,6 @@ from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation from pytorch_lightning.utilities.types import _PARAMETERS if _APEX_AVAILABLE: diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 250e119b2554c..7ca6a49b985da 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -23,7 +23,6 @@ from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation if _TORCH_GREATER_EQUAL_1_10: from torch import autocast as new_autocast From f764b562f6d568c4e6e3c11e6738335a068f9257 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 9 Feb 2022 17:15:44 -0800 Subject: [PATCH 18/18] clean --- pytorch_lightning/plugins/precision/apex_amp.py | 4 ++-- pytorch_lightning/plugins/precision/native_amp.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 47171c404a726..07a7737b6bb5f 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -100,14 +100,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: amp.load_state_dict(state_dict) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """ "``ApexMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. + """``ApexMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. Lightning will auto-restore ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.load_state_dict`` instead """ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """ "``ApexMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. + """``ApexMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. Lightning will auto-save ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.state_dict`` instead """ diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 7ca6a49b985da..ac3a16621416f 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -118,14 +118,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.scaler.load_state_dict(state_dict) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """ "``NativeMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. + """``NativeMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. Lightning will auto-restore NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.load_state_dict`` instead """ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - """ "``NativeMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. + """``NativeMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. Lightning will auto-save NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.state_dict`` instead """