Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

[bugfix] show_fix should set num_workers to 0 #226

Merged
merged 2 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,4 @@ imdb
xsum
coco128
wmt_en_ro
kinetics
9 changes: 7 additions & 2 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
from torch.utils.data.dataset import IterableDataset, Subset

from flash.data.auto_dataset import AutoDataset
from flash.data.base_viz import BaseVisualization
Expand Down Expand Up @@ -138,8 +138,12 @@ def data_fetcher(self, data_fetcher: BaseDataFetcher) -> None:

def _reset_iterator(self, stage: RunningStage) -> Iterable[Any]:
iter_name = f"_{stage}_iter"
# the num_workers has to be set to 0 to work properly.
num_workers = self.num_workers
self.num_workers = 0
dataloader_fn = getattr(self, f"{stage}_dataloader")
iterator = iter(dataloader_fn())
self.num_workers = num_workers
setattr(self, iter_name, iterator)
return iterator

Expand Down Expand Up @@ -191,7 +195,8 @@ def get_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, def
def set_dataset_attribute(dataset: torch.utils.data.Dataset, attr_name: str, value: Any) -> None:
if isinstance(dataset, Subset):
dataset = dataset.dataset
setattr(dataset, attr_name, value)
if isinstance(dataset, (Dataset, IterableDataset)):
setattr(dataset, attr_name, value)

def set_running_stages(self):
if self._train_ds:
Expand Down
12 changes: 12 additions & 0 deletions tests/data/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,15 @@ def extract_data(data):
assert dm.data_fetcher.show_collate_called
assert dm.data_fetcher.per_batch_transform_called
dm.data_fetcher.check_reset()


def test_data_loaders_num_workers_to_0(tmpdir):
"""
The num_workers should be set to `0` internally for visualization and not for training.
"""

datamodule = DataModule(train_dataset=range(10), num_workers=3)
iterator = datamodule._reset_iterator(RunningStage.TRAINING)
assert isinstance(iterator, torch.utils.data.dataloader._SingleProcessDataLoaderIter)
iterator = iter(datamodule.train_dataloader())
assert isinstance(iterator, torch.utils.data.dataloader._MultiProcessingDataLoaderIter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a small assert checking that internally the number of workers has been been kept as expected

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what _MultiProcessingDataLoaderIter does.