Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add should_rank_save_checkpoint property to Training Plugins #7684

Merged
merged 10 commits into from
May 25, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))


- Added `should_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684))


### Changed

- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)
Expand Down
12 changes: 4 additions & 8 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
from pytorch_lightning.utilities.xla_device import tpu_training_and_local_rank_zero

log = logging.getLogger(__name__)
warning_cache = WarningCache()
Expand Down Expand Up @@ -493,7 +492,7 @@ def _do_save(self, trainer: 'pl.Trainer', filepath: str) -> None:
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
if trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer):
if trainer.should_save_checkpoint:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the trainer
Expand Down Expand Up @@ -631,7 +630,7 @@ def __resolve_ckpt_dir(self, trainer: 'pl.Trainer') -> None:

self.dirpath = ckpt_path

if (not trainer.fast_dev_run and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))):
if not trainer.fast_dev_run and trainer.should_save_checkpoint:
self._fs.makedirs(self.dirpath, exist_ok=True)

def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None:
Expand Down Expand Up @@ -694,10 +693,7 @@ def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[

self._save_model(trainer, filepath)

if (
self.last_model_path and self.last_model_path != filepath
and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))
):
if self.last_model_path and self.last_model_path != filepath and trainer.should_save_checkpoint:
self._del_model(self.last_model_path)

self.last_model_path = filepath
Expand All @@ -724,7 +720,7 @@ def _save_none_monitor_checkpoint(self, trainer: 'pl.Trainer', monitor_candidate

if (
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))
and trainer.should_save_checkpoint
):
self._del_model(self.best_model_path)

Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def world_size(self) -> int:
def root_device(self) -> torch.device:
return xm.xla_device()

@property
def is_local_zero(self) -> bool:
return self.local_rank == 0

@staticmethod
def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None:
if not isinstance(dataloaders, list):
Expand Down Expand Up @@ -308,3 +312,7 @@ def teardown(self) -> None:
# TPU teardown
os.environ.pop("PT_XLA_DEBUG", None)
self.barrier("teardown")

@property
def should_save_checkpoint(self) -> bool:
return self.is_local_zero
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,8 @@ def teardown(self) -> None:
@classmethod
def register_plugins(cls, plugin_registry):
pass

@property
def should_save_checkpoint(self) -> bool:
"""Returns whether the checkpoint should be saved (rank based)"""
return self.is_global_zero
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ def world_size(self) -> int:
# some training types define a world size
return getattr(self.accelerator.training_type_plugin, "world_size", 1)

@property
def should_save_checkpoint(self) -> bool:
return self.accelerator.training_type_plugin.should_save_checkpoint

@property
def _distrib_type(self) -> DistributedType:
return self.accelerator_connector._distrib_type
Expand Down
7 changes: 0 additions & 7 deletions pytorch_lightning/utilities/xla_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import traceback
from multiprocessing import Process, Queue

import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import DeviceType
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE

if _XLA_AVAILABLE:
Expand Down Expand Up @@ -105,8 +103,3 @@ def tpu_device_exists() -> bool:
if XLADeviceUtils._TPU_AVAILABLE:
os.environ["PL_TPU_AVAILABLE"] = '1'
return XLADeviceUtils._TPU_AVAILABLE


def tpu_training_and_local_rank_zero(trainer: 'pl.Trainer') -> bool:
return trainer._device_type == DeviceType.TPU and \
trainer.training_type_plugin.local_rank == 0