diff --git a/.gitignore b/.gitignore index 6717726144..f682da63dd 100644 --- a/.gitignore +++ b/.gitignore @@ -148,3 +148,4 @@ imdb xsum coco128 wmt_en_ro +kinetics diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 890c0a6661..dfcf213662 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -21,7 +21,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.nn import Module from torch.utils.data import DataLoader, Dataset -from torch.utils.data.dataset import Subset +from torch.utils.data.dataset import IterableDataset, Subset from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization @@ -138,8 +138,12 @@ def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None: def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]: iter_name = f"_{stage}_iter" + # num_workers has to be set to 0 to work properly + num_workers = self.num_workers + self.num_workers = 0 dataloader_fn = getattr(self, f"{stage}_dataloader") iterator = iter(dataloader_fn()) + self.num_workers = num_workers setattr(self, iter_name, iterator) return iterator @@ -191,7 +195,8 @@ def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, def def set_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, value: Any) -> None: if isinstance(dataset, Subset): dataset = dataset.dataset - setattr(dataset, attr_name, value) + if isinstance(dataset, (Dataset, IterableDataset)): + setattr(dataset, attr_name, value) def set_running_stages(self): if self._train_ds: diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 8e6ad5c8c7..5627c1b620 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -182,9 +182,6 @@ def post_tensor_transform(self, sample: Any) -> Any: def per_batch_transform(self, sample: Any) -> Any: return self.common_step(sample) - def per_sample_transform_on_device(self, sample: Any) -> Any: - return self.common_step(sample) - def per_batch_transform_on_device(self, sample: Any) -> Any: return self.common_step(sample) diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index b0d96a5252..43988ff2d6 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -181,3 +181,15 @@ def extract_data(data): assert dm.data_fetcher.show_collate_called assert dm.data_fetcher.per_batch_transform_called dm.data_fetcher.check_reset() + + +def test_data_loaders_num_workers_to_0(tmpdir): + """ + num_workers should be set to `0` internally for visualization and not for training. + """ + + datamodule = DataModule(train_dataset=range(10), num_workers=3) + iterator = datamodule._reset_iterator(RunningStage.TRAINING) + assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter) + iterator = iter(datamodule.train_dataloader()) + assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter)