From 6fda7d9e47ce9eae6bf7ced12e26085671e2932a Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 2 Jul 2021 16:57:35 +0530 Subject: [PATCH] Fix progress bar updates for Pod Training --- .../plugins/training_type/tpu_spawn.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d7485d77d8267..0f628fd35a2d9 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -285,6 +285,21 @@ def test_step(self, *args, **kwargs): def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) + def training_step_end(self, output): + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print(' ', end='', flush=True) + return output + + def validation_step_end(self, output): + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print(' ', end='', flush=True) + return output + + def test_step_end(self, output): + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print(' ', end='', flush=True) + return output + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write.