Skip to content

Commit

Permalink
Update model_checkpoint.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ananthsub committed Feb 23, 2021
1 parent 4ed50fe commit aef0c39
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def __init__(
mode: str = "auto",
period: int = 1,
prefix: str = "",
every_n_epochs: int = 1,
every_n_batches: int = -1,
):
super().__init__()
self.monitor = monitor
Expand All @@ -180,6 +182,8 @@ def __init__(
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.every_n_epochs = every_n_epochs
self.every_n_batches = every_n_batches
self._last_global_step_saved = -1
self.prefix = prefix
self.current_score = None
Expand Down Expand Up @@ -208,11 +212,38 @@ def on_pretrain_routine_start(self, trainer, pl_module):
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None:
if self._should_skip_saving_checkpoint(trainer):
return

step = trainer.global_step
skip_step = self.every_n_steps < 1 or ((step + 1) % self.every_n_steps != 0)

now = time.monotonic()
time_interval = self.time_interval
prev_time_check = self._prev_time_check
skip_time = (
time_interval is None
or prev_time_check is None
or (now - prev_time_check) < time_interval.total_seconds()
)
if skip_step and skip_time:
return
if not skip_time:
self._prev_time_check = now

self._save_checkpoint(trainer, pl_module)


def on_validation_end(self, trainer, pl_module):
"""
checkpoints can be saved at the end of the val loop
"""
self.save_checkpoint(trainer, pl_module)
if self._should_skip_saving_checkpoint(trainer) or self.every_n_epochs < 0:
return
epoch = trainer.current_epoch
if (epoch + 1) % self.every_n_epochs == 0:
self.save_checkpoint(trainer, pl_module)

def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
return {
Expand All @@ -227,6 +258,9 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]

def _should_skip_saving_checkpoint(self, trainer) -> bool:
return (trainer.fast_dev_run or trainer.running_sanity_check or self.save_top_k == 0 or self.period < 1 or self._last_global_step_saved == global_step)

def save_checkpoint(self, trainer, pl_module):
"""
Performs the main logic around saving a checkpoint.
Expand Down

0 comments on commit aef0c39

Please sign in to comment.