Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix progress bar updates for Pod Training #8258

Merged
merged 3 commits into from
Jul 5, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, tpu_distributed
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _TPU_AVAILABLE:
import torch_xla.core.xla_env_vars as xenv
Expand Down Expand Up @@ -282,6 +283,26 @@ def test_step(self, *args, **kwargs):
def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def validation_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def test_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
self._pod_progress_bar_force_stdout()
return output

def _pod_progress_bar_force_stdout(self) -> None:
# Why is it required? The way `pytorch_xla.distributed` streams logs
# from different vms to the master worker doesn't work well with tqdm
# Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140
# The print statement seems to force tqdm to flush stdout.
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1:
print()

def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
"""Save model/training states as a checkpoint file through state-dump and file-write.

Expand Down