From a5c8d182329a3be88f17f952bee9cc063116c515 Mon Sep 17 00:00:00 2001 From: Iz Beltagy Date: Sat, 25 Jul 2020 03:57:48 +0000 Subject: [PATCH] fix https://github.com/PyTorchLightning/pytorch-lightning/issues/2700 --- pytorch_lightning/callbacks/model_checkpoint.py | 17 +++++++---------- pytorch_lightning/trainer/training_io.py | 12 +++++++++--- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f70d8d8d0a5e1..a14d4d2ef3a37 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -262,11 +262,8 @@ def on_train_start(self, trainer, pl_module): trainer.ckpt_path = ckpt_path trainer.weights_save_path = ckpt_path - @rank_zero_only def on_validation_end(self, trainer, pl_module): - # only run on main process - if trainer.global_rank != 0: - return + # run on all process and rely on trainer.save_checkpoint to save only from the global_rank==0 metrics = trainer.callback_metrics epoch = trainer.current_epoch @@ -306,7 +303,7 @@ def on_validation_end(self, trainer, pl_module): f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning ) elif self.check_monitor_top_k(current): - self._do_check_save(filepath, current, epoch) + self._do_check_save(trainer, filepath, current, epoch) elif self.verbose > 0: log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}') @@ -314,10 +311,9 @@ def on_validation_end(self, trainer, pl_module): if self.verbose > 0: log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') - assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' self._save_model(filepath) - def _do_check_save(self, filepath, current, epoch): + def _do_check_save(self, trainer, filepath, current, epoch): # remove kth del_list = [] @@ -345,6 +341,7 @@ def _do_check_save(self, filepath, current, epoch): f' {filepath} as top {self.save_top_k}') self._save_model(filepath) - for cur_path in del_list: - if cur_path != filepath: - self._del_model(cur_path) + if trainer.is_global_zero: + for cur_path in del_list: + if cur_path != filepath: + self._del_model(cur_path) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 3234800a62c51..6c80e98bc69eb 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -261,13 +261,19 @@ def _atomic_save(self, checkpoint, filepath: str): This points to the file that the checkpoint will be stored in. """ tmp_path = str(filepath) + ".part" - torch.save(checkpoint, tmp_path) - os.replace(tmp_path, filepath) + if self.use_tpu: + xm.save(checkpoint, tmp_path, master_only=True, global_master=True) + if self.is_global_zero: + os.replace(tmp_path, filepath) + else: + torch.save(checkpoint, tmp_path) + os.replace(tmp_path, filepath) def save_checkpoint(self, filepath, weights_only: bool = False): checkpoint = self.dump_checkpoint(weights_only) - if self.is_global_zero: + # with self.use_tpu, processes should call `xm.save` not just the one with global_rank==0 + if self.is_global_zero or self.use_tpu: # do the actual save try: self._atomic_save(checkpoint, filepath)