Skip to content

Commit

Permalink
Update the TQDM progress bar on_train_epoch_end (#11069)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
2 people authored and lexierule committed Dec 15, 2021
1 parent af11c11 commit 0bb9ce0
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 20 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))
- Fixed running sanity check with `RichProgressBar` ([#10913](https://github.com/PyTorchLightning/pytorch-lightning/pull/10913))
- Fixed support for `CombinedLoader` while checking for warning raised with eval dataloaders ([#10994](https://github.com/PyTorchLightning/pytorch-lightning/pull/10994))
- The TQDM progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069))
- Fixed bug where the TQDM updated the training progress bar during `trainer.validate` ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069))


## [1.5.5] - 2021-12-07
Expand Down
47 changes: 28 additions & 19 deletions pytorch_lightning/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
else:
from tqdm import tqdm as _tqdm

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress.base import ProgressBarBase

_PAD_SIZE = 5
Expand Down Expand Up @@ -206,12 +207,10 @@ def init_test_tqdm(self) -> Tqdm:
return bar

def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.main_progress_bar = Tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
super().on_sanity_check_end(trainer, pl_module)
self.main_progress_bar.close()
self.val_progress_bar.close()

Expand All @@ -233,49 +232,59 @@ def on_train_epoch_start(self, trainer, pl_module):

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
total_batches = self.total_train_batches + self.total_val_batches
total_batches = convert_inf(total_batches)
if self._should_update(self.train_batch_idx, total_batches):
if self._should_update(self.train_batch_idx):
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_enabled:
self._update_bar(self.main_progress_bar)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self.main_progress_bar.close()

def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if trainer.sanity_checking:
reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx)
else:
self._update_bar(self.main_progress_bar) # fill up remaining
if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
self._update_bar(self.main_progress_bar) # fill up remaining
self.val_progress_bar = self.init_validation_tqdm()
reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, convert_inf(self.total_val_batches)):
if self._should_update(self.val_batch_idx):
self._update_bar(self.val_progress_bar)
if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
self._update_bar(self.main_progress_bar)

def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_enabled:
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
if self.main_progress_bar is not None:
if self.main_progress_bar is not None and trainer.state.fn == pl.trainer.states.TrainerFn.FITTING:
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)
self.main_progress_bar.close()

def on_test_start(self, trainer, pl_module):
super().on_test_start(trainer, pl_module)
self.test_progress_bar = self.init_test_tqdm()
self.test_progress_bar.total = convert_inf(self.total_test_batches)

def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.test_batch_idx, self.total_test_batches):
if self._should_update(self.test_batch_idx):
self._update_bar(self.test_progress_bar)

def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_enabled:
self._update_bar(self.test_progress_bar)

def on_test_end(self, trainer, pl_module):
super().on_test_end(trainer, pl_module)
self.test_progress_bar.close()

def on_predict_epoch_start(self, trainer, pl_module):
Expand All @@ -285,7 +294,7 @@ def on_predict_epoch_start(self, trainer, pl_module):

def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.predict_batch_idx, self.total_predict_batches):
if self._should_update(self.predict_batch_idx):
self._update_bar(self.predict_progress_bar)

def on_predict_end(self, trainer, pl_module):
Expand All @@ -309,8 +318,8 @@ def print(
s = sep.join(map(str, args))
active_progress_bar.write(s, end=end, file=file, nolock=nolock)

def _should_update(self, current, total) -> bool:
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)
def _should_update(self, idx: int) -> bool:
return self.is_enabled and (idx % self.refresh_rate == 0)

def _update_bar(self, bar: Optional[Tqdm]) -> None:
"""Updates the bar by the refresh rate without overshooting."""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
r"""
.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of
`pytorch_lightning.callbacks.progress.base.get_standard_metrics` and will be removed in v1.7.
`pytorch_lightning.callbacks.progress.base.get_metrics` and will be removed in v1.7.
Implement this to override the default items displayed in the progress bar.
By default it includes the average loss value, split index of BPTT (if used)
Expand Down
63 changes: 63 additions & 0 deletions tests/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import pickle
import sys
from collections import defaultdict
from typing import Optional, Union
from unittest import mock
from unittest.mock import ANY, call, Mock
Expand Down Expand Up @@ -607,3 +608,65 @@ def test_tqdm_progress_bar_main_bar_resume():
# restarting mid validation epoch is not currently supported
assert bar.val_progress_bar.n == 0
assert bar.val_progress_bar.total == 3


def test_tqdm_progress_bar_correct_value_epoch_end(tmpdir):
class MockedProgressBar(TQDMProgressBar):
calls = defaultdict(list)

def get_metrics(self, trainer, pl_module):
items = super().get_metrics(trainer, model)
del items["v_num"]
del items["loss"]
# this is equivalent to mocking `set_postfix` as this method gets called every time
self.calls[trainer.state.fn].append(
(trainer.state.stage, trainer.current_epoch, trainer.global_step, items)
)
return items

class MyModel(BoringModel):
def training_step(self, batch, batch_idx):
self.log("a", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
self.log("b", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
return super().validation_step(batch, batch_idx)

def test_step(self, batch, batch_idx):
self.log("c", self.global_step, prog_bar=True, on_step=False, on_epoch=True, reduce_fx=max)
return super().test_step(batch, batch_idx)

model = MyModel()
pbar = MockedProgressBar()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=2,
enable_model_summary=False,
enable_checkpointing=False,
log_every_n_steps=1,
callbacks=pbar,
)

trainer.fit(model)
assert pbar.calls["fit"] == [
("sanity_check", 0, 0, {"b": 0}),
("train", 0, 0, {}),
("train", 0, 1, {}),
("validate", 0, 1, {"b": 1}), # validation end
# epoch end over, `on_epoch=True` metrics are computed
("train", 0, 2, {"a": 1, "b": 1}), # training epoch end
("train", 1, 2, {"a": 1, "b": 1}),
("train", 1, 3, {"a": 1, "b": 1}),
("validate", 1, 3, {"a": 1, "b": 3}), # validation end
("train", 1, 4, {"a": 3, "b": 3}), # training epoch end
]

trainer.validate(model, verbose=False)
assert pbar.calls["validate"] == []

trainer.test(model, verbose=False)
assert pbar.calls["test"] == []

0 comments on commit 0bb9ce0

Please sign in to comment.