Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[bugfix] show_fix should set num_workers to 0 (#226)
Browse files Browse the repository at this point in the history
* resolve bug

* update
  • Loading branch information
tchaton authored Apr 19, 2021
1 parent a756e8b commit 42cc20a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ imdb
xsum
coco128
wmt_en_ro
kinetics
9 changes: 7 additions & 2 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
from torch.utils.data.dataset import IterableDataset, Subset

from flash.data.auto_dataset import AutoDataset
from flash.data.base_viz import BaseVisualization
Expand Down Expand Up @@ -138,8 +138,12 @@ def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None:

def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]:
iter_name = f"_{stage}_iter"
# num_workers has to be set to 0 to work properly
num_workers = self.num_workers
self.num_workers = 0
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
self.num_workers = num_workers
setattr(self, iter_name, iterator)
return iterator

Expand Down Expand Up @@ -191,7 +195,8 @@ def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, def
def set_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, value: Any) -> None:
if isinstance(dataset, Subset):
dataset = dataset.dataset
setattr(dataset, attr_name, value)
if isinstance(dataset, (Dataset, IterableDataset)):
setattr(dataset, attr_name, value)

def set_running_stages(self):
if self._train_ds:
Expand Down
3 changes: 0 additions & 3 deletions flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,6 @@ def post_tensor_transform(self, sample: Any) -> Any:
def per_batch_transform(self, sample: Any) -> Any:
return self.common_step(sample)

def per_sample_transform_on_device(self, sample: Any) -> Any:
return self.common_step(sample)

def per_batch_transform_on_device(self, sample: Any) -> Any:
return self.common_step(sample)

Expand Down
12 changes: 12 additions & 0 deletions tests/data/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,15 @@ def extract_data(data):
assert dm.data_fetcher.show_collate_called
assert dm.data_fetcher.per_batch_transform_called
dm.data_fetcher.check_reset()


def test_data_loaders_num_workers_to_0(tmpdir):
"""
num_workers should be set to `0` internally for visualization and not for training.
"""

datamodule = DataModule(train_dataset=range(10), num_workers=3)
iterator = datamodule._reset_iterator(RunningStage.TRAINING)
assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter)
iterator = iter(datamodule.train_dataloader())
assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter)

0 comments on commit 42cc20a

Please sign in to comment.