From e292be2abf1f5c78ea44bfd4ed92fd6c8330361d Mon Sep 17 00:00:00 2001 From: Gerard Bentley Date: Fri, 13 Mar 2020 19:07:38 -0700 Subject: [PATCH] borda fixes --- pytorch_lightning/trainer/__init__.py | 3 +-- pytorch_lightning/trainer/trainer.py | 9 ++++++--- tests/test_deprecated.py | 9 --------- tests/trainer/test_trainer.py | 12 ++++++++++++ 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 96bb90b1726001..2d651c18cc1771 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -637,8 +637,7 @@ def on_train_end(self): trainer = Trainer(profiler=profiler) progress_bar_refresh_rate -^^^^^^ -^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^^^^^^^^^^ How often to refresh progress bar (in steps). In notebooks, faster refresh rates (lower number) is known to crash them because of their screen refresh rates, so raise it to 50 or more. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bfcc658e42c9e9..c79209f43ceb62 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -162,10 +162,10 @@ def __init__( log_gpu_memory: None, 'min_max', 'all'. Might slow performance show_progress_bar: - .. warning:: .. deprecated:: 0.7.0 + .. warning:: .. deprecated:: 0.7.2 Set `progress_bar_refresh_rate` to postive integer to enable. Will remove 0.9.0. - progress_bar_refresh_rate: How often to refresh progress bar (in steps). 0 to disable progress bar. + progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. overfit_pct: How much of training-, validation-, and test dataset to check. @@ -422,7 +422,6 @@ def __init__( " and this method will be removed in v0.8.0", DeprecationWarning) # can't init progress bar here because starting a new process # means the progress_bar won't survive pickling - self.show_progress_bar = self.progress_bar_refresh_rate >= 1 # logging self.log_save_interval = log_save_interval @@ -568,6 +567,10 @@ def from_argparse_args(cls, args): params = vars(args) return cls(**params) + @property + def show_progress_bar(self) -> bool: + return self.progress_bar_refresh_rate >= 1 + @property def num_gpus(self) -> int: gpus = self.data_parallel_device_ids diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index fbb6fd5b354951..2c4a6b9c75ec59 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -37,8 +37,6 @@ def test_tbd_remove_in_v0_8_0_trainer(): } # skip 0 since it may be interested as False kwargs = {k: (i + 1) for i, k in enumerate(mapping_old_new)} - kwargs['show_progress_bar'] = True - trainer = Trainer(**kwargs) for attr_old in mapping_old_new: @@ -47,13 +45,6 @@ def test_tbd_remove_in_v0_8_0_trainer(): 'Missing deprecated attribute "%s"' % attr_old assert kwargs[attr_old] == getattr(trainer, attr_new), \ 'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new) - attr_old = 'show_progress_bar' - attr_new = 'progress_bar_refresh_rate' - assert kwargs[attr_old] == getattr(trainer, attr_old), \ - 'Missing deprecated attribute "%s"' % attr_old - assert kwargs[attr_old] == bool(getattr(trainer, attr_new)), \ - 'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new) - def test_tbd_remove_in_v0_9_0_module_imports(): from pytorch_lightning.core.decorators import data_loader # noqa: F811 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1faadc67524db7..711dc3e9f494a6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -624,3 +624,15 @@ def test_epoch_end(self, outputs): model = LightningTestModel(hparams) Trainer().test(model) + + +def test_disable_progress_bar_arg(tmpdir): + """Tests disabling progress bar by setting refresh to 0""" + trainer = Trainer(progress_bar_refresh_rate=0) + assert not getattr(trainer, 'show_progress_bar') + + trainer = Trainer(progress_bar_refresh_rate=-1) + assert not getattr(trainer, 'show_progress_bar') + + trainer = Trainer(progress_bar_refresh_rate=50) + assert getattr(trainer, 'show_progress_bar')