Skip to content

Commit

Permalink
fix checkpointing on TPU
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-peters authored and Borda committed Aug 7, 2020
1 parent 234e2b5 commit fdac289
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
19 changes: 9 additions & 10 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,12 @@ def on_train_start(self, trainer, pl_module):
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
os.makedirs(self.dirpath, exist_ok=True)

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.global_rank != 0:
return

# To get checkpointing working on TPU, need to call _save_model
# for all ranks, to avoid deadlocks. Assuming save_function is mapped
# to trainer.save_checkpoint, this will also work on GPU as save_checkpoint
# handles rank==0 vs rank!=0 logic. If the user provides a custom
# save_function, they are responsible for adding rank==0 vs rank!=0 logic.
metrics = trainer.callback_metrics
epoch = trainer.current_epoch

Expand Down Expand Up @@ -336,8 +336,6 @@ def on_validation_end(self, trainer, pl_module):
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath, trainer, pl_module)

def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
Expand Down Expand Up @@ -368,6 +366,7 @@ def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath, trainer, pl_module)

for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
if trainer.global_rank == 0:
for cur_path in del_list:
if cur_path != filepath:
self._del_model(cur_path)
34 changes: 23 additions & 11 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,29 +269,41 @@ def _atomic_save(self, checkpoint, filepath: str):
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
if self.use_tpu:
xm.save(checkpoint, tmp_path, master_only=True, global_master=True)
if xm.is_master_ordinal(local=False):
os.replace(tmp_path, filepath)
# Can't use the new zipfile serialization for 1.6.0 because there's a bug in
# torch.hub.load_state_dict_from_url() that prevents it from loading the new files.
# More details can be found here: https://github.com/pytorch/pytorch/issues/42239
if LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
elif LooseVersion(torch.__version__).version[:3] == [1, 6, 0]:
torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False)
else:
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, filepath)

def save_checkpoint(self, filepath, weights_only: bool = False):
checkpoint = self.dump_checkpoint(weights_only)

if self.is_global_zero:
def _do_save(chkpt):
# do the actual save
try:
self._atomic_save(checkpoint, filepath)
self._atomic_save(chkpt, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `module_arguments` dropped from checkpoint.' f' An attribute is not picklable {err}'
)
self._atomic_save(checkpoint, filepath)
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in chkpt:
del chkpt[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.'
f' An attribute is not picklable {err}')
self._atomic_save(chkpt, filepath)

checkpoint = self.dump_checkpoint(weights_only)

# self._atomic_save has different behavior for XLA vs
# non-XLA. In XLA, it has a barrier and internal logic to only
# save for rank==0, so need to call for all ranks. For non-XLA,
# it doesn't have rank==0 logic so only call for rank==0
if self.use_tpu:
_do_save(checkpoint)
elif self.is_global_zero:
_do_save(checkpoint)

def restore(self, checkpoint_path: str, on_gpu: bool):
"""
Expand Down

0 comments on commit fdac289

Please sign in to comment.