Skip to content

Commit

Permalink
Add _pod_progress_bar_force_stdout
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 committed Jul 2, 2021
1 parent 6fda7d9 commit f1d0b65
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
Expand Down Expand Up @@ -152,9 +153,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
# set warning rank
rank_zero_only.rank = self.global_rank

if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print(' ', end='', flush=True)

if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()

Expand Down Expand Up @@ -285,21 +283,26 @@ def test_step(self, *args, **kwargs):
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def training_step_end(self, output):
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print(' ', end='', flush=True)
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def validation_step_end(self, output):
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print(' ', end='', flush=True)
def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def test_step_end(self, output):
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print(' ', end='', flush=True)
def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def _pod_progress_bar_force_stdout(self) -> None:
# Why is it required? The way `pytorch_xla.distributed` streams logs
# from different vms to the master worker doesn't work well with tqdm
# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140
# The print statement seems to force tqdm to flush stdout.
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print()

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.
Expand Down

0 comments on commit f1d0b65

Please sign in to comment.