Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid deprecated progress_bar_refresh_rate usage #10520

Merged
merged 8 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461))


-
- Fixed issue where `trainer.progress_bar_callback` was not getting set with `enable_progress_bar=False` if a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))


## [1.5.1] - 2021-11-09
Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,9 @@ def on_trainer_init(
" bar pass `enable_progress_bar = False` to the Trainer."
)

if enable_progress_bar:
self.trainer._progress_bar_callback = self.configure_progress_bar(
progress_bar_refresh_rate, process_position
)
else:
self.trainer._progress_bar_callback = None
self.trainer._progress_bar_callback = self.configure_progress_bar(
progress_bar_refresh_rate, process_position, enable_progress_bar
)

# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary, weights_summary)
Expand Down Expand Up @@ -215,7 +212,9 @@ def _configure_swa_callbacks(self):
if not existing_swa:
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks

def configure_progress_bar(self, refresh_rate=None, process_position=0):
def configure_progress_bar(
self, refresh_rate: Optional[int] = None, process_position: int = 0, enable_progress_bar: bool = True
) -> Optional[ProgressBarBase]:
if os.getenv("COLAB_GPU") and refresh_rate is None:
# smaller refresh rate on colab causes crashes, choose a higher value
refresh_rate = 20
Expand All @@ -229,7 +228,12 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0):
)
if len(progress_bars) == 1:
progress_bar_callback = progress_bars[0]
elif refresh_rate > 0:
if not enable_progress_bar:
raise MisconfigurationException(
"Trainer was configured with `enable_progress_bar=False`"
f" but found `{progress_bar_callback.__class__.__name__}` in callbacks list."
)
elif refresh_rate > 0 and enable_progress_bar:
progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position)
self.trainer.callbacks.append(progress_bar_callback)
else:
Expand Down
89 changes: 40 additions & 49 deletions tests/callbacks/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import pickle
import sys
from typing import Optional, Union
from typing import Union
from unittest import mock
from unittest.mock import ANY, call, Mock

Expand All @@ -32,65 +32,54 @@


@pytest.mark.parametrize(
"callbacks,refresh_rate",
"kwargs",
[
([], None),
([], 1),
([], 2),
([TQDMProgressBar(refresh_rate=1)], 0),
([TQDMProgressBar(refresh_rate=2)], 0),
([TQDMProgressBar(refresh_rate=2)], 1),
# won't print but is still set
{"callbacks": TQDMProgressBar(refresh_rate=0)},
{"callbacks": TQDMProgressBar()},
{"progress_bar_refresh_rate": 1},
],
)
def test_tqdm_progress_bar_on(tmpdir, callbacks: list, refresh_rate: Optional[int]):
def test_tqdm_progress_bar_on(tmpdir, kwargs):
"""Test different ways the progress bar can be turned on."""

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
progress_bar_refresh_rate=refresh_rate,
max_epochs=1,
overfit_batches=5,
)
if "progress_bar_refresh_rate" in kwargs:
with pytest.deprecated_call(match=r"progress_bar_refresh_rate=.*` is deprecated"):
trainer = Trainer(default_root_dir=tmpdir, **kwargs)
else:
trainer = Trainer(default_root_dir=tmpdir, **kwargs)

progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
# Trainer supports only a single progress bar callback at the moment
assert len(progress_bars) == 1
assert progress_bars[0] is trainer.progress_bar_callback


@pytest.mark.parametrize(
"callbacks,refresh_rate,enable_progress_bar",
[([], 0, True), ([], False, True), ([ModelCheckpoint(dirpath="../trainer")], 0, True), ([], 1, False)],
)
def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool, int], enable_progress_bar: bool):
@pytest.mark.parametrize("kwargs", [{"enable_progress_bar": False}, {"progress_bar_refresh_rate": 0}])
def test_tqdm_progress_bar_off(tmpdir, kwargs):
"""Test different ways the progress bar can be turned off."""

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
progress_bar_refresh_rate=refresh_rate,
enable_progress_bar=enable_progress_bar,
)

progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)]
assert 0 == len(progress_bars)
assert not trainer.progress_bar_callback
if "progress_bar_refresh_rate" in kwargs:
pytest.deprecated_call(match=r"progress_bar_refresh_rate=.*` is deprecated").__enter__()
trainer = Trainer(default_root_dir=tmpdir, **kwargs)
progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)]
assert not len(progress_bars)


