Skip to content

Commit

Permalink
borda fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerard Bentley committed Mar 24, 2020
1 parent 96eb725 commit e292be2
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,8 +637,7 @@ def on_train_end(self):
trainer = Trainer(profiler=profiler)
progress_bar_refresh_rate
^^^^^^
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^
How often to refresh progress bar (in steps).
In notebooks, faster refresh rates (lower number) is known to crash them
because of their screen refresh rates, so raise it to 50 or more.
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,10 @@ def __init__(
log_gpu_memory: None, 'min_max', 'all'. Might slow performance
show_progress_bar:
.. warning:: .. deprecated:: 0.7.0
.. warning:: .. deprecated:: 0.7.2
Set `progress_bar_refresh_rate` to postive integer to enable. Will remove 0.9.0.
progress_bar_refresh_rate: How often to refresh progress bar (in steps). 0 to disable progress bar.
progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar.
overfit_pct: How much of training-, validation-, and test dataset to check.
Expand Down Expand Up @@ -422,7 +422,6 @@ def __init__(
" and this method will be removed in v0.8.0", DeprecationWarning)
# can't init progress bar here because starting a new process
# means the progress_bar won't survive pickling
self.show_progress_bar = self.progress_bar_refresh_rate >= 1

# logging
self.log_save_interval = log_save_interval
Expand Down Expand Up @@ -568,6 +567,10 @@ def from_argparse_args(cls, args):
params = vars(args)
return cls(**params)

@property
def show_progress_bar(self) -> bool:
return self.progress_bar_refresh_rate >= 1

@property
def num_gpus(self) -> int:
gpus = self.data_parallel_device_ids
Expand Down
9 changes: 0 additions & 9 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def test_tbd_remove_in_v0_8_0_trainer():
}
# skip 0 since it may be interested as False
kwargs = {k: (i + 1) for i, k in enumerate(mapping_old_new)}
kwargs['show_progress_bar'] = True

trainer = Trainer(**kwargs)

for attr_old in mapping_old_new:
Expand All @@ -47,13 +45,6 @@ def test_tbd_remove_in_v0_8_0_trainer():
'Missing deprecated attribute "%s"' % attr_old
assert kwargs[attr_old] == getattr(trainer, attr_new), \
'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new)
attr_old = 'show_progress_bar'
attr_new = 'progress_bar_refresh_rate'
assert kwargs[attr_old] == getattr(trainer, attr_old), \
'Missing deprecated attribute "%s"' % attr_old
assert kwargs[attr_old] == bool(getattr(trainer, attr_new)), \
'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new)


def test_tbd_remove_in_v0_9_0_module_imports():
from pytorch_lightning.core.decorators import data_loader # noqa: F811
Expand Down
12 changes: 12 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,15 @@ def test_epoch_end(self, outputs):

model = LightningTestModel(hparams)
Trainer().test(model)


def test_disable_progress_bar_arg(tmpdir):
"""Tests disabling progress bar by setting refresh to 0"""
trainer = Trainer(progress_bar_refresh_rate=0)
assert not getattr(trainer, 'show_progress_bar')

trainer = Trainer(progress_bar_refresh_rate=-1)
assert not getattr(trainer, 'show_progress_bar')

trainer = Trainer(progress_bar_refresh_rate=50)
assert getattr(trainer, 'show_progress_bar')

0 comments on commit e292be2

Please sign in to comment.