diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 2a30ddce23841..f9bc1309f12fc 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -34,6 +34,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed +from pytorch_lightning.utilities.types import STEP_OUTPUT if _TPU_AVAILABLE: import torch_xla.core.xla_env_vars as xenv @@ -282,6 +283,26 @@ def test_step(self, *args, **kwargs): def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) + def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() + return output + + def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() + return output + + def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT: + self._pod_progress_bar_force_stdout() + return output + + def _pod_progress_bar_force_stdout(self) -> None: + # Why is it required? The way `pytorch_xla.distributed` streams logs + # from different vms to the master worker doesn't work well with tqdm + # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 + # The print statement seems to force tqdm to flush stdout. + if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + print() + def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: """Save model/training states as a checkpoint file through state-dump and file-write.