Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ibeltagy committed Jul 25, 2020
1 parent 209278b commit a5c8d18
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
17 changes: 7 additions & 10 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,8 @@ def on_train_start(self, trainer, pl_module):
trainer.ckpt_path = ckpt_path
trainer.weights_save_path = ckpt_path

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.global_rank != 0:
return
# run on all process and rely on trainer.save_checkpoint to save only from the global_rank==0

metrics = trainer.callback_metrics
epoch = trainer.current_epoch
Expand Down Expand Up @@ -306,18 +303,17 @@ def on_validation_end(self, trainer, pl_module):
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
)
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
self._do_check_save(trainer, filepath, current, epoch)
elif self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')

else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath)

def _do_check_save(self, filepath, current, epoch):
def _do_check_save(self, trainer, filepath, current, epoch):
# remove kth

del_list = []
Expand Down Expand Up @@ -345,6 +341,7 @@ def _do_check_save(self, filepath, current, epoch):
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)

for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
if trainer.is_global_zero:
for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
12 changes: 9 additions & 3 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,19 @@ def _atomic_save(self, checkpoint, filepath: str):
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)
if self.use_tpu:
xm.save(checkpoint, tmp_path, master_only=True, global_master=True)
if self.is_global_zero:
os.replace(tmp_path, filepath)
else:
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)

def save_checkpoint(self, filepath, weights_only: bool = False):
checkpoint = self.dump_checkpoint(weights_only)

if self.is_global_zero:
# with self.use_tpu, processes should call `xm.save` not just the one with global_rank==0
if self.is_global_zero or self.use_tpu:
# do the actual save
try:
self._atomic_save(checkpoint, filepath)
Expand Down

0 comments on commit a5c8d18

Please sign in to comment.