Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix the progress bar for the sanity check #2892

2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed checkpointing to remote file paths ([#2925](https://github.com/PyTorchLightning/pytorch-lightning/pull/2925))

- Fixed the total steps of the progress bar for the validation sanity check ([#2892](https://github.com/PyTorchLightning/pytorch-lightning/pull/2892))

## [0.8.5] - 2020-07-09

### Added
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from tqdm import tqdm

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.data import has_len


class ProgressBarBase(Callback):
Expand Down Expand Up @@ -293,7 +294,9 @@ def init_test_tqdm(self) -> tqdm:
def on_sanity_check_start(self, trainer, pl_module):
super().on_sanity_check_start(trainer, pl_module)
self.val_progress_bar = self.init_sanity_tqdm()
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
self.val_progress_bar.total = sum(
min(trainer.num_sanity_val_steps, len(d) if has_len(d) else float('inf')) for d in trainer.val_dataloaders
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a quite common case, can't we add a function for this like

def len_or_default(to_be_checked: Any, default_length: int = int('inf')):
    if has_len(to_be_checked):
        return len(to_be_checked)
    return default_length

This may be an overhead now, but we really need similar things quite often

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a repeated code here. This is already done in reset_val_dataloader. All we need is just to sum num_sanity_val_steps here once #2917 is fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree with both of you. should we block this PR with 2917 or the other way around? Does it matter which one goes first?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest block this one. Once I get some answers there I asked, I'll fix that one tonight and then we can complete this one :)

)
Comment on lines +297 to +299
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli num_sanity_val_steps should be independent of limit_val_batches(float)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if num_sanity_val_steps=2, len(val_dataloader)=10 and limit_val_batches=0.1, should it run for 2 val_steps or 1?

Copy link
Contributor

@awaelchli awaelchli Aug 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this relevant here? I thought this pr is just about displaying the num_sanity steps that the trainer returns.
if limit_val_batches is used, it should just truncate the sanity steps if needed, no? This should happen in the trainer I think.

Copy link
Contributor

@rohitgr7 rohitgr7 Aug 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it still has some issues with limit_val_batches and I think a better fix would be to set up num_sanity_val_steps as a list in Trainer itself rather than doing it here, and simple we can do a sum to get total sanity val steps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that means

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest in case of num_sanity_val_steps == -1 it should be affected by limit_val_batches too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rohitgr7 I like your suggestions. It is true, the trainer should compute these properties and the progress bars should only read them (and maybe sum them).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I open another PR or keep this PR going? Should we use the same num_sanity_val_steps to save these values? (#2891 (comment))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am already working on it :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you.

self.main_progress_bar = tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, trainer, pl_module):
Expand Down
46 changes: 5 additions & 41 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
import multiprocessing
import platform
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Tuple, Callable, Optional

import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from torch.utils.data import IterableDataset
ITERABLE_DATASET_EXISTS = True
except ImportError:
ITERABLE_DATASET_EXISTS = False

try:
from apex import amp
except ImportError:
Expand All @@ -41,35 +34,6 @@
HOROVOD_AVAILABLE = True


def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)


def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader. """

try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False

if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len


class TrainerDataLoadingMixin(ABC):

# this is just a summary on variables used in this abstract class,
Expand Down Expand Up @@ -131,7 +95,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
is_iterable_ds = _has_iterable_dataset(dataloader)
is_iterable_ds = has_iterable_dataset(dataloader)

if not is_dataloader or is_iterable_ds:
return dataloader
Expand Down Expand Up @@ -195,7 +159,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
Expand All @@ -219,7 +183,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
if not has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
Expand Down Expand Up @@ -282,7 +246,7 @@ def _reset_eval_dataloader(
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
num_batches = len(dataloader) if has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')

# percent or num_steps
Expand Down
41 changes: 41 additions & 0 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from distutils.version import LooseVersion

import torch
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import rank_zero_warn

try:
from torch.utils.data import IterableDataset
ITERABLE_DATASET_EXISTS = True
except ImportError:
ITERABLE_DATASET_EXISTS = False


def has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)


def has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader. """

try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False

if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len
42 changes: 42 additions & 0 deletions tests/callbacks/test_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,45 @@ def on_test_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx

trainer.test(model)
assert progress_bar.test_batches_seen == progress_bar.total_test_batches


@pytest.mark.parametrize('num_sanity_val_steps,num_val_dataloaders_batches,expected_num_steps', [
(-1, [10], 10),
(0, [10], 0),
(2, [10], 2),
(10, [2], 2),
(10, [2, 3], 5),
(10, [20, 3], 13),
(10, [20, 30], 20),
(10, [float('inf')], 10),
(10, [1, float('inf')], 11),
])
def test_sanity_check_progress_bar_total(
tmpdir, num_sanity_val_steps, num_val_dataloaders_batches, expected_num_steps
):
"""Test that the sanity_check progress finishes with the correct total steps processed."""

tmp_model = EvalModelTemplate(batch_size=1)
batch_size = len(tmp_model.dataloader(train=False, num_samples=1).dataset)
model = EvalModelTemplate(batch_size=batch_size)

num_dataloaders = len(num_val_dataloaders_batches)
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=1,
limit_val_batches=len(model.dataloader(train=False)) * num_dataloaders,
max_epochs=0,
num_sanity_val_steps=num_sanity_val_steps,
)

val_dataloaders = []
for num_samples in num_val_dataloaders_batches:
if num_samples == float('inf'):
val_dataloaders.append(model.val_dataloader__infinite())
else:
val_dataloaders.append(
model.dataloader(train=False, num_samples=num_samples))
trainer.fit(model, val_dataloaders=val_dataloaders)

val_progress_bar = trainer.progress_bar_callback.val_progress_bar
assert getattr(val_progress_bar, 'total', 0) == expected_num_steps
34 changes: 1 addition & 33 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

import pytest
import torch
from packaging.version import parse
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset, Subset
from torch.utils.data.dataset import Subset
from torch.utils.data.distributed import DistributedSampler

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate

Expand Down Expand Up @@ -619,36 +617,6 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path):
trainer.test(**test_options)


@pytest.mark.xfail(
parse(torch.__version__) < parse("1.4.0"),
reason="IterableDataset with __len__ before 1.4 raises",
)
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset

class IterableWithLen(IterableDataset):

def __iter__(self):
return iter(original_dataset)

def __len__(self):
return len(original_dataset)

dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert _has_len(dataloader)
assert _has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
def test_dataloader_reinit_for_subclass(tmpdir):

Expand Down
39 changes: 39 additions & 0 deletions tests/utilities/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch
from packaging.version import parse
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import IterableDataset

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from tests.base import EvalModelTemplate


@pytest.mark.xfail(
parse(torch.__version__) < parse("1.4.0"),
reason="IterableDataset with __len__ before 1.4 raises",
)
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset

class IterableWithLen(IterableDataset):

def __iter__(self):
return iter(original_dataset)

def __len__(self):
return len(original_dataset)

dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])