diff --git a/CHANGELOG.md b/CHANGELOG.md index a564d5b8094e6..bb37f5542a7bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487)) +- Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684)) + + ### Changed - Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e0d63a1b0c3af..7642ad95d08bf 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -33,12 +33,11 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT from pytorch_lightning.utilities.warnings import WarningCache -from pytorch_lightning.utilities.xla_device import tpu_training_and_local_rank_zero log = logging.getLogger(__name__) warning_cache = WarningCache() @@ -473,9 +472,8 @@ def save_function(self, value: Optional[Callable]) -> None: ) self._save_function = value - @rank_zero_only - def _del_model(self, filepath: str) -> None: - if self._fs.exists(filepath): + def _del_model(self, trainer: 'pl.Trainer', filepath: str) -> None: + if trainer.should_rank_save_checkpoint and self._fs.exists(filepath): self._fs.rm(filepath) log.debug(f"Removed checkpoint: {filepath}") @@ -493,7 +491,7 @@ def _do_save(self, trainer: 'pl.Trainer', filepath: str) -> None: trainer.dev_debugger.track_checkpointing_history(filepath) # make paths - if trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer): + if trainer.should_rank_save_checkpoint: self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) # delegate the saving to the trainer @@ -631,7 +629,7 @@ def __resolve_ckpt_dir(self, trainer: 'pl.Trainer') -> None: self.dirpath = ckpt_path - if (not trainer.fast_dev_run and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))): + if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint: self._fs.makedirs(self.dirpath, exist_ok=True) def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None: @@ -694,11 +692,8 @@ def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[ self._save_model(trainer, filepath) - if ( - self.last_model_path and self.last_model_path != filepath - and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer)) - ): - self._del_model(self.last_model_path) + if self.last_model_path and self.last_model_path != filepath and trainer.should_rank_save_checkpoint: + self._del_model(trainer, self.last_model_path) self.last_model_path = filepath @@ -724,9 +719,9 @@ def _save_none_monitor_checkpoint(self, trainer: 'pl.Trainer', monitor_candidate if ( self.save_top_k is None and self.best_model_path and self.best_model_path != filepath - and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer)) + and trainer.should_rank_save_checkpoint ): - self._del_model(self.best_model_path) + self._del_model(trainer, self.best_model_path) self.best_model_path = filepath @@ -773,7 +768,7 @@ def _update_best_and_save( self._save_model(trainer, filepath) if del_filepath is not None and filepath != del_filepath: - self._del_model(del_filepath) + self._del_model(trainer, del_filepath) def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None: """ diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 9ac1e757b2b6d..9a27e6230b201 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -308,3 +308,7 @@ def teardown(self) -> None: # TPU teardown os.environ.pop("PT_XLA_DEBUG", None) self.barrier("teardown") + + @property + def should_rank_save_checkpoint(self) -> bool: + return self.local_rank == 0 diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 0a148a01dbb69..8d27fd4ac6a2f 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -309,3 +309,8 @@ def teardown(self) -> None: @classmethod def register_plugins(cls, plugin_registry): pass + + @property + def should_rank_save_checkpoint(self) -> bool: + """Returns whether the checkpoint should be saved (rank based)""" + return self.is_global_zero diff --git a/pytorch_lightning/trainer/properties.py b/pytorch_lightning/trainer/properties.py index ff12e5c6e9053..e469d1bc12394 100644 --- a/pytorch_lightning/trainer/properties.py +++ b/pytorch_lightning/trainer/properties.py @@ -97,6 +97,10 @@ def world_size(self) -> int: # some training types define a world size return getattr(self.accelerator.training_type_plugin, "world_size", 1) + @property + def should_rank_save_checkpoint(self) -> bool: + return self.accelerator.training_type_plugin.should_rank_save_checkpoint + @property def _distrib_type(self) -> DistributedType: return self.accelerator_connector._distrib_type diff --git a/pytorch_lightning/utilities/xla_device.py b/pytorch_lightning/utilities/xla_device.py index 6fb31a0a824cc..513972e4bd8a3 100644 --- a/pytorch_lightning/utilities/xla_device.py +++ b/pytorch_lightning/utilities/xla_device.py @@ -17,8 +17,6 @@ import traceback from multiprocessing import Process, Queue -import pytorch_lightning as pl -from pytorch_lightning.utilities.enums import DeviceType from pytorch_lightning.utilities.imports import _XLA_AVAILABLE if _XLA_AVAILABLE: @@ -105,8 +103,3 @@ def tpu_device_exists() -> bool: if XLADeviceUtils._TPU_AVAILABLE: os.environ["PL_TPU_AVAILABLE"] = '1' return XLADeviceUtils._TPU_AVAILABLE - - -def tpu_training_and_local_rank_zero(trainer: 'pl.Trainer') -> bool: - return trainer._device_type == DeviceType.TPU and \ - trainer.training_type_plugin.local_rank == 0