diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 1af2ced94a674..5f8d9bf9dc6d3 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -15,23 +15,17 @@ import multiprocessing import platform from abc import ABC, abstractmethod -from distutils.version import LooseVersion from typing import Union, List, Tuple, Callable, Optional -import torch import torch.distributed as torch_distrib from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException -try: - from torch.utils.data import IterableDataset - ITERABLE_DATASET_EXISTS = True -except ImportError: - ITERABLE_DATASET_EXISTS = False try: from apex import amp @@ -55,35 +49,6 @@ HOROVOD_AVAILABLE = True -def _has_iterable_dataset(dataloader: DataLoader): - return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ - and isinstance(dataloader.dataset, IterableDataset) - - -def _has_len(dataloader: DataLoader) -> bool: - """ Checks if a given Dataloader has __len__ method implemented i.e. if - it is a finite dataloader or infinite dataloader. """ - - try: - # try getting the length - if len(dataloader) == 0: - raise ValueError('`Dataloader` returned 0 length.' - ' Please make sure that your Dataloader at least returns 1 batch') - has_len = True - except TypeError: - has_len = False - except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used - has_len = False - - if has_len and _has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): - rank_zero_warn( - 'Your `IterableDataset` has `__len__` defined.' - ' In combination with multi-processing data loading (e.g. batch size > 1),' - ' this can lead to unintended side effects since the samples will be duplicated.' - ) - return has_len - - class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, @@ -145,7 +110,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets - is_iterable_ds = _has_iterable_dataset(dataloader) + is_iterable_ds = has_iterable_dataset(dataloader) if not is_dataloader or is_iterable_ds: return dataloader @@ -209,7 +174,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: # automatically add samplers self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True) - self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf') + self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf') self._worker_check(self.train_dataloader, 'train dataloader') if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: @@ -233,7 +198,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None: f'to the number of the training batches ({self.num_training_batches}). ' 'If you want to disable validation set `limit_val_batches` to 0.0 instead.') else: - if not _has_len(self.train_dataloader): + if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: self.val_check_batch = float('inf') else: @@ -296,7 +261,7 @@ def _reset_eval_dataloader( # datasets could be none, 1 or 2+ if len(dataloaders) != 0: for i, dataloader in enumerate(dataloaders): - num_batches = len(dataloader) if _has_len(dataloader) else float('inf') + num_batches = len(dataloader) if has_len(dataloader) else float('inf') self._worker_check(dataloader, f'{mode} dataloader {i}') # percent or num_steps diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py new file mode 100644 index 0000000000000..996fe3082f755 --- /dev/null +++ b/pytorch_lightning/utilities/data.py @@ -0,0 +1,55 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion +import torch +from torch.utils.data import DataLoader + +from pytorch_lightning.utilities import rank_zero_warn + + +try: + from torch.utils.data import IterableDataset + ITERABLE_DATASET_EXISTS = True +except ImportError: + ITERABLE_DATASET_EXISTS = False + + +def has_iterable_dataset(dataloader: DataLoader): + return ITERABLE_DATASET_EXISTS and hasattr(dataloader, 'dataset') \ + and isinstance(dataloader.dataset, IterableDataset) + + +def has_len(dataloader: DataLoader) -> bool: + """ Checks if a given Dataloader has __len__ method implemented i.e. if + it is a finite dataloader or infinite dataloader. """ + + try: + # try getting the length + if len(dataloader) == 0: + raise ValueError('`Dataloader` returned 0 length.' + ' Please make sure that your Dataloader at least returns 1 batch') + has_len = True + except TypeError: + has_len = False + except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used + has_len = False + + if has_len and has_iterable_dataset(dataloader) and LooseVersion(torch.__version__) >= LooseVersion("1.4.0"): + rank_zero_warn( + 'Your `IterableDataset` has `__len__` defined.' + ' In combination with multi-processing data loading (e.g. batch size > 1),' + ' this can lead to unintended side effects since the samples will be duplicated.' + ) + return has_len diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 6bff4418c57f1..cdb0e84240aaa 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -11,7 +11,7 @@ import tests.base.develop_pipelines as tpipes from pytorch_lightning import Trainer, Callback -from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len +from pytorch_lightning.utilities.data import has_iterable_dataset, has_len from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -637,8 +637,8 @@ def __len__(self): return len(original_dataset) dataloader = DataLoader(IterableWithLen(), batch_size=16) - assert _has_len(dataloader) - assert _has_iterable_dataset(dataloader) + assert has_len(dataloader) + assert has_iterable_dataset(dataloader) trainer = Trainer( default_root_dir=tmpdir, max_steps=3,