Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add should_rank_save_checkpoint property to Training Plugins #7684

Merged
merged 10 commits into from
May 25, 2021
Prev Previous commit
Next Next commit
Address comments
kaushikb11 committed May 24, 2021
commit 696b7254a77b49adb1553fba388a13493b492b2d
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -41,7 +41,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))


- Added `should_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684))
- Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684))


### Changed
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
@@ -492,7 +492,7 @@ def _do_save(self, trainer: 'pl.Trainer', filepath: str) -> None:
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
if trainer.should_save_checkpoint:
if trainer.should_rank_save_checkpoint:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the trainer
@@ -630,7 +630,7 @@ def __resolve_ckpt_dir(self, trainer: 'pl.Trainer') -> None:

self.dirpath = ckpt_path

if not trainer.fast_dev_run and trainer.should_save_checkpoint:
if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint:
self._fs.makedirs(self.dirpath, exist_ok=True)

def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None:
@@ -693,7 +693,7 @@ def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[

self._save_model(trainer, filepath)

if self.last_model_path and self.last_model_path != filepath and trainer.should_save_checkpoint:
if self.last_model_path and self.last_model_path != filepath and trainer.should_rank_save_checkpoint:
self._del_model(self.last_model_path)

self.last_model_path = filepath
@@ -720,7 +720,7 @@ def _save_none_monitor_checkpoint(self, trainer: 'pl.Trainer', monitor_candidate

if (
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
and trainer.should_save_checkpoint
and trainer.should_rank_save_checkpoint
):
self._del_model(self.best_model_path)

8 changes: 2 additions & 6 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
@@ -74,10 +74,6 @@ def world_size(self) -> int:
def root_device(self) -> torch.device:
return xm.xla_device()

@property
def is_local_zero(self) -> bool:
return self.local_rank == 0

@staticmethod
def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> None:
if not isinstance(dataloaders, list):
@@ -314,5 +310,5 @@ def teardown(self) -> None:
self.barrier("teardown")

@property
def should_save_checkpoint(self) -> bool:
return self.is_local_zero
def should_rank_save_checkpoint(self) -> bool:
return self.local_rank == 0
Original file line number Diff line number Diff line change
@@ -311,6 +311,6 @@ def register_plugins(cls, plugin_registry):
pass

@property
def should_save_checkpoint(self) -> bool:
def should_rank_save_checkpoint(self) -> bool:
"""Returns whether the checkpoint should be saved (rank based)"""
return self.is_global_zero
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
@@ -98,8 +98,8 @@ def world_size(self) -> int:
return getattr(self.accelerator.training_type_plugin, "world_size", 1)

@property
def should_save_checkpoint(self) -> bool:
return self.accelerator.training_type_plugin.should_save_checkpoint
def should_rank_save_checkpoint(self) -> bool:
return self.accelerator.training_type_plugin.should_rank_save_checkpoint

@property
def _distrib_type(self) -> DistributedType: