diff --git a/CHANGELOG.md b/CHANGELOG.md index bbad7fb1d4be2..8cae35e4f247c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `monitor` argument in the `EarlyStopping` callback is no longer optional ([#10328](https://github.com/PyTorchLightning/pytorch-lightning/pull/10328)) -- +- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520)) - @@ -136,7 +136,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - - ## [1.5.1] - 2021-11-09 ### Fixed diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 4d41734ed90e6..6a54e973ffcf3 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -94,12 +94,9 @@ def on_trainer_init( " bar pass `enable_progress_bar = False` to the Trainer." ) - if enable_progress_bar: - self.trainer._progress_bar_callback = self.configure_progress_bar( - progress_bar_refresh_rate, process_position - ) - else: - self.trainer._progress_bar_callback = None + self.trainer._progress_bar_callback = self.configure_progress_bar( + progress_bar_refresh_rate, process_position, enable_progress_bar + ) # configure the ModelSummary callback self._configure_model_summary_callback(enable_model_summary, weights_summary) @@ -215,7 +212,9 @@ def _configure_swa_callbacks(self): if not existing_swa: self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks - def configure_progress_bar(self, refresh_rate=None, process_position=0): + def configure_progress_bar( + self, refresh_rate: Optional[int] = None, process_position: int = 0, enable_progress_bar: bool = True + ) -> Optional[ProgressBarBase]: if os.getenv("COLAB_GPU") and refresh_rate is None: # smaller refresh rate on colab causes crashes, choose a higher value refresh_rate = 20 @@ -229,7 +228,12 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0): ) if len(progress_bars) == 1: progress_bar_callback = progress_bars[0] - elif refresh_rate > 0: + if not enable_progress_bar: + raise MisconfigurationException( + "Trainer was configured with `enable_progress_bar=False`" + f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list." + ) + elif refresh_rate > 0 and enable_progress_bar: progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position) self.trainer.callbacks.append(progress_bar_callback) else: diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index 99fe02ce21a11..a8371591759d7 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -14,7 +14,7 @@ import os import pickle import sys -from typing import Optional, Union +from typing import Union from unittest import mock from unittest.mock import ANY, call, Mock @@ -32,65 +32,54 @@ @pytest.mark.parametrize( - "callbacks,refresh_rate", + "kwargs", [ - ([], None), - ([], 1), - ([], 2), - ([TQDMProgressBar(refresh_rate=1)], 0), - ([TQDMProgressBar(refresh_rate=2)], 0), - ([TQDMProgressBar(refresh_rate=2)], 1), + # won't print but is still set + {"callbacks": TQDMProgressBar(refresh_rate=0)}, + {"callbacks": TQDMProgressBar()}, + {"progress_bar_refresh_rate": 1}, ], ) -def test_tqdm_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]): +def test_tqdm_progress_bar_on(tmpdir, kwargs): """Test different ways the progress bar can be turned on.""" - - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=callbacks, - progress_bar_refresh_rate=refresh_rate, - max_epochs=1, - overfit_batches=5, - ) + if "progress_bar_refresh_rate" in kwargs: + with pytest.deprecated_call(match=r"progress_bar_refresh_rate=.*` is deprecated"): + trainer = Trainer(default_root_dir=tmpdir, **kwargs) + else: + trainer = Trainer(default_root_dir=tmpdir, **kwargs) progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] - # Trainer supports only a single progress bar callback at the moment assert len(progress_bars) == 1 assert progress_bars[0] is trainer.progress_bar_callback -@pytest.mark.parametrize( - "callbacks,refresh_rate,enable_progress_bar", - [([], 0, True), ([], False, True), ([ModelCheckpoint(dirpath="../trainer")], 0, True), ([], 1, False)], -) -def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool): +@pytest.mark.parametrize("kwargs", [{"enable_progress_bar": False}, {"progress_bar_refresh_rate": 0}]) +def test_tqdm_progress_bar_off(tmpdir, kwargs): """Test different ways the progress bar can be turned off.""" - - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=callbacks, - progress_bar_refresh_rate=refresh_rate, - enable_progress_bar=enable_progress_bar, - ) - - progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)] - assert 0 == len(progress_bars) - assert not trainer.progress_bar_callback + if "progress_bar_refresh_rate" in kwargs: + pytest.deprecated_call(match=r"progress_bar_refresh_rate=.*` is deprecated").__enter__() + trainer = Trainer(default_root_dir=tmpdir, **kwargs) + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] + assert not len(progress_bars) def test_tqdm_progress_bar_misconfiguration(): """Test that Trainer doesn't accept multiple progress bars.""" + # Trainer supports only a single progress bar callback at the moment callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")] with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"): Trainer(callbacks=callbacks) + with pytest.raises(MisconfigurationException, match=r"enable_progress_bar=False` but found `TQDMProgressBar"): + Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False) + def test_tqdm_progress_bar_totals(tmpdir): """Test that the progress finishes with the correct total steps processed.""" model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=1, max_epochs=1) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) bar = trainer.progress_bar_callback assert float("inf") == bar.total_train_batches assert 0 == bar.total_val_batches @@ -209,14 +198,15 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal self.test_batches_seen += 1 progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[progress_bar], - progress_bar_refresh_rate=101, # should not matter if custom callback provided - limit_train_batches=1.0, - num_sanity_val_steps=2, - max_epochs=3, - ) + with pytest.deprecated_call(match=r"progress_bar_refresh_rate=101\)` is deprecated"): + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[progress_bar], + progress_bar_refresh_rate=101, # should not matter if custom callback provided + limit_train_batches=1.0, + num_sanity_val_steps=2, + max_epochs=3, + ) assert trainer.progress_bar_callback.refresh_rate == refresh_rate trainer.fit(model) @@ -276,9 +266,6 @@ def test_tqdm_progress_bar_default_value(tmpdir): trainer = Trainer(default_root_dir=tmpdir) assert trainer.progress_bar_callback.refresh_rate == 1 - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) - assert trainer.progress_bar_callback.refresh_rate == 1 - @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) def test_tqdm_progress_bar_value_on_colab(tmpdir): @@ -286,10 +273,14 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): trainer = Trainer(default_root_dir=tmpdir) assert trainer.progress_bar_callback.refresh_rate == 20 - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) - assert trainer.progress_bar_callback.refresh_rate == 20 + trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar()) + assert trainer.progress_bar_callback.refresh_rate == 1 # FIXME: should be 20 + + trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar(refresh_rate=19)) + assert trainer.progress_bar_callback.refresh_rate == 19 - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19) + with pytest.deprecated_call(match=r"progress_bar_refresh_rate=19\)` is deprecated"): + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19) assert trainer.progress_bar_callback.refresh_rate == 19 diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index bad9a717d1629..320ffc3e68f34 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -796,7 +796,7 @@ def val_dataloader(self): max_epochs=1, val_check_interval=val_check_interval, num_sanity_val_steps=0, - progress_bar_refresh_rate=0, + enable_progress_bar=False, ) trainer.fit(model) @@ -834,7 +834,7 @@ def val_dataloader(self): max_epochs=1, val_check_interval=val_check_interval, num_sanity_val_steps=0, - progress_bar_refresh_rate=0, + enable_progress_bar=False, ) with pytest.raises(CustomException): # will stop during validation @@ -885,7 +885,7 @@ def val_dataloader(self): max_epochs=1, val_check_interval=val_check_interval, num_sanity_val_steps=0, - progress_bar_refresh_rate=0, + enable_progress_bar=False, ) trainer.fit(model, ckpt_path=ckpt_path) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 2cb68aa2e95bd..e3c353c3eb063 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -22,6 +22,7 @@ LearningRateMonitor, ModelCheckpoint, ModelSummary, + ProgressBarBase, TQDMProgressBar, ) from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector @@ -143,10 +144,11 @@ def test_attach_model_callbacks(): def _attach_callbacks(trainer_callbacks, model_callbacks): model = LightningModule() model.configure_callbacks = lambda: model_callbacks + has_progress_bar = any(isinstance(cb, ProgressBarBase) for cb in trainer_callbacks + model_callbacks) trainer = Trainer( enable_checkpointing=False, - enable_progress_bar=False, - enable_model_summary=None, + enable_progress_bar=has_progress_bar, + enable_model_summary=False, callbacks=trainer_callbacks, ) trainer.model = model