diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d7485d77d8267..0f628fd35a2d9 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -285,6 +285,21 @@ def test_step(self, *args, **kwargs): def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) + def training_step_end(self, output): + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print(' ', end='', flush=True) + return output + + def validation_step_end(self, output): + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print(' ', end='', flush=True) + return output + + def test_step_end(self, output): + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print(' ', end='', flush=True) + return output + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write.