diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3a333f0d4cce18..34293f21fbca2a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -26,9 +26,8 @@ import numpy as np import torch -from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -185,9 +184,11 @@ def __init__( self.kth_value, self.mode = mode_dict[mode] + @rank_zero_only def _del_model(self, filepath: str): if self._fs.exists(filepath): self._fs.rm(filepath) + rank_zero_info(f"Removed checkpoint: {filepath}") def _save_model(self, filepath: str, trainer, pl_module): @@ -195,7 +196,8 @@ def _save_model(self, filepath: str, trainer, pl_module): trainer.dev_debugger.track_checkpointing_history(filepath) # make paths - self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) + if trainer.is_global_zero: + self._fs.makedirs(os.path.dirname(filepath), exist_ok=True) # delegate the saving to the model if self.save_function is not None: @@ -272,7 +274,6 @@ def format_checkpoint_name( ckpt_name = f"{filename}.ckpt" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name - @rank_zero_only def on_pretrain_routine_start(self, trainer, pl_module): """ Determines model checkpoint save directory at runtime. References attributes from the @@ -312,10 +313,8 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.dirpath = ckpt_path - assert ( - trainer.global_rank == 0 - ), "tried to make a checkpoint from non global_rank=0" - self._fs.makedirs(self.dirpath, exist_ok=True) + if trainer.is_global_zero: + self._fs.makedirs(self.dirpath, exist_ok=True) def __warn_deprecated_monitor_key(self): using_result_obj = os.environ.get("PL_USING_RESULT_OBJ", None) @@ -333,12 +332,7 @@ def __warn_deprecated_monitor_key(self): f" Remove `ModelCheckpoint(monitor='{self.monitor}')` to fix." ) - @rank_zero_only def on_validation_end(self, trainer, pl_module): - # only run on main process - if trainer.global_rank != 0: - return - if trainer.running_sanity_check: return @@ -398,17 +392,14 @@ def on_validation_end(self, trainer, pl_module): elif self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch, trainer, pl_module) elif self.verbose: - log.info( + rank_zero_info( f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}" ) else: if self.verbose: - log.info(f"Epoch {epoch:d}: saving model to {filepath}") + rank_zero_info(f"Epoch {epoch:d}: saving model to {filepath}") - assert ( - trainer.global_rank == 0 - ), "tried to make a checkpoint from non global_rank=0" self._save_model(filepath, trainer, pl_module) if self.save_last: @@ -417,16 +408,15 @@ def on_validation_end(self, trainer, pl_module): ) filepath = os.path.join(self.dirpath, f"{filename}.ckpt") self._save_model(filepath, trainer, pl_module) - if self.last_model_path and self.last_model_path != filepath: + if ( + self.last_model_path + and self.last_model_path != filepath + and trainer.is_global_zero + ): self._del_model(self.last_model_path) def _do_check_save( - self, - filepath: str, - current: torch.Tensor, - epoch: int, - trainer, - pl_module, + self, filepath: str, current: torch.Tensor, epoch: int, trainer, pl_module ): # remove kth @@ -450,20 +440,19 @@ def _do_check_save( self.best_model_score = self.best_k_models[self.best_model_path] if self.verbose: - log.info( + rank_zero_info( f"Epoch {epoch:d}: {self.monitor} reached" f" {current:0.5f} (best {self.best_model_score:0.5f})," f" saving model to {filepath} as top {self.save_top_k}" ) self._save_model(filepath, trainer, pl_module) - for cur_path in del_list: - if cur_path != filepath: - self._del_model(cur_path) + if trainer.is_global_zero: + for cur_path in del_list: + if cur_path != filepath: + self._del_model(cur_path) - def on_save_checkpoint( - self, trainer, pl_module - ) -> Dict[str, Any]: + def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]: return { "best_model_score": self.best_model_score, "best_model_path": self.best_model_path,