From dc0866b0a647da61462ccd66270f3ab3c73a3f37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 8 Nov 2020 16:52:48 +0100 Subject: [PATCH 1/5] reset --- pytorch_lightning/callbacks/progress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index de0c91f6983bd..c1b2e6321f488 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -307,7 +307,7 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches)) + self.val_progress_bar.reset(convert_inf(sum(trainer.num_sanity_val_batches))) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): @@ -342,7 +342,7 @@ def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) if not trainer.running_sanity_check: self.val_progress_bar = self.init_validation_tqdm() - self.val_progress_bar.total = convert_inf(self.total_val_batches) + self.val_progress_bar.reset(convert_inf(self.total_val_batches)) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) @@ -362,7 +362,7 @@ def on_train_end(self, trainer, pl_module): def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) self.test_progress_bar = self.init_test_tqdm() - self.test_progress_bar.total = convert_inf(self.total_test_batches) + self.test_progress_bar.reset(convert_inf(self.total_test_batches)) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) From 0b256c4df25c8526dd6c7a3799e02e05c9aeb09a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 9 Jan 2021 09:58:01 +0100 Subject: [PATCH 2/5] fix reset --- pytorch_lightning/callbacks/progress.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index a98db59942958..3e86b7d2db13a 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -24,6 +24,8 @@ # check if ipywidgets is installed before importing tqdm.auto # to ensure it won't fail and a progress bar is displayed +from typing import Optional, Union + if importlib.util.find_spec('ipywidgets') is not None: from tqdm.auto import tqdm else: @@ -306,7 +308,7 @@ def init_test_tqdm(self) -> tqdm: def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) self.val_progress_bar = self.init_sanity_tqdm() - self.val_progress_bar.reset(convert_inf(sum(trainer.num_sanity_val_batches))) + reset(self.val_progress_bar, sum(trainer.num_sanity_val_batches)) self.main_progress_bar = tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, trainer, pl_module): @@ -327,8 +329,7 @@ def on_epoch_start(self, trainer, pl_module): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - if not self.main_progress_bar.disable: - self.main_progress_bar.reset(convert_inf(total_batches)) + reset(self.main_progress_bar, total_batches) self.main_progress_bar.set_description(f'Epoch {trainer.current_epoch}') def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -342,7 +343,7 @@ def on_validation_start(self, trainer, pl_module): if not trainer.running_sanity_check: self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() - self.val_progress_bar.reset(convert_inf(self.total_val_batches)) + reset(self.val_progress_bar, self.total_val_batches) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) @@ -362,7 +363,7 @@ def on_train_end(self, trainer, pl_module): def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) self.test_progress_bar = self.init_test_tqdm() - self.test_progress_bar.reset(convert_inf(self.total_test_batches)) + reset(self.test_progress_bar, self.total_test_batches) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) @@ -387,8 +388,14 @@ def _update_bar(self, bar): bar.update(delta) -def convert_inf(x): +def convert_inf(x: Union[int, float]) -> Optional[int]: """ The tqdm doesn't support inf values. We have to convert it to None. """ if x == float('inf'): return None return x + + +def reset(bar: tqdm, total: Optional[int]) -> None: + """ Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """ + if not bar.disable: + bar.reset(total=convert_inf(total)) From b669c4ee7ed474d08fae0d86395820d63964ebd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 9 Jan 2021 10:06:32 +0100 Subject: [PATCH 3/5] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c32b93cc0dec..dd26330065bf3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195)) +- Fixed a visual bug in the progress bar display at the very first iteration ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579)) ## [1.1.3] - 2021-01-05 From fc5b9ed26d7da2fd0e2f667cc93143838be2bb17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 9 Jan 2021 10:58:40 +0100 Subject: [PATCH 4/5] update chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd26330065bf3..79f3ab348e20d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195)) -- Fixed a visual bug in the progress bar display at the very first iteration ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579)) +- Fixed a visual bug in the progress bar display initialization ([#4579](https://github.com/PyTorchLightning/pytorch-lightning/pull/4579)) ## [1.1.3] - 2021-01-05 From d561c9230d8d7ba2a5bc86ee40e0319ede349748 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 11 Jan 2021 20:23:30 +0100 Subject: [PATCH 5/5] typing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/callbacks/progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 3e86b7d2db13a..639a988bf3856 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -388,14 +388,14 @@ def _update_bar(self, bar): bar.update(delta) -def convert_inf(x: Union[int, float]) -> Optional[int]: +def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: """ The tqdm doesn't support inf values. We have to convert it to None. """ if x == float('inf'): return None return x -def reset(bar: tqdm, total: Optional[int]) -> None: +def reset(bar: tqdm, total: Optional[int] = None) -> None: """ Resets the tqdm bar to 0 progress with a new total, unless it is disabled. """ if not bar.disable: bar.reset(total=convert_inf(total))