Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix visual progress bar bug / properly reset progress bar #4579

Merged
merged 11 commits into from
Jan 14, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed


- Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579))


## [1.1.4] - 2021-01-12

Expand Down
19 changes: 13 additions & 6 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
from typing import Optional, Union

if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
else:
Expand Down Expand Up @@ -306,7 +308,7 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches))
reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches))
self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand All @@ -327,8 +329,7 @@ def on_epoch_start(self, trainer, pl_module):
val_checks_per_epoch = total_train_batches // trainer.val_check_batch
total_val_batches = total_val_batches * val_checks_per_epoch
total_batches = total_train_batches + total_val_batches
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(convert_inf(total_batches))
reset(self.main_progress_bar, total_batches)
self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}')

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand All @@ -342,7 +343,7 @@ def on_validation_start(self, trainer, pl_module):
if not trainer.running_sanity_check:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)
reset(self.val_progress_bar, self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
Expand All @@ -362,7 +363,7 @@ def on_train_end(self, trainer, pl_module):
def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)
reset(self.test_progress_bar, self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
Expand All @@ -387,8 +388,14 @@ def _update_bar(self, bar):
bar.update(delta)


def convert_inf(x):
def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]:
""" The tqdm doesn't support inf values. We have to convert it to None. """
if x == float('inf'):
return None
return x


def reset(bar: tqdm, total: Optional[int] = None) -> None:
""" Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """
if not bar.disable:
bar.reset(total=convert_inf(total))