From f1d0b655314c5fa7702e2a9068ba05a3093dbaad Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sat, 3 Jul 2021 01:41:29 +0530 Subject: [PATCH] Add _pod_progress_bar_force_stdout --- .../plugins/training_type/tpu_spawn.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 0f628fd35a2d9..f9bc1309f12fc 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -34,6 +34,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.types import STEP_OUTPUT if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv @@ -152,9 +153,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: # set warning rank rank_zero_only.rank = self.global_rank - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: - print(' ', end='', flush=True) - if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() @@ -285,21 +283,26 @@ def test_step(self, *args, **kwargs): def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) - def training_step_end(self, output): - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: - print(' ', end='', flush=True) + def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() return output - def validation_step_end(self, output): - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: - print(' ', end='', flush=True) + def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() return output - def test_step_end(self, output): - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: - print(' ', end='', flush=True) + def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() return output + def _pod_progress_bar_force_stdout(self) -> None: + # Why is it required? The way `pytorch_xla.distributed` streams logs + # from different vms to the master worker doesn't work well with tqdm + # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 + # The print statement seems to force tqdm to flush stdout. + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print() + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write.