Skip to content

Commit

Permalink
fix checkpointing on TPU
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-peters committed Jul 27, 2020
1 parent 3f2c102 commit ad0c2da
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
19 changes: 9 additions & 10 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,12 +266,12 @@ def on_train_start(self, trainer, pl_module):
trainer.ckpt_path = ckpt_path
trainer.weights_save_path = ckpt_path

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.global_rank != 0:
return

# To get checkpointing working on TPU, need to call _save_model
# for all ranks, to avoid deadlocks. Assuming save_function is mapped
# to trainer.save_checkpoint, this will also work on GPU as save_checkpoint
# handles rank==0 vs rank!=0 logic. If the user provides a custom
# save_function, they are responsible for adding rank==0 vs rank!=0 logic.
metrics = trainer.callback_metrics
epoch = trainer.current_epoch

Expand Down Expand Up @@ -326,8 +326,6 @@ def on_validation_end(self, trainer, pl_module):
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath, trainer, pl_module)

def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
Expand Down Expand Up @@ -358,6 +356,7 @@ def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath, trainer, pl_module)

for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
if trainer.global_rank == 0:
for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
30 changes: 22 additions & 8 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,22 +261,36 @@ def _atomic_save(self, checkpoint, filepath: str):
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)
if self.use_tpu:
xm.save(checkpoint, tmp_path, master_only=True, global_master=True)
if xm.is_master_ordinal(local=False):
os.replace(tmp_path, filepath)
else:
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)

def save_checkpoint(self, filepath, weights_only: bool = False):
checkpoint = self.dump_checkpoint(weights_only)

if self.is_global_zero:
def _do_save(chkpt):
# do the actual save
try:
self._atomic_save(checkpoint, filepath)
self._atomic_save(chkpt, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
del chkpt[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.'
f' An attribute is not picklable {err}')
self._atomic_save(checkpoint, filepath)
self._atomic_save(chkpt, filepath)

checkpoint = self.dump_checkpoint(weights_only)

# self._atomic_save has different behavior for XLA vs
# non-XLA. In XLA, it has a barrier and internal logic to only
# save for rank==0, so need to call for all ranks. For non-XLA,
# it doesn't have rank==0 logic so only call for rank==0
if self.use_tpu:
_do_save(checkpoint)
elif self.is_global_zero:
_do_save(checkpoint)

def restore(self, checkpoint_path: str, on_gpu: bool):
"""
Expand Down

0 comments on commit ad0c2da

Please sign in to comment.