def test_tqdm_progress_bar_misconfiguration():
"""Test that Trainer doesn't accept multiple progress bars."""
# Trainer supports only a single progress bar callback at the moment
callbacks = [TQDMProgressBar(), TQDMProgressBar(), ModelCheckpoint(dirpath="../trainer")]
with pytest.raises(MisconfigurationException, match=r"^You added multiple progress bar callbacks"):
Trainer(callbacks=callbacks)

with pytest.raises(MisconfigurationException, match=r"enable_progress_bar=False` but found `TQDMProgressBar"):
Trainer(callbacks=TQDMProgressBar(), enable_progress_bar=False)


def test_tqdm_progress_bar_totals(tmpdir):
"""Test that the progress finishes with the correct total steps processed."""

model = BoringModel()

trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=1, max_epochs=1)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
bar = trainer.progress_bar_callback
assert float("inf") == bar.total_train_batches
assert 0 == bar.total_val_batches
Expand Down Expand Up @@ -209,14 +198,15 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
self.test_batches_seen += 1

progress_bar = CurrentProgressBar(refresh_rate=refresh_rate)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[progress_bar],
progress_bar_refresh_rate=101, # should not matter if custom callback provided
limit_train_batches=1.0,
num_sanity_val_steps=2,
max_epochs=3,
)
with pytest.deprecated_call(match=r"progress_bar_refresh_rate=101\)` is deprecated"):
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[progress_bar],
progress_bar_refresh_rate=101, # should not matter if custom callback provided
limit_train_batches=1.0,
num_sanity_val_steps=2,
max_epochs=3,
)
assert trainer.progress_bar_callback.refresh_rate == refresh_rate

trainer.fit(model)
Expand Down Expand Up @@ -276,20 +266,21 @@ def test_tqdm_progress_bar_default_value(tmpdir):
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.progress_bar_callback.refresh_rate == 1

trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None)
assert trainer.progress_bar_callback.refresh_rate == 1


@mock.patch.dict(os.environ, {"COLAB_GPU": "1"})
def test_tqdm_progress_bar_value_on_colab(tmpdir):
"""Test that Trainer will override the default in Google COLAB."""
trainer = Trainer(default_root_dir=tmpdir)
assert trainer.progress_bar_callback.refresh_rate == 20

trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None)
assert trainer.progress_bar_callback.refresh_rate == 20
trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar())
assert trainer.progress_bar_callback.refresh_rate == 1 # FIXME: should be 20

trainer = Trainer(default_root_dir=tmpdir, callbacks=TQDMProgressBar(refresh_rate=19))
assert trainer.progress_bar_callback.refresh_rate == 19

trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19)
with pytest.deprecated_call(match=r"progress_bar_refresh_rate=19\)` is deprecated"):
trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19)
assert trainer.progress_bar_callback.refresh_rate == 19


Expand Down
6 changes: 3 additions & 3 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ def val_dataloader(self):
max_epochs=1,
val_check_interval=val_check_interval,
num_sanity_val_steps=0,
progress_bar_refresh_rate=0,
enable_progress_bar=False,
)
trainer.fit(model)

Expand Down Expand Up @@ -834,7 +834,7 @@ def val_dataloader(self):
max_epochs=1,
val_check_interval=val_check_interval,
num_sanity_val_steps=0,
progress_bar_refresh_rate=0,
enable_progress_bar=False,
)
with pytest.raises(CustomException):
# will stop during validation
Expand Down Expand Up @@ -885,7 +885,7 @@ def val_dataloader(self):
max_epochs=1,
val_check_interval=val_check_interval,
num_sanity_val_steps=0,
progress_bar_refresh_rate=0,
enable_progress_bar=False,
)
trainer.fit(model, ckpt_path=ckpt_path)

Expand Down
6 changes: 4 additions & 2 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LearningRateMonitor,
ModelCheckpoint,
ModelSummary,
ProgressBarBase,
TQDMProgressBar,
)
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
Expand Down Expand Up @@ -143,10 +144,11 @@ def test_attach_model_callbacks():
def _attach_callbacks(trainer_callbacks, model_callbacks):
model = LightningModule()
model.configure_callbacks = lambda: model_callbacks
has_progress_bar = any(isinstance(cb, ProgressBarBase) for cb in trainer_callbacks + model_callbacks)
trainer = Trainer(
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=None,
enable_progress_bar=has_progress_bar,
enable_model_summary=False,
callbacks=trainer_callbacks,
)
trainer.model = model
Expand Down