Skip to content

Commit

Permalink
Remove rank_zero_only decorators from model checkpoint entry points
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Sep 21, 2020
1 parent 2775389 commit 433e158
Showing 1 changed file with 21 additions and 32 deletions.
53 changes: 21 additions & 32 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@

import numpy as np
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem


Expand Down Expand Up @@ -185,17 +184,20 @@ def __init__(

self.kth_value, self.mode = mode_dict[mode]

@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
self._fs.rm(filepath)
rank_zero_info(f"Removed checkpoint: {filepath}")

def _save_model(self, filepath: str, trainer, pl_module):

# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)

# make paths
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
if trainer.is_global_zero:
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the model
if self.save_function is not None:
Expand Down Expand Up @@ -272,7 +274,6 @@ def format_checkpoint_name(
ckpt_name = f"{filename}.ckpt"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

@rank_zero_only
def on_pretrain_routine_start(self, trainer, pl_module):
"""
Determines model checkpoint save directory at runtime. References attributes from the
Expand Down Expand Up @@ -312,10 +313,8 @@ def on_pretrain_routine_start(self, trainer, pl_module):

self.dirpath = ckpt_path

assert (
trainer.global_rank == 0
), "tried to make a checkpoint from non global_rank=0"
self._fs.makedirs(self.dirpath, exist_ok=True)
if trainer.is_global_zero:
self._fs.makedirs(self.dirpath, exist_ok=True)

def __warn_deprecated_monitor_key(self):
using_result_obj = os.environ.get("PL_USING_RESULT_OBJ", None)
Expand All @@ -333,12 +332,7 @@ def __warn_deprecated_monitor_key(self):
f" Remove `ModelCheckpoint(monitor='{self.monitor}')` to fix."
)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.global_rank != 0:
return

if trainer.running_sanity_check:
return

Expand Down Expand Up @@ -398,17 +392,14 @@ def on_validation_end(self, trainer, pl_module):
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose:
log.info(
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
)

else:
if self.verbose:
log.info(f"Epoch {epoch:d}: saving model to {filepath}")
rank_zero_info(f"Epoch {epoch:d}: saving model to {filepath}")

assert (
trainer.global_rank == 0
), "tried to make a checkpoint from non global_rank=0"
self._save_model(filepath, trainer, pl_module)

if self.save_last:
Expand All @@ -417,16 +408,15 @@ def on_validation_end(self, trainer, pl_module):
)
filepath = os.path.join(self.dirpath, f"{filename}.ckpt")
self._save_model(filepath, trainer, pl_module)
if self.last_model_path and self.last_model_path != filepath:
if (
self.last_model_path
and self.last_model_path != filepath
and trainer.is_global_zero
):
self._del_model(self.last_model_path)

def _do_check_save(
self,
filepath: str,
current: torch.Tensor,
epoch: int,
trainer,
pl_module,
self, filepath: str, current: torch.Tensor, epoch: int, trainer, pl_module
):
# remove kth

Expand All @@ -450,20 +440,19 @@ def _do_check_save(
self.best_model_score = self.best_k_models[self.best_model_path]

if self.verbose:
log.info(
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
f" saving model to {filepath} as top {self.save_top_k}"
)
self._save_model(filepath, trainer, pl_module)

for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
if trainer.is_global_zero:
for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)

def on_save_checkpoint(
self, trainer, pl_module
) -> Dict[str, Any]:
def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
return {
"best_model_score": self.best_model_score,
"best_model_path": self.best_model_path,
Expand Down

0 comments on commit 433e158

Please sign in to comment.