diff --git a/CHANGELOG.md b/CHANGELOG.md index cf20998955453..21840c2d3500a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -108,6 +108,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `_Stateful` support for `LightningDataModule` ([#11637](https://github.com/PyTorchLightning/pytorch-lightning/pull/11637)) +- Added `_Stateful` support for `PrecisionPlugin` ([#11638](https://github.com/PyTorchLightning/pytorch-lightning/pull/11638)) + + - Added `Accelerator.is_available` to check device availability ([#11797](https://github.com/PyTorchLightning/pytorch-lightning/pull/11797)) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 9e8b35289bce5..07a7737b6bb5f 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -93,9 +93,21 @@ def optimizer_step( return optimizer.step(**kwargs) return closure_result + def state_dict(self) -> Dict[str, Any]: + return amp.state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + amp.load_state_dict(state_dict) + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if "amp_scaling_state" in checkpoint: - amp.load_state_dict(checkpoint["amp_scaling_state"]) + """``ApexMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. + + Lightning will auto-restore ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.load_state_dict`` + instead + """ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - checkpoint["amp_scaling_state"] = amp.state_dict() + """``ApexMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. + + Lightning will auto-save ApexMixedPrecisionPlugin state with ``ApexMixedPrecisionPlugin.state_dict`` instead + """ diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 2c9174e5517cb..ac3a16621416f 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -108,10 +108,24 @@ def forward_context(self) -> Generator[None, None, None]: with self.autocast_context_manager(): yield + def state_dict(self) -> Dict[str, Any]: + if self.scaler is not None: + return self.scaler.state_dict() + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + if self.scaler is not None: + self.scaler.load_state_dict(state_dict) + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if self.scaler is not None and "native_amp_scaling_state" in checkpoint: - self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"]) + """``NativeMixedPrecisionPlugin.on_load_checkpoint`` is deprecated in v1.6. + + Lightning will auto-restore NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.load_state_dict`` + instead + """ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - if self.scaler is not None: - checkpoint["native_amp_scaling_state"] = self.scaler.state_dict() + """``NativeMixedPrecisionPlugin.on_save_checkpoint`` is deprecated in v1.6. + + Lightning will auto-save NativeMixedPrecisionPlugin state with ``NativeMixedPrecisionPlugin.state_dict`` instead + """ diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 6bdfe12dc2a2d..c876dd5c909d2 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from functools import partial -from typing import Any, Callable, Generator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import torch from torch import Tensor @@ -242,3 +242,20 @@ def teardown(self) -> None: It is the right place to release memory and free other resources. """ + + def state_dict(self) -> Dict[str, Any]: + """Called when saving a checkpoint, implement to generate precision plugin state_dict. + + Returns: + A dictionary containing precision plugin state. + """ + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """Called when loading a checkpoint, implement to reload precision plugin state given precision plugin + state_dict. + + Args: + state_dict: the precision plugin state returned by ``state_dict``. + """ + pass diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3560678baaa02..e28608c5a7fbf 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -23,6 +23,7 @@ import pytorch_lightning as pl from pytorch_lightning.loops.utilities import _is_max_limit_reached from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -197,7 +198,7 @@ def restore_training_state(self) -> None: return # restore precision plugin (scaler etc.) - self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) + self.restore_precision_plugin_state() # restore loops and their progress self.restore_loops() @@ -207,6 +208,21 @@ def restore_training_state(self) -> None: # restore optimizers and schedulers state self.restore_optimizers_and_schedulers() + def restore_precision_plugin_state(self) -> None: + """Restore the precision plugin state from the pre-loaded checkpoint.""" + prec_plugin = self.trainer.precision_plugin + prec_plugin.on_load_checkpoint(self._loaded_checkpoint) + if prec_plugin.__class__.__qualname__ in self._loaded_checkpoint: + prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__qualname__]) + + # old checkpoints compatibility + if "amp_scaling_state" in self._loaded_checkpoint and isinstance(prec_plugin, ApexMixedPrecisionPlugin): + prec_plugin.load_state_dict(self._loaded_checkpoint["amp_scaling_state"]) + if "native_amp_scaling_state" in self._loaded_checkpoint and isinstance( + prec_plugin, NativeMixedPrecisionPlugin + ): + prec_plugin.load_state_dict(self._loaded_checkpoint["native_amp_scaling_state"]) + def restore_callbacks(self) -> None: """Restores all callbacks from the pre-loaded checkpoint.""" if not self._loaded_checkpoint: @@ -319,9 +335,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'callbacks': "callback specific state"[] # if not weights_only 'optimizer_states': "PT optim's state_dict"[] # if not weights_only 'lr_schedulers': "PT sched's state_dict"[] # if not weights_only - 'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp - 'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp 'state_dict': Model's state_dict (e.g. network weights) + precision_plugin.__class__.__qualname__: precision plugin state_dict # if not weights_only CHECKPOINT_HYPER_PARAMS_NAME: CHECKPOINT_HYPER_PARAMS_KEY: CHECKPOINT_HYPER_PARAMS_TYPE: @@ -367,7 +382,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: lr_schedulers.append(config.scheduler.state_dict()) checkpoint["lr_schedulers"] = lr_schedulers - self.trainer.precision_plugin.on_save_checkpoint(checkpoint) + # precision plugin + prec_plugin = self.trainer.precision_plugin + prec_plugin_state_dict = prec_plugin.state_dict() + if prec_plugin_state_dict: + checkpoint[prec_plugin.__class__.__qualname__] = prec_plugin_state_dict + prec_plugin.on_save_checkpoint(checkpoint) # dump hyper-parameters if model.hparams: diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 00ccaa3ec7c6c..c7233ea362bcb 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -491,10 +491,8 @@ def training_step(self, batch, batch_idx): "state_dict": ANY, "loops": ANY, } - if kwargs.get("amp_backend") == "native": - saved_ckpt["native_amp_scaling_state"] = ANY - elif kwargs.get("amp_backend") == "apex": - saved_ckpt["amp_scaling_state"] = ANY + if kwargs.get("amp_backend") == "native" or kwargs.get("amp_backend") == "apex": + saved_ckpt[trainer.precision_plugin.__class__.__qualname__] = ANY device = torch.device("cuda:0" if "gpus" in kwargs else "cpu") expected = [ dict(name="Callback.on_init_start", args=(trainer,)),