Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow up of #2892 #3202

Merged
merged 3 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 5 additions & 40 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,17 @@
import multiprocessing
import platform
from abc import ABC, abstractmethod
from distutils.version import LooseVersion
from typing import Union, List, Tuple, Callable, Optional

import torch
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from torch.utils.data import IterableDataset
ITERABLE_DATASET_EXISTS = True
except ImportError:
ITERABLE_DATASET_EXISTS = False

try:
from apex import amp
Expand All @@ -55,35 +49,6 @@
HOROVOD_AVAILABLE = True


def _has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)


def _has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader. """

try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False

if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len


class TrainerDataLoadingMixin(ABC):

# this is just a summary on variables used in this abstract class,
Expand Down Expand Up @@ -145,7 +110,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
is_dataloader = isinstance(dataloader, DataLoader)
# don't manipulate iterable datasets
is_iterable_ds = _has_iterable_dataset(dataloader)
is_iterable_ds = has_iterable_dataset(dataloader)

if not is_dataloader or is_iterable_ds:
return dataloader
Expand Down Expand Up @@ -209,7 +174,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
# automatically add samplers
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)

self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')
self._worker_check(self.train_dataloader, 'train dataloader')

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
Expand All @@ -233,7 +198,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
f'to the number of the training batches ({self.num_training_batches}). '
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
else:
if not _has_len(self.train_dataloader):
if not has_len(self.train_dataloader):
if self.val_check_interval == 1.0:
self.val_check_batch = float('inf')
else:
Expand Down Expand Up @@ -296,7 +261,7 @@ def _reset_eval_dataloader(
# datasets could be none, 1 or 2+
if len(dataloaders) != 0:
for i, dataloader in enumerate(dataloaders):
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
num_batches = len(dataloader) if has_len(dataloader) else float('inf')
self._worker_check(dataloader, f'{mode} dataloader {i}')

# percent or num_steps
Expand Down
55 changes: 55 additions & 0 deletions pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from distutils.version import LooseVersion
import torch
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import rank_zero_warn


try:
from torch.utils.data import IterableDataset
ITERABLE_DATASET_EXISTS = True
except ImportError:
ITERABLE_DATASET_EXISTS = False


def has_iterable_dataset(dataloader: DataLoader):
return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \
and isinstance(dataloader.dataset, IterableDataset)


def has_len(dataloader: DataLoader) -> bool:
""" Checks if a given Dataloader has __len__ method implemented i.e. if
it is a finite dataloader or infinite dataloader. """

try:
# try getting the length
if len(dataloader) == 0:
raise ValueError('`Dataloader` returned 0 length.'
' Please make sure that your Dataloader at least returns 1 batch')
has_len = True
except TypeError:
has_len = False
except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used
has_len = False

if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"):
rank_zero_warn(
'Your `IterableDataset` has `__len__` defined.'
' In combination with multi-processing data loading (e.g. batch size > 1),'
' this can lead to unintended side effects since the samples will be duplicated.'
)
return has_len
8 changes: 4 additions & 4 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
from pytorch_lightning.utilities.data import has_iterable_dataset, has_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate

Expand Down Expand Up @@ -624,7 +624,7 @@ def test_warning_with_few_workers(mock, tmpdir, ckpt_path):
reason="IterableDataset with __len__ before 1.4 raises",
)
def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning messages is shown when an IterableDataset defines `__len__`. """
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
original_dataset = model.train_dataloader().dataset

Expand All @@ -637,8 +637,8 @@ def __len__(self):
return len(original_dataset)

dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert _has_len(dataloader)
assert _has_iterable_dataset(dataloader)
assert has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
Expand Down