diff --git a/CHANGELOG.md b/CHANGELOG.md index 93f3505eb1eb1..2011626e67bbc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579)) + ## [1.1.4] - 2021-01-12 diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 3ed5c11fd75d7..639a988bf3856 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -24,6 +24,8 @@ # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed +from typing import Optional, Union + if importlib.util.find_spec('ipywidgets') is not None: from tqdm.auto import tqdm else: @@ -306,7 +308,7 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches)) + reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): @@ -327,8 +329,7 @@ def on_epoch_start(self, trainer, pl_module): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(convert_inf(total_batches)) + reset(self.main_progress_bar, total_batches) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -342,7 +343,7 @@ def on_validation_start(self, trainer, pl_module): if not trainer.running_sanity_check: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() - self.val_progress_bar.total = convert_inf(self.total_val_batches) + reset(self.val_progress_bar, self.total_val_batches) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) @@ -362,7 +363,7 @@ def on_train_end(self, trainer, pl_module): def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) self.test_progress_bar = self.init_test_tqdm() - self.test_progress_bar.total = convert_inf(self.total_test_batches) + reset(self.test_progress_bar, self.total_test_batches) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) @@ -387,8 +388,14 @@ def _update_bar(self, bar): bar.update(delta) -def convert_inf(x): +def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: """ The tqdm doesn't support inf values. We have to convert it to None. """ if x == float('inf'): return None return x + + +def reset(bar: tqdm, total: Optional[int] = None) -> None: + """ Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """ + if not bar.disable: + bar.reset(total=convert_inf(total))