From 005b0f45a3721e05100f41349e01383ffb0f0942 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 13 Aug 2021 17:35:31 +0100 Subject: [PATCH] Introduce CheckpointIO Plugin (#8743) --- CHANGELOG.md | 3 + docs/source/advanced/checkpoint_io.rst | 52 +++++++++++ docs/source/api_references.rst | 12 +++ docs/source/index.rst | 1 + pytorch_lightning/plugins/__init__.py | 4 + pytorch_lightning/plugins/io/__init__.py | 15 ++++ .../plugins/io/checkpoint_plugin.py | 57 ++++++++++++ pytorch_lightning/plugins/io/torch_plugin.py | 55 ++++++++++++ .../plugins/training_type/ddp.py | 10 ++- .../plugins/training_type/ddp_spawn.py | 10 ++- .../plugins/training_type/deepspeed.py | 15 +++- pytorch_lightning/plugins/training_type/dp.py | 9 +- .../plugins/training_type/fully_sharded.py | 19 ++-- .../plugins/training_type/horovod.py | 9 +- .../plugins/training_type/ipu.py | 8 +- .../plugins/training_type/parallel.py | 4 +- .../plugins/training_type/single_device.py | 9 +- .../plugins/training_type/single_tpu.py | 18 +++- .../plugins/training_type/tpu_spawn.py | 11 ++- .../training_type/training_type_plugin.py | 30 ++++--- .../connectors/accelerator_connector.py | 24 ++++-- pytorch_lightning/utilities/types.py | 2 + tests/accelerators/test_cpu.py | 3 +- tests/plugins/test_checkpoint_io_plugin.py | 86 +++++++++++++++++++ 24 files changed, 420 insertions(+), 46 deletions(-) create mode 100644 docs/source/advanced/checkpoint_io.rst create mode 100644 pytorch_lightning/plugins/io/__init__.py create mode 100644 pytorch_lightning/plugins/io/checkpoint_plugin.py create mode 100644 pytorch_lightning/plugins/io/torch_plugin.py create mode 100644 tests/plugins/test_checkpoint_io_plugin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 384fb6a20e1a0b..43a0fc8cdb9029 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) +- Added `CheckpointIO` to expose checkpoint IO from training type plugin ([#8743](https://github.com/PyTorchLightning/pytorch-lightning/pull/8743)) + + ### Changed - Parsing of the `gpus` Trainer argument has changed: `gpus="n"` (str) no longer selects the GPU index n and instead selects the first n devices. ([#8770](https://github.com/PyTorchLightning/pytorch-lightning/pull/8770)) diff --git a/docs/source/advanced/checkpoint_io.rst b/docs/source/advanced/checkpoint_io.rst new file mode 100644 index 00000000000000..6eabfae99b07b1 --- /dev/null +++ b/docs/source/advanced/checkpoint_io.rst @@ -0,0 +1,52 @@ +Custom Checkpointing IO +======================= + +.. warning:: The Checkpoint IO API is experimental and subject to change. + +Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic +that is managed by the ``TrainingTypePlugin``. + +``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a `Trainer`` object or a``TrainingTypePlugin`` as shown below. + +.. code-block:: python + + from pathlib import Path + from typing import Any, Dict, Optional, Union + + from pytorch_lightning import Trainer + from pytorch_lightning.callbacks import ModelCheckpoint + from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin + + + class CustomCheckpointIO(CheckpointIO): + def save_checkpoint( + self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None + ) -> None: + ... + + def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]: + ... + + + custom_checkpoint_io = CustomCheckpointIO() + + # Pass into the Trainer object + model = MyModel() + trainer = Trainer( + plugins=[custom_checkpoint_io], + callbacks=ModelCheckpoint(save_last=True), + ) + trainer.fit(model) + + # pass into TrainingTypePlugin + model = MyModel() + device = torch.device("cpu") + trainer = Trainer( + plugins=SingleDevicePlugin(device, checkpoint_io=custom_checkpoint_io), + callbacks=ModelCheckpoint(save_last=True), + ) + trainer.fit(model) + +.. note:: + + Some ``TrainingTypePlugins`` do not support custom ``CheckpointIO`` as as checkpointing logic is not modifiable. diff --git a/docs/source/api_references.rst b/docs/source/api_references.rst index 919aece8a72d3c..49b5556d7a922e 100644 --- a/docs/source/api_references.rst +++ b/docs/source/api_references.rst @@ -127,6 +127,18 @@ Cluster Environments KubeflowEnvironment SLURMEnvironment +Checkpoint IO Plugins +^^^^^^^^^^^^^^^^^^^^^ + +.. currentmodule:: pytorch_lightning.plugins.io + +.. autosummary:: + :toctree: api + :nosignatures: + :template: classtemplate.rst + + CheckpointIO + TorchCheckpointIO Profiler API ------------ diff --git a/docs/source/index.rst b/docs/source/index.rst index 5e8e4352b57d8e..105a1dfff54bf5 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -55,6 +55,7 @@ PyTorch Lightning Documentation advanced/multi_gpu advanced/advanced_gpu common/weights_loading + advanced/checkpoint_io common/optimizers advanced/profiler advanced/sequences diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index d6434e84adae79..a69065fa74f733 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,4 +1,6 @@ from pytorch_lightning.plugins.base_plugin import Plugin +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.plugins.plugins_registry import ( # noqa: F401 call_training_type_register_plugins, TrainingTypePluginsRegistry, @@ -29,6 +31,8 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin __all__ = [ + "CheckpointIO", + "TorchCheckpointIO", "ApexMixedPrecisionPlugin", "DataParallelPlugin", "DDP2Plugin", diff --git a/pytorch_lightning/plugins/io/__init__.py b/pytorch_lightning/plugins/io/__init__.py new file mode 100644 index 00000000000000..232f582c1a5208 --- /dev/null +++ b/pytorch_lightning/plugins/io/__init__.py @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO # noqa: F401 +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO # noqa: F401 diff --git a/pytorch_lightning/plugins/io/checkpoint_plugin.py b/pytorch_lightning/plugins/io/checkpoint_plugin.py new file mode 100644 index 00000000000000..575399af48df34 --- /dev/null +++ b/pytorch_lightning/plugins/io/checkpoint_plugin.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from pytorch_lightning.utilities.types import _PATH + + +class CheckpointIO(ABC): + """ + Interface to save/load checkpoints as they are saved through the ``TrainingTypePlugin``. + + Typically most plugins either use the Torch based IO Plugin; ``TorchCheckpointIO`` but may + require particular handling depending on the plugin. + + In addition, you can pass a custom ``CheckpointIO`` by extending this class and passing it + to the Trainer, i.e ``Trainer(plugins=[MyCustomCheckpointIO()])``. + + .. note:: + + For some plugins, it is not possible to use a custom checkpoint plugin as checkpointing logic is not + modifiable. + + """ + + @abstractmethod + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + path: write-target path + storage_options: Optional parameters when saving the model/training states. + """ + + @abstractmethod + def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: + """ + Load checkpoint from a path when resuming or loading ckpt for test/validate/predict stages. + + Args: + path: Path to checkpoint + storage_options: Optional parameters when loading the model/training states. + + Returns: The loaded checkpoint. + """ diff --git a/pytorch_lightning/plugins/io/torch_plugin.py b/pytorch_lightning/plugins/io/torch_plugin.py new file mode 100644 index 00000000000000..e95f3d3b226f7b --- /dev/null +++ b/pytorch_lightning/plugins/io/torch_plugin.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional + +import pytorch_lightning as pl +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.cloud_io import atomic_save +from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities.types import _PATH + + +class TorchCheckpointIO(CheckpointIO): + """ + CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` + to save and load checkpoints respectively, common for most use cases. + """ + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + try: + # write the checkpoint dictionary on the file + atomic_save(checkpoint, path) + except AttributeError as err: + # todo (sean): is this try catch necessary still? + # https://github.com/PyTorchLightning/pytorch-lightning/pull/431 + key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY + checkpoint.pop(key, None) + rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") + atomic_save(checkpoint, path) + + def load_checkpoint( + self, path: _PATH, map_location: Optional[Callable] = lambda storage, loc: storage + ) -> Dict[str, Any]: + """ + Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. + + Args: + path: Path to checkpoint + map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage + locations. + + Returns: The loaded checkpoint. + """ + return pl_load(path, map_location=map_location) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 003d567b35fc05..8348a565fa486b 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -32,6 +32,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import ( _HYDRA_AVAILABLE, @@ -75,14 +76,19 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, - cluster_environment: ClusterEnvironment = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: - super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + ) self.interactive_ddp_procs = [] if num_nodes is not None: rank_zero_deprecation( diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index c39d35e8d1b6a1..759743ad4a8257 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -25,6 +25,7 @@ from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import ( @@ -62,14 +63,19 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, num_nodes: Optional[int] = None, - cluster_environment: ClusterEnvironment = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, sync_batchnorm: Optional[bool] = None, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[callable] = None, ddp_comm_wrapper: Optional[callable] = None, **kwargs: Any, ): - super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + ) if num_nodes is not None: rank_zero_deprecation( "Argument `num_nodes` in `DDPSpawnPlugin` is deprecated in v1.4, and will be removed in v1.6. " diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 8b518de6dcdf44..940fe6cf4032e0 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -27,6 +27,7 @@ from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config from pytorch_lightning.trainer.states import TrainerFn @@ -274,8 +275,11 @@ def __init__( pin_memory = cpu_offload_use_pin_memory super().__init__( - parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment + parallel_devices=parallel_devices, + num_nodes=num_nodes, + cluster_environment=cluster_environment, ) + self.config = self._load_config(config) if self.config is None: # User has not overridden config, set defaults @@ -679,6 +683,7 @@ def save_checkpoint(self, checkpoint: Dict, filepath: str) -> None: filepath: write-target file's path """ if self.zero_stage_3 and self._multi_device and self.is_global_zero: + # todo (sean): Add link to docs once docs are merged. warning_cache.warn( "When saving the DeepSpeed Stage 3 checkpoint, " "each worker will save a shard of the checkpoint within a directory. " @@ -818,3 +823,11 @@ def register_plugins(cls, plugin_registry: Dict) -> None: offload_params_device="nvme", offload_optimizer_device="nvme", ) + + @property + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io + + @checkpoint_io.setter + def checkpoint_io(self, plugin: CheckpointIO) -> None: + raise MisconfigurationException("DeepSpeed currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index beedac2942ac6f..551324416cce93 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -17,6 +17,7 @@ from torch.nn import DataParallel from pytorch_lightning.overrides.data_parallel import LightningParallelModule +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.model_helpers import is_overridden @@ -29,8 +30,12 @@ class DataParallelPlugin(ParallelPlugin): device and each gets a split of the data. """ - def __init__(self, parallel_devices: Optional[List[torch.device]]): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + def __init__( + self, + parallel_devices: Optional[List[torch.device]], + checkpoint_io: Optional[CheckpointIO] = None, + ): + super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) @property def global_rank(self) -> int: diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 1acac25e96db4f..29c74439dd5eeb 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Dict, Generator, List, Optional import torch -from torch import Tensor from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -40,7 +40,8 @@ def __init__( min_num_params: int = 1e8, state_dict_to_cpu: bool = True, parallel_devices: Optional[List[torch.device]] = None, - cluster_environment: ClusterEnvironment = None, + cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, ): """ Plugin for Fully Sharded Data Parallel provided by FairScale. @@ -89,7 +90,11 @@ def __init__( (Defautl: True). """ - super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + ) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu self.flatten_parameters = flatten_parameters @@ -169,12 +174,6 @@ def model_to_device(self) -> None: # ensure we update the device type in the lightning module self.lightning_module.to(self.root_device) - def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: - # Currently it is same as default TrainingTypePlugin, i.e. return - # the full state dict for FSDP, in the future, we will provide sharded - # state dict. - return super().lightning_module_state_dict() - @property def setup_optimizers_in_pre_dispatch(self) -> bool: # Setup optimizers after the Fully Sharded Model has been made diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 34fe429d893620..e5eb8bf9723ea3 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -20,6 +20,7 @@ from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HOROVOD_AVAILABLE from pytorch_lightning.utilities.distributed import distributed_available @@ -33,8 +34,12 @@ class HorovodPlugin(ParallelPlugin): """Plugin for Horovod distributed training integration.""" - def __init__(self, parallel_devices: Optional[List[torch.device]] = None): - super().__init__(parallel_devices=parallel_devices, cluster_environment=None) + def __init__( + self, + parallel_devices: Optional[List[torch.device]] = None, + checkpoint_io: Optional[CheckpointIO] = None, + ): + super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io) rank_zero_only.rank = self.global_rank @property diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 82a0d3db68e645..4e711ddb406eb2 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -23,6 +23,7 @@ from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.trainer.supporters import CombinedLoader @@ -67,6 +68,7 @@ def __init__( autoreport_dir: Optional[str] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, training_opts: Optional["poptorch.Options"] = None, inference_opts: Optional["poptorch.Options"] = None, ) -> None: @@ -83,7 +85,11 @@ def __init__( inference_opts: Optional ``poptorch.Options`` to override the default created options for validation/testing and predicting. """ - super().__init__(parallel_devices, cluster_environment) + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + checkpoint_io=checkpoint_io, + ) if not _POPTORCH_AVAILABLE or not poptorch.ipuHardwareIsAvailable(): raise MisconfigurationException( "The IPU Accelerator requires IPU devices to run. " diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index 4186d697f21ac8..71aae1bb71a918 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -22,6 +22,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp @@ -34,8 +35,9 @@ def __init__( self, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, + checkpoint_io: Optional[CheckpointIO] = None, ): - super().__init__() + super().__init__(checkpoint_io) self.parallel_devices = parallel_devices self.cluster_environment = cluster_environment diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 5399cffe19f682..c92fead861c192 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -15,6 +15,7 @@ import torch +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin from pytorch_lightning.utilities import _XLA_AVAILABLE @@ -22,8 +23,12 @@ class SingleDevicePlugin(TrainingTypePlugin): """Plugin that handles communication on a single device.""" - def __init__(self, device: torch.device): - super().__init__() + def __init__( + self, + device: torch.device, + checkpoint_io: Optional[CheckpointIO] = None, + ): + super().__init__(checkpoint_io) self.device: torch.device = device self.global_rank = 0 self.local_rank = 0 diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index d83bd7ed8ba171..b6f7d4000da944 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -15,9 +15,11 @@ from typing import Any, Dict from pytorch_lightning.core.decorators import parameter_validation +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.exceptions import MisconfigurationException if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -29,10 +31,14 @@ class SingleTPUPlugin(SingleDevicePlugin): """Plugin for training on a single TPU device.""" - def __init__(self, device: int, debug: bool = False): + def __init__( + self, + device: int, + debug: bool = False, + ): device = xm.xla_device(device) - super().__init__(device) + super().__init__(device=device) self.debug = debug self.tpu_local_core_rank = 0 @@ -74,3 +80,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: def teardown(self) -> None: # TPU teardown os.environ.pop("PT_XLA_DEBUG", None) + + @property + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io + + @checkpoint_io.setter + def checkpoint_io(self, plugin: CheckpointIO) -> None: + raise MisconfigurationException("TPU Plugin currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index faec805773cb79..ee4cd9934d6506 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -25,6 +25,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.decorators import parameter_validation from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader from pytorch_lightning.trainer.states import TrainerFn @@ -53,7 +54,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): """Plugin for training multiple TPU devices using the :func:`torch.multiprocessing.spawn` method.""" def __init__(self, parallel_devices: Optional[List[int]] = None, debug: bool = False, **_: Any) -> None: - super().__init__(parallel_devices) + super().__init__(parallel_devices=parallel_devices) self.debug = debug self.tpu_local_core_rank = 0 self.tpu_global_core_rank = 0 @@ -345,3 +346,11 @@ def should_rank_save_checkpoint(self) -> bool: @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: plugin_registry.register("tpu_spawn_debug", cls, description="TPUSpawn Plugin with `debug` as True", debug=True) + + @property + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io + + @checkpoint_io.setter + def checkpoint_io(self, plugin: CheckpointIO) -> None: + raise MisconfigurationException("TPU Spawn Plugin currently does not support custom checkpoint plugins.") diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cdff37fd9bcb2f..09363c3dc826b6 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -24,10 +24,9 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.base_plugin import Plugin -from pytorch_lightning.utilities import rank_zero_warn -from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT TBroadcast = TypeVar("T") @@ -38,11 +37,21 @@ class TrainingTypePlugin(Plugin, ABC): Base class for all training type plugins that change the behaviour of the training, validation and test-loop. """ - def __init__(self) -> None: + def __init__(self, checkpoint_io: Optional[CheckpointIO] = None) -> None: self._model: Optional[Module] = None self._results: Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]] = None + checkpoint_io = checkpoint_io if checkpoint_io is not None else TorchCheckpointIO() + self._checkpoint_io = checkpoint_io self._call_configure_sharded_model_hook = True + @property + def checkpoint_io(self) -> CheckpointIO: + return self._checkpoint_io + + @checkpoint_io.setter + def checkpoint_io(self, plugin: CheckpointIO) -> None: + self._checkpoint_io = plugin + def connect(self, model: Module) -> None: """Called by the accelerator to connect the accelerator and the model with this plugin""" self.model = model @@ -146,7 +155,7 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: return self._results def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: - return pl_load(checkpoint_path, map_location=(lambda storage, loc: storage)) + return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: self.lightning_module.load_state_dict(checkpoint["state_dict"]) @@ -282,15 +291,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """ # dump states as a checkpoint dictionary object checkpoint = self.on_save(checkpoint) - if self.is_global_zero: - try: - # write the checkpoint dictionary on the file - atomic_save(checkpoint, filepath) - except AttributeError as err: - key = pl.LightningModule.CHECKPOINT_HYPER_PARAMS_KEY - checkpoint.pop(key, None) - rank_zero_warn(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}") - atomic_save(checkpoint, filepath) + if self.should_rank_save_checkpoint: + return self.checkpoint_io.save_checkpoint(checkpoint, filepath) @contextlib.contextmanager def model_sharded_context(self) -> Generator: diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 999e5eecf3a6d9..ec3b56489e32a9 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -26,6 +26,7 @@ from pytorch_lightning.accelerators.tpu import TPUAccelerator from pytorch_lightning.plugins import ( ApexMixedPrecisionPlugin, + CheckpointIO, DataParallelPlugin, DDP2Plugin, DDPFullyShardedPlugin, @@ -134,6 +135,7 @@ def __init__( self._precision_plugin: Optional[PrecisionPlugin] = None self._training_type_plugin: Optional[TrainingTypePlugin] = None self._cluster_environment: Optional[ClusterEnvironment] = None + self._checkpoint_io: Optional[CheckpointIO] = None plugins = plugins if plugins is not None else [] @@ -274,6 +276,7 @@ def _set_devices_if_none(self) -> None: def handle_given_plugins(self) -> None: training_type = None + checkpoint = None precision = None cluster_environment = None @@ -299,18 +302,25 @@ def handle_given_plugins(self) -> None: else: raise MisconfigurationException( - "You can only specify one precision and one training type plugin." - f" Found more than 1 training type plugin: {type(plug).__name__}" + "You can only specify one training type plugin." + f" Available: {type(training_type).__name__}, given: {type(plug).__name__}" ) elif isinstance(plug, PrecisionPlugin): if precision is None: precision = plug else: raise MisconfigurationException( - "You can only specify one precision and one training type plugin." - f" Found more than 1 precision plugin: {type(plug).__name__}" + "You can only specify one precision plugin." + f" Available: {type(precision).__name__}, given: {type(plug).__name__}" + ) + elif isinstance(plug, CheckpointIO): + if checkpoint is None: + checkpoint = plug + else: + raise MisconfigurationException( + "You can only specify one checkpoint plugin." + f" Available: {type(checkpoint).__name__}, given: {type(plug).__name__}" ) - elif isinstance(plug, ClusterEnvironment): if cluster_environment is None: cluster_environment = plug @@ -325,6 +335,7 @@ def handle_given_plugins(self) -> None: self._training_type_plugin = training_type self._precision_plugin = precision + self._checkpoint_io = checkpoint self._cluster_environment = cluster_environment or self.select_cluster_environment() @property @@ -341,6 +352,9 @@ def training_type_plugin(self) -> TrainingTypePlugin: if self._training_type_plugin is None: self._training_type_plugin = self.select_training_type_plugin() self._training_type_plugin = self.resolve_training_type_plugin(self._training_type_plugin) + # attach checkpoint plugin to the training type plugin + if self._checkpoint_io is not None: + self._training_type_plugin.checkpoint_io = self._checkpoint_io self._training_type_plugin_resolved = True return self._training_type_plugin diff --git a/pytorch_lightning/utilities/types.py b/pytorch_lightning/utilities/types.py index 774f44ceaeeec7..69cac5edf784e8 100644 --- a/pytorch_lightning/utilities/types.py +++ b/pytorch_lightning/utilities/types.py @@ -17,6 +17,7 @@ - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`) """ from numbers import Number +from pathlib import Path from typing import Any, Dict, Iterator, List, Mapping, Sequence, Type, Union import torch @@ -31,6 +32,7 @@ _EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader _PREDICT_OUTPUT = Union[List[Any], List[List[Any]]] _PARAMETERS = Iterator[torch.nn.Parameter] +_PATH = Union[str, Path] TRAIN_DATALOADERS = Union[ DataLoader, Sequence[DataLoader], diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 1a584ed4447584..a4a4d6d62d5d3f 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -9,6 +9,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin +from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -201,7 +202,7 @@ def load_checkpoint_file(self, checkpoint_path: Union[str, Path]) -> Dict[str, A checkpoint_path = os.path.join(tmpdir, "model.pt") trainer.save_checkpoint(checkpoint_path) - plugin = TestPlugin(torch.device("cpu")) + plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO()) accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin()) assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch diff --git a/tests/plugins/test_checkpoint_io_plugin.py b/tests/plugins/test_checkpoint_io_plugin.py new file mode 100644 index 00000000000000..ef43b8b14b1468 --- /dev/null +++ b/tests/plugins/test_checkpoint_io_plugin.py @@ -0,0 +1,86 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional +from unittest.mock import MagicMock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins import CheckpointIO, DeepSpeedPlugin, SingleDevicePlugin, TPUSpawnPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _PATH +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + + +class CustomCheckpointIO(CheckpointIO): + save_checkpoint_called: bool = False + load_checkpoint_file_called: bool = False + + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + self.save_checkpoint_called = True + torch.save(checkpoint, path) + + def load_checkpoint(self, path: _PATH, storage_options: Optional[Any] = None) -> Dict[str, Any]: + self.load_checkpoint_file_called = True + return torch.load(path) + + +def test_checkpoint_plugin_called(tmpdir): + """ + Ensure that the custom checkpoint IO plugin and torch checkpoint IO plugin is called when saving/loading. + """ + checkpoint_plugin = CustomCheckpointIO() + checkpoint_plugin = MagicMock(wraps=checkpoint_plugin, spec=CustomCheckpointIO) + + ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) + + model = BoringModel() + device = torch.device("cpu") + trainer = Trainer( + default_root_dir=tmpdir, + plugins=SingleDevicePlugin(device, checkpoint_io=checkpoint_plugin), + callbacks=ck, + max_epochs=1, + ) + trainer.fit(model) + assert checkpoint_plugin.save_checkpoint.call_count == 3 + trainer.test(model, ckpt_path=ck.last_model_path) + checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt") + + checkpoint_plugin.reset_mock() + ck = ModelCheckpoint(dirpath=tmpdir, save_last=True) + + model = BoringModel() + device = torch.device("cpu") + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[SingleDevicePlugin(device), checkpoint_plugin], + callbacks=ck, + max_epochs=1, + ) + trainer.fit(model) + assert checkpoint_plugin.save_checkpoint.call_count == 3 + + trainer.test(model, ckpt_path=ck.last_model_path) + checkpoint_plugin.load_checkpoint.assert_called_once() + checkpoint_plugin.load_checkpoint.assert_called_with(tmpdir / "last.ckpt") + + +@pytest.mark.parametrize("plugin_cls", [pytest.param(DeepSpeedPlugin, marks=RunIf(deepspeed=True)), TPUSpawnPlugin]) +def test_no_checkpoint_io_plugin_support(plugin_cls): + with pytest.raises(MisconfigurationException, match="currently does not support custom checkpoint plugins"): + plugin_cls().checkpoint_io = CustomCheckpointIO()