diff --git a/CHANGELOG.md b/CHANGELOG.md index c2bf09068ff61..ea1bfed1659c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,6 +112,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `DDPSpawnPlugin` no longer overrides the `post_dispatch` plugin hook ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) +- Integrate the progress bar implementation with progress tracking ([#11213](https://github.com/PyTorchLightning/pytorch-lightning/pull/11213)) + + - The `LightningModule.{add_to_queue,get_from_queue}` hooks no longer get a `torch.multiprocessing.SimpleQueue` and instead receive a list based queue ([#10034](https://github.com/PyTorchLightning/pytorch-lightning/pull/10034)) diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index cc66bb4a3a8e5..4808bd39629e1 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -47,52 +47,57 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx): """ def __init__(self): - - self._trainer = None - self._train_batch_idx = 0 - self._val_batch_idx = 0 - self._test_batch_idx = 0 - self._predict_batch_idx = 0 + self._trainer: Optional["pl.Trainer"] = None @property - def trainer(self): + def trainer(self) -> "pl.Trainer": return self._trainer @property def train_batch_idx(self) -> int: - """The current batch index being processed during training. + """The number of batches processed during training. Use this to update your progress bar. """ - return self._train_batch_idx + if self.trainer is None: + return 0 + return self.trainer.fit_loop.epoch_loop.batch_progress.current.processed @property def val_batch_idx(self) -> int: - """The current batch index being processed during validation. + """The number of batches processed during validation. Use this to update your progress bar. """ - return self._val_batch_idx + if self.trainer is None: + return 0 + if self.trainer.state.fn == "fit": + return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.current.processed + return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed @property def test_batch_idx(self) -> int: - """The current batch index being processed during testing. + """The number of batches processed during testing. Use this to update your progress bar. """ - return self._test_batch_idx + if self.trainer is None: + return 0 + return self.trainer.test_loop.epoch_loop.batch_progress.current.processed @property def predict_batch_idx(self) -> int: - """The current batch index being processed during predicting. + """The number of batches processed during prediction. Use this to update your progress bar. """ - return self._predict_batch_idx + if self.trainer is None: + return 0 + return self.trainer.predict_loop.epoch_loop.batch_progress.current.processed @property def total_train_batches(self) -> int: - """The total number of training batches during training, which may change from epoch to epoch. + """The total number of training batches, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the training dataloader is of infinite size. @@ -101,7 +106,7 @@ def total_train_batches(self) -> int: @property def total_val_batches(self) -> int: - """The total number of validation batches during validation, which may change from epoch to epoch. + """The total number of validation batches, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the validation dataloader is of infinite size. @@ -115,7 +120,7 @@ def total_val_batches(self) -> int: @property def total_test_batches(self) -> int: - """The total number of testing batches during testing, which may change from epoch to epoch. + """The total number of testing batches, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the test dataloader is of infinite size. @@ -124,7 +129,7 @@ def total_test_batches(self) -> int: @property def total_predict_batches(self) -> int: - """The total number of predicting batches during testing, which may change from epoch to epoch. + """The total number of prediction batches, which may change from epoch to epoch. Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the predict dataloader is of infinite size. @@ -155,33 +160,6 @@ def print(self, *args, **kwargs): def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None: self._trainer = trainer - def on_train_start(self, trainer, pl_module): - self._train_batch_idx = 0 - - def on_train_epoch_start(self, trainer, pl_module): - self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - self._train_batch_idx += 1 - - def on_validation_start(self, trainer, pl_module): - self._val_batch_idx = 0 - - def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._val_batch_idx += 1 - - def on_test_start(self, trainer, pl_module): - self._test_batch_idx = 0 - - def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._test_batch_idx += 1 - - def on_predict_epoch_start(self, trainer, pl_module): - self._predict_batch_idx = 0 - - def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self._predict_batch_idx += 1 - def get_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]: r""" Combines progress bar metrics collected from the trainer with standard metrics from get_standard_metrics. diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index f983ccdaab12e..46b5437934013 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -303,35 +303,28 @@ def refresh(self) -> None: self.progress.refresh() def on_train_start(self, trainer, pl_module): - super().on_train_start(trainer, pl_module) self._init_progress(trainer) def on_predict_start(self, trainer, pl_module): - super().on_predict_start(trainer, pl_module) self._init_progress(trainer) def on_test_start(self, trainer, pl_module): - super().on_test_start(trainer, pl_module) self._init_progress(trainer) def on_validation_start(self, trainer, pl_module): - super().on_validation_start(trainer, pl_module) self._init_progress(trainer) def on_sanity_check_start(self, trainer, pl_module): - super().on_sanity_check_start(trainer, pl_module) self._init_progress(trainer) self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description) self.refresh() def on_sanity_check_end(self, trainer, pl_module): - super().on_sanity_check_end(trainer, pl_module) if self.progress is not None: self.progress.update(self.val_sanity_progress_bar_id, advance=0, visible=False) self.refresh() def on_train_epoch_start(self, trainer, pl_module): - super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float("inf"): @@ -354,7 +347,6 @@ def on_train_epoch_start(self, trainer, pl_module): self.refresh() def on_validation_epoch_start(self, trainer, pl_module): - super().on_validation_epoch_start(trainer, pl_module) if self.total_val_batches > 0: total_val_batches = self.total_val_batches if self.total_train_batches != float("inf") and hasattr(trainer, "val_check_batch"): @@ -379,7 +371,6 @@ def _should_update(self, current: int, total: int) -> bool: return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): - super().on_validation_epoch_end(trainer, pl_module) if self.val_progress_bar_id is not None: self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False) @@ -388,18 +379,15 @@ def on_test_epoch_start(self, trainer, pl_module): self.refresh() def on_predict_epoch_start(self, trainer, pl_module): - super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) self._update_metrics(trainer, pl_module) self.refresh() def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if trainer.sanity_checking: self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) elif self.val_progress_bar_id is not None: @@ -410,12 +398,10 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches) self.refresh() diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index ea667c9ddf968..6b42b144cf2f9 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -135,6 +135,15 @@ def is_enabled(self) -> bool: def is_disabled(self) -> bool: return not self.is_enabled + @property + def _val_processed(self): + if self.trainer is None: + return 0 + if self.trainer.state.fn == "fit": + # use total in case validation runs more than once per training epoch + return self.trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.total.processed + return self.trainer.validate_loop.epoch_loop.batch_progress.current.processed + def disable(self) -> None: self._enabled = False @@ -216,11 +225,9 @@ def on_sanity_check_end(self, trainer, pl_module): self.val_progress_bar.close() def on_train_start(self, trainer, pl_module): - super().on_train_start(trainer, pl_module) self.main_progress_bar = self.init_train_tqdm() def on_train_epoch_start(self, trainer, pl_module): - super().on_train_epoch_start(trainer, pl_module) total_train_batches = self.total_train_batches total_val_batches = self.total_val_batches if total_train_batches != float("inf") and total_val_batches != float("inf"): @@ -228,43 +235,37 @@ def on_train_epoch_start(self, trainer, pl_module): val_checks_per_epoch = total_train_batches // trainer.val_check_batch total_val_batches = total_val_batches * val_checks_per_epoch total_batches = total_train_batches + total_val_batches - reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx) + self.main_progress_bar.total = convert_inf(total_batches) self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) if self._should_update(self.train_batch_idx): - self._update_bar(self.main_progress_bar) + _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.is_enabled: - self._update_bar(self.main_progress_bar) + _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) + if not self.main_progress_bar.disable: self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self.main_progress_bar.close() def on_validation_start(self, trainer, pl_module): - super().on_validation_start(trainer, pl_module) if trainer.sanity_checking: - reset(self.val_progress_bar, total=sum(trainer.num_sanity_val_batches), current=self.val_batch_idx) + self.val_progress_bar.total = sum(trainer.num_sanity_val_batches) else: - if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: - self._update_bar(self.main_progress_bar) # fill up remaining self.val_progress_bar = self.init_validation_tqdm() - reset(self.val_progress_bar, total=self.total_val_batches, current=self.val_batch_idx) + self.val_progress_bar.total = convert_inf(self.total_val_batches) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.val_batch_idx): - self._update_bar(self.val_progress_bar) + _update_n(self.val_progress_bar, self.val_batch_idx) if trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: - self._update_bar(self.main_progress_bar) + _update_n(self.main_progress_bar, self.train_batch_idx + self._val_processed) def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.is_enabled: - self._update_bar(self.val_progress_bar) + _update_n(self.val_progress_bar, self._val_processed) def on_validation_end(self, trainer, pl_module): if self.main_progress_bar is not None and trainer.state.fn == pl.trainer.states.TrainerFn.FITTING: @@ -272,31 +273,26 @@ def on_validation_end(self, trainer, pl_module): self.val_progress_bar.close() def on_test_start(self, trainer, pl_module): - super().on_test_start(trainer, pl_module) self.test_progress_bar = self.init_test_tqdm() self.test_progress_bar.total = convert_inf(self.total_test_batches) def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.test_batch_idx): - self._update_bar(self.test_progress_bar) + _update_n(self.test_progress_bar, self.test_batch_idx) def on_test_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self.is_enabled: - self._update_bar(self.test_progress_bar) + _update_n(self.test_progress_bar, self.test_batch_idx) def on_test_end(self, trainer, pl_module): self.test_progress_bar.close() def on_predict_epoch_start(self, trainer, pl_module): - super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar = self.init_predict_tqdm() self.predict_progress_bar.total = convert_inf(self.total_predict_batches) def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.predict_batch_idx): - self._update_bar(self.predict_progress_bar) + _update_n(self.predict_progress_bar, self.predict_batch_idx) def on_predict_end(self, trainer, pl_module): self.predict_progress_bar.close() @@ -320,19 +316,7 @@ def print( active_progress_bar.write(s, end=end, file=file, nolock=nolock) def _should_update(self, idx: int) -> bool: - return self.is_enabled and (idx % self.refresh_rate == 0) - - def _update_bar(self, bar: Optional[Tqdm]) -> None: - """Updates the bar by the refresh rate without overshooting.""" - if bar is None: - return - if bar.total is not None: - delta = min(self.refresh_rate, bar.total - bar.n) - else: - # infinite / unknown size - delta = self.refresh_rate - if delta > 0: - bar.update(delta) + return self.refresh_rate and idx % self.refresh_rate == 0 @staticmethod def _resolve_refresh_rate(refresh_rate: int) -> int: @@ -353,8 +337,7 @@ def convert_inf(x: Optional[Union[int, float]]) -> Optional[Union[int, float]]: return x -def reset(bar: Tqdm, total: Optional[int] = None, current: int = 0) -> None: - """Resets the tqdm bar to the desired position and sets a new total, unless it is disabled.""" +def _update_n(bar: _tqdm, value: int) -> None: if not bar.disable: - bar.reset(total=convert_inf(total)) - bar.n = current + bar.n = value + bar.refresh() diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index e22bb62126188..00f325976182a 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -17,7 +17,7 @@ from collections import defaultdict from typing import Union from unittest import mock -from unittest.mock import ANY, call, Mock +from unittest.mock import ANY, call import pytest import torch @@ -172,27 +172,16 @@ class CurrentProgressBar(TQDMProgressBar): val_batches_seen = 0 test_batches_seen = 0 - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - super().on_train_batch_start(trainer, pl_module, batch, batch_idx) - assert self.train_batch_idx == trainer.fit_loop.batch_idx - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) - assert self.train_batch_idx == trainer.fit_loop.batch_idx + 1 - if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: - assert self.main_progress_bar.n == self.train_batch_idx self.train_batches_seen += 1 def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if not self.is_disabled and self.val_batch_idx % self.refresh_rate == 0: - assert self.val_progress_bar.n == self.val_batch_idx self.val_batches_seen += 1 def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - if not self.is_disabled and self.test_batch_idx % self.refresh_rate == 0: - assert self.test_progress_bar.n == self.test_batch_idx self.test_batches_seen += 1 progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) @@ -282,40 +271,40 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): assert trainer.progress_bar_callback.refresh_rate == 19 -class MockedUpdateProgressBars(TQDMProgressBar): - """Mocks the update method once bars get initializied.""" - - def _mock_bar_update(self, bar): - bar.update = Mock(wraps=bar.update) - return bar +class MockTqdm(Tqdm): + def __init__(self, *args, **kwargs): + self.n_values = [] + super().__init__(*args, **kwargs) + self.__n = 0 + # again to reset additions from `super().__init__` + self.n_values = [] - def init_train_tqdm(self): - bar = super().init_train_tqdm() - return self._mock_bar_update(bar) + @property + def n(self): + return self.__n - def init_validation_tqdm(self): - bar = super().init_validation_tqdm() - return self._mock_bar_update(bar) - - def init_test_tqdm(self): - bar = super().init_test_tqdm() - return self._mock_bar_update(bar) + @n.setter + def n(self, value): + self.__n = value + # track the changes in the `n` value + if not len(self.n_values) or value != self.n_values[-1]: + self.n_values.append(value) @pytest.mark.parametrize( - "train_batches,val_batches,refresh_rate,train_deltas,val_deltas", + "train_batches,val_batches,refresh_rate,train_updates,val_updates", [ - [2, 3, 1, [1, 1, 1, 1, 1], [1, 1, 1]], - [0, 0, 3, [], []], - [1, 0, 3, [1], []], + [2, 3, 1, [1, 2, 3, 4, 5], [1, 2, 3]], + [0, 0, 3, None, None], + [1, 0, 3, [1], None], [1, 1, 3, [2], [1]], - [5, 0, 3, [3, 2], []], - [5, 2, 3, [3, 3, 1], [2]], - [5, 2, 6, [6, 1], [2]], + [5, 0, 3, [3, 5], None], + [5, 2, 3, [3, 7], [2]], + [5, 2, 6, [7], [2]], ], ) def test_main_progress_bar_update_amount( - tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_deltas: list, val_deltas: list + tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_updates, val_updates ): """Test that the main progress updates with the correct amount together with the val progress. @@ -323,7 +312,7 @@ def test_main_progress_bar_update_amount( rate. """ model = BoringModel() - progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) + progress_bar = TQDMProgressBar(refresh_rate=refresh_rate) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -333,18 +322,19 @@ def test_main_progress_bar_update_amount( logger=False, enable_checkpointing=False, ) - trainer.fit(model) + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model) if train_batches > 0: - progress_bar.main_progress_bar.update.assert_has_calls([call(delta) for delta in train_deltas]) + assert progress_bar.main_progress_bar.n_values == train_updates if val_batches > 0: - progress_bar.val_progress_bar.update.assert_has_calls([call(delta) for delta in val_deltas]) + assert progress_bar.val_progress_bar.n_values == val_updates -@pytest.mark.parametrize("test_batches,refresh_rate,test_deltas", [[1, 3, [1]], [3, 1, [1, 1, 1]], [5, 3, [3, 2]]]) -def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, test_deltas: list): +@pytest.mark.parametrize("test_batches,refresh_rate,updates", [[1, 3, [1]], [3, 1, [1, 2, 3]], [5, 3, [3, 5]]]) +def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate: int, updates: list): """Test that test progress updates with the correct amount.""" model = BoringModel() - progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate) + progress_bar = TQDMProgressBar(refresh_rate=refresh_rate) trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -353,8 +343,9 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate logger=False, enable_checkpointing=False, ) - trainer.test(model) - progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas]) + with mock.patch("pytorch_lightning.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.test(model) + assert progress_bar.test_progress_bar.n_values == updates def test_tensor_to_float_conversion(tmpdir): @@ -568,36 +559,6 @@ def get_metrics(self, trainer: Trainer, model: LightningModule): assert "v_num" not in standard_metrics.keys() -def test_tqdm_progress_bar_main_bar_resume(): - """Test that the progress bar can resume its counters based on the Trainer state.""" - bar = TQDMProgressBar() - trainer = Mock() - model = Mock() - - trainer.sanity_checking = False - trainer.check_val_every_n_epoch = 1 - trainer.current_epoch = 1 - trainer.num_training_batches = 5 - trainer.val_check_batch = 5 - trainer.num_val_batches = [3] - trainer.fit_loop.epoch_loop.batch_progress.current.completed = 3 - - bar.setup(trainer, model) - bar.on_train_start(trainer, model) - bar.on_train_epoch_start(trainer, model) - - assert bar.main_progress_bar.n == 3 - assert bar.main_progress_bar.total == 8 - - # bar.on_train_epoch_end(trainer, model) - bar.on_validation_start(trainer, model) - bar.on_validation_epoch_start(trainer, model) - - # restarting mid validation epoch is not currently supported - assert bar.val_progress_bar.n == 0 - assert bar.val_progress_bar.total == 3 - - def test_tqdm_progress_bar_correct_value_epoch_end(tmpdir): class MockedProgressBar(TQDMProgressBar): calls = defaultdict(list)