From 4b24b445ce9a65b21ab3fc0361d95d879f01cedf Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 13 Oct 2021 13:13:42 +0100 Subject: [PATCH] 2/n Add Custom Data Loading Tutorial + API improvement. (#855) --- flash/__init__.py | 4 +- flash/core/data/datasets.py | 10 +- flash/core/data/new_data_module.py | 343 ++++++++++++++++++ flash/core/data/preprocess_transform.py | 17 + .../flash_components/custom_data_loading.py | 151 +++++++- tests/core/data/test_new_data_module.py | 135 +++++++ tests/examples/test_flash_components.py | 2 +- 7 files changed, 643 insertions(+), 19 deletions(-) create mode 100644 flash/core/data/new_data_module.py create mode 100644 tests/core/data/test_new_data_module.py diff --git a/flash/__init__.py b/flash/__init__.py index 3c2d496a2c..579a79b27c 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -22,8 +22,8 @@ from flash.core.data.callback import FlashCallback from flash.core.data.data_module import DataModule # noqa: E402 from flash.core.data.data_source import DataSource - from flash.core.data.datasets import FlashDataset, FlashIterableDataset - from flash.core.data.preprocess_transform import PreprocessTransform + from flash.core.data.datasets import FlashDataset, FlashIterableDataset # noqa: E402 + from flash.core.data.preprocess_transform import PreprocessTransform # noqa: E402 from flash.core.data.process import Postprocess, Preprocess, Serializer from flash.core.model import Task # noqa: E402 from flash.core.trainer import Trainer # noqa: E402 diff --git a/flash/core/data/datasets.py b/flash/core/data/datasets.py index c20d2a25ad..6fd37fbd60 100644 --- a/flash/core/data/datasets.py +++ b/flash/core/data/datasets.py @@ -24,12 +24,18 @@ from flash.core.data.properties import Properties from flash.core.registry import FlashRegistry +__all__ = [ + "BaseDataset", + "FlashDataset", + "FlashIterableDataset", +] + class BaseDataset(Properties): DATASET_KEY = "dataset" - transforms_registry: Optional[FlashRegistry] = None + transforms_registry: Optional[FlashRegistry] = FlashRegistry("transforms") transform: Optional[PreprocessTransform] = None @abstractmethod @@ -71,11 +77,13 @@ def running_stage(self, running_stage: RunningStage) -> None: @property def dataloader_collate_fn(self) -> Optional[Callable]: if self.transform: + self.transform.running_stage = self.running_stage return self.transform.dataloader_collate_fn @property def on_after_batch_transfer_fn(self) -> Optional[Callable]: if self.transform: + self.transform.running_stage = self.running_stage return self.transform.on_after_batch_transfer_fn def _resolve_functions(self, func_name: str, cls: Type["BaseDataset"]) -> None: diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py new file mode 100644 index 0000000000..c3d05a4089 --- /dev/null +++ b/flash/core/data/new_data_module.py @@ -0,0 +1,343 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Tuple, Type, TYPE_CHECKING, Union + +import pytorch_lightning as pl +import torch +from pytorch_lightning import LightningDataModule +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.enums import LightningEnum +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data import DataLoader +from torch.utils.data.dataset import IterableDataset +from torch.utils.data.sampler import Sampler + +import flash +from flash.core.data.base_viz import BaseVisualization +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_pipeline import DefaultPreprocess, Postprocess +from flash.core.data.datasets import BaseDataset +from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE, PreprocessTransform +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _FIFTYONE_AVAILABLE + +if _FIFTYONE_AVAILABLE and TYPE_CHECKING: + from fiftyone.core.collections import SampleCollection +else: + SampleCollection = None + + +class DataModule(DataModule): + """A basic DataModule class for all Flash tasks. This class includes references to a + :class:`~flash.core.data.datasets.BaseDataset` and a :class:`~flash.core.data.callback.BaseDataFetcher`. + + Args: + train_dataset: Dataset for training. Defaults to None. + val_dataset: Dataset for validating model performance during training. Defaults to None. + test_dataset: Dataset to test model performance. Defaults to None. + predict_dataset: Dataset for predicting. Defaults to None. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to attach to the + :class:`~flash.core.data.process.Preprocess`. If ``None``, the output from + :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` will be used. + val_split: An optional float which gives the relative amount of the training dataset to use for the validation + dataset. + batch_size: The batch size to be used by the DataLoader. Defaults to 1. + num_workers: The number of workers to use for parallelized loading. + Defaults to None which equals the number of available CPU threads, + or 0 for Windows or Darwin platform. + sampler: A sampler following the :class:`~torch.utils.data.sampler.Sampler` type. + Will be passed to the DataLoader for the training dataset. Defaults to None. + """ + + preprocess_cls = DefaultPreprocess + postprocess_cls = Postprocess + flash_datasets_registry = FlashRegistry("datasets") + + def __init__( + self, + train_dataset: Optional[BaseDataset] = None, + val_dataset: Optional[BaseDataset] = None, + test_dataset: Optional[BaseDataset] = None, + predict_dataset: Optional[BaseDataset] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + val_split: Optional[float] = None, + batch_size: Optional[int] = None, + num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, + pin_memory: bool = True, + persistent_workers: bool = True, + ) -> None: + + if not batch_size: + raise MisconfigurationException("The `batch_size` should be provided to the DataModule on instantiation.") + + if flash._IS_TESTING and torch.cuda.is_available(): + batch_size = 16 + + self._postprocess: Optional[Postprocess] = None + self._viz: Optional[BaseVisualization] = None + self._data_fetcher: Optional[BaseDataFetcher] = data_fetcher or self.configure_data_fetcher() + + self._train_ds = train_dataset + self._val_ds = val_dataset + self._test_ds = test_dataset + self._predict_ds = predict_dataset + + if self._train_ds and self._val_ds and isinstance(val_split, float) and val_split > 0: + raise MisconfigurationException( + "A `val_dataset` was provided with `val_split`. Please, choose one or the other." + ) + + if self._train_ds is not None and (val_split is not None and self._val_ds is None): + self._train_ds, self._val_ds = self._split_train_val(self._train_ds, val_split) + + if self._train_ds: + self.train_dataloader = self._train_dataloader + + if self._val_ds: + self.val_dataloader = self._val_dataloader + + if self._test_ds: + self.test_dataloader = self._test_dataloader + + if self._predict_ds: + self.predict_dataloader = self._predict_dataloader + + self.batch_size = batch_size + + if num_workers is None: + num_workers = 0 + self.num_workers = num_workers + self.persistent_workers = persistent_workers and num_workers > 0 + self.pin_memory = pin_memory + + self.sampler = sampler + + self.set_running_stages() + + LightningDataModule.__init__(self) + + def _train_dataloader(self) -> DataLoader: + train_ds: BaseDataset = self._train_ds + collate_fn = train_ds.dataloader_collate_fn + shuffle: bool = False + if isinstance(train_ds, IterableDataset): + drop_last = False + else: + drop_last = len(train_ds) > self.batch_size + + if self.sampler is None: + sampler = None + shuffle = not isinstance(train_ds, IterableDataset) + else: + sampler = self.sampler(train_ds) + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_train_dataset( + train_ds, + trainer=self.trainer, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn, + sampler=sampler, + ) + + return DataLoader( + train_ds, + batch_size=self.batch_size, + shuffle=shuffle, + sampler=sampler, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + drop_last=drop_last, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + ) + + def _val_dataloader(self) -> DataLoader: + val_ds: BaseDataset = self._val_ds + collate_fn = val_ds.dataloader_collate_fn + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_val_dataset( + val_ds, + trainer=self.trainer, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + ) + + return DataLoader( + val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + ) + + def _test_dataloader(self) -> DataLoader: + test_ds: BaseDataset = self._test_ds + collate_fn = test_ds.dataloader_collate_fn + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_test_dataset( + test_ds, + trainer=self.trainer, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + ) + + return DataLoader( + test_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + ) + + def _predict_dataloader(self) -> DataLoader: + predict_ds: BaseDataset = self._predict_ds + collate_fn = predict_ds.dataloader_collate_fn + + if isinstance(predict_ds, IterableDataset): + batch_size = self.batch_size + else: + batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_predict_dataset( + predict_ds, + batch_size=batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + ) + + return DataLoader( + predict_ds, + batch_size=batch_size, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + ) + + def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: + ds = None + if self.trainer.training: + ds = self._train_ds + elif self.trainer.validating: + ds = self._val_ds + elif self.trainer.testing: + ds = self._test_ds + elif self.trainer.predicting: + ds = self._predict_ds + + if ds: + transform = ds.on_after_batch_transfer_fn + batch = transform(batch) + + return batch + + @classmethod + def create_flash_datasets( + cls, + enum: Union[LightningEnum, str], + train_data: Optional[Any] = None, + val_data: Optional[Any] = None, + test_data: Optional[Any] = None, + predict_data: Optional[Any] = None, + train_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + val_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + test_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + predict_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + **flash_dataset_kwargs, + ) -> Tuple[Optional[BaseDataset]]: + cls._verify_flash_dataset_enum(enum) + flash_dataset_cls: BaseDataset = cls.flash_datasets_registry.get(enum) + return ( + cls._create_flash_dataset( + flash_dataset_cls, + train_data, + running_stage=RunningStage.TRAINING, + transform=train_transform, + **flash_dataset_kwargs, + ), + cls._create_flash_dataset( + flash_dataset_cls, + val_data, + running_stage=RunningStage.VALIDATING, + transform=val_transform, + **flash_dataset_kwargs, + ), + cls._create_flash_dataset( + flash_dataset_cls, + test_data, + running_stage=RunningStage.TESTING, + transform=test_transform, + **flash_dataset_kwargs, + ), + cls._create_flash_dataset( + flash_dataset_cls, + predict_data, + running_stage=RunningStage.PREDICTING, + transform=predict_transform, + **flash_dataset_kwargs, + ), + ) + + @staticmethod + def _create_flash_dataset( + flash_dataset_cls, + *load_data_args, + running_stage: RunningStage, + transform: Optional[PreprocessTransform], + **kwargs, + ) -> Optional[BaseDataset]: + if load_data_args[0] is not None: + return flash_dataset_cls.from_data( + *load_data_args, running_stage=running_stage, transform=transform, **kwargs + ) + + @classmethod + def _verify_flash_dataset_enum(cls, enum: LightningEnum) -> None: + if not cls.flash_datasets_registry or not isinstance(cls.flash_datasets_registry, FlashRegistry): + raise MisconfigurationException( + "The ``AutoContainer`` should have ``flash_datasets_registry`` (FlashRegistry) populated " + "with datasource class and ``default_flash_dataset_enum`` (LightningEnum) class attributes. " + ) + + if enum not in cls.flash_datasets_registry.available_keys(): + available_constructors = [ + f"from_{key.name.lower()}" for key in cls.flash_datasets_registry.available_keys() + ] + raise MisconfigurationException( + f"The ``AutoContainer`` ``flash_datasets_registry`` doesn't contain the associated {enum} " + f"HINT: Here are the available constructors {available_constructors}" + ) + + @classmethod + def register_flash_dataset(cls, enum: Union[str, LightningEnum], flash_dataset_cls: Type[BaseDataset]) -> None: + if cls.flash_datasets_registry is None: + raise MisconfigurationException("The class attribute `flash_datasets_registry` should be set. ") + cls.flash_datasets_registry(fn=flash_dataset_cls, name=enum) diff --git a/flash/core/data/preprocess_transform.py b/flash/core/data/preprocess_transform.py index 669df1ac99..4cc4f5b4ac 100644 --- a/flash/core/data/preprocess_transform.py +++ b/flash/core/data/preprocess_transform.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +from functools import partial, wraps from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from pytorch_lightning.trainer.states import RunningStage @@ -40,6 +41,17 @@ class PreprocessTransformPlacement(LightningEnum): PER_BATCH_TRANSFORM_ON_DEVICE = "per_batch_transform_on_device" +def transform_context(func: Callable, current_fn: str) -> Callable: + @wraps(func) + def wrapper(self, *args, **kwargs) -> Any: + self.current_fn = current_fn + result = func(self, *args, **kwargs) + self.current_fn = None + return result + + return wrapper + + class PreprocessTransform(Properties): def configure_transforms(self, *args, **kwargs) -> Dict[PreprocessTransformPlacement, Callable]: """The default transforms to use. @@ -112,11 +124,13 @@ def transforms(self) -> Dict[str, Optional[Dict[str, Callable]]]: "transform": self.transform, } + @partial(transform_context, current_fn="per_sample_transform") def per_sample_transform(self, sample: Any) -> Any: if isinstance(sample, list): return [self.current_transform(s) for s in sample] return self.current_transform(sample) + @partial(transform_context, current_fn="per_batch_transform") def per_batch_transform(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). @@ -125,6 +139,7 @@ def per_batch_transform(self, batch: Any) -> Any: """ return self.current_transform(batch) + @partial(transform_context, current_fn="collate") def collate(self, samples: Sequence, metadata=None) -> Any: """Transform to convert a sequence of samples to a collated batch.""" current_transform = self.current_transform @@ -144,6 +159,7 @@ def collate(self, samples: Sequence, metadata=None) -> Any: return collate_fn(samples, metadata) return collate_fn(samples) + @partial(transform_context, current_fn="per_sample_transform_on_device") def per_sample_transform_on_device(self, sample: Any) -> Any: """Transforms to apply to the data before the collation (per-sample basis). @@ -156,6 +172,7 @@ def per_sample_transform_on_device(self, sample: Any) -> Any: return [self.current_transform(s) for s in sample] return self.current_transform(sample) + @partial(transform_context, current_fn="per_batch_transform_on_device") def per_batch_transform_on_device(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). diff --git a/flash_examples/flash_components/custom_data_loading.py b/flash_examples/flash_components/custom_data_loading.py index 00bd3120ea..0ccbad904e 100644 --- a/flash_examples/flash_components/custom_data_loading.py +++ b/flash_examples/flash_components/custom_data_loading.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from contextlib import suppress from functools import partial -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torchvision.transforms as T from PIL import Image @@ -23,13 +24,13 @@ from flash import _PACKAGE_ROOT, FlashDataset, PreprocessTransform from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.new_data_module import DataModule +from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import download_data -from flash.core.registry import FlashRegistry seed_everything(42) -ROOT_DATA = f"{_PACKAGE_ROOT}/data" -download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", ROOT_DATA) +download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", f"{_PACKAGE_ROOT}/data") ############################################################################################# # Use Case: Load Data from multiple folders # @@ -47,7 +48,7 @@ ############################################################################################# -class CustomDataTransform(LightningEnum): +class DataTransform(LightningEnum): BASE = "base" RANDOM_ROTATION = "random_rotation" @@ -75,16 +76,13 @@ class CustomDataFormat(LightningEnum): # # ############################################################################################# -FOLDER_PATH = f"{ROOT_DATA}/hymenoptera_data/train" +FOLDER_PATH = f"{_PACKAGE_ROOT}/data/hymenoptera_data/train" TRAIN_FOLDERS = [os.path.join(FOLDER_PATH, "ants"), os.path.join(FOLDER_PATH, "bees")] VAL_FOLDERS = [os.path.join(FOLDER_PATH, "ants"), os.path.join(FOLDER_PATH, "bees")] PREDICT_FOLDER = os.path.join(FOLDER_PATH, "ants") class MultipleFoldersImageDataset(FlashDataset): - - transforms_registry = FlashRegistry("image_classification_transform") - def load_data(self, folders: List[str]) -> List[Dict[DefaultDataKeys, Any]]: if self.training: self.num_classes = len(folders) @@ -137,15 +135,15 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float # Register your transform within the Flash Dataset registry # Note: Registries can be shared by multiple dataset. -MultipleFoldersImageDataset.register_transform(CustomDataTransform.BASE, ImageBaseTransform) -MultipleFoldersImageDataset.register_transform(CustomDataTransform.RANDOM_ROTATION, ImageRandomRotationTransform) +MultipleFoldersImageDataset.register_transform(DataTransform.BASE, ImageBaseTransform) +MultipleFoldersImageDataset.register_transform(DataTransform.RANDOM_ROTATION, ImageRandomRotationTransform) MultipleFoldersImageDataset.register_transform( - CustomDataTransform.RANDOM_90_DEG_ROTATION, partial(ImageRandomRotationTransform, rotation=90) + DataTransform.RANDOM_90_DEG_ROTATION, partial(ImageRandomRotationTransform, rotation=90) ) train_dataset = MultipleFoldersImageDataset.from_train_data( TRAIN_FOLDERS, - transform=(CustomDataTransform.RANDOM_ROTATION, {"rotation": 45}), + transform=(DataTransform.RANDOM_ROTATION, {"rotation": 45}), ) print(train_dataset.transform) @@ -164,7 +162,7 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float train_dataset = MultipleFoldersImageDataset.from_train_data( TRAIN_FOLDERS, - transform=CustomDataTransform.RANDOM_90_DEG_ROTATION, + transform=DataTransform.RANDOM_90_DEG_ROTATION, ) print(train_dataset.transform) @@ -181,7 +179,7 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float # }, # ) -val_dataset = MultipleFoldersImageDataset.from_val_data(VAL_FOLDERS, transform=CustomDataTransform.BASE) +val_dataset = MultipleFoldersImageDataset.from_val_data(VAL_FOLDERS, transform=DataTransform.BASE) print(val_dataset.transform) # Out: # ImageClassificationRandomRotationTransform( @@ -199,3 +197,126 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float # : 0, # : (500, 375) # } + +############################################################################################# +# Step 4 / 5: Create a DataModule # +# # +# The `DataModule` class is a collection of FlashDataset and you can pass them directly to # +# its init function. # +# # +############################################################################################# + + +datamodule = DataModule( + train_dataset=MultipleFoldersImageDataset.from_train_data(TRAIN_FOLDERS, transform=DataTransform.RANDOM_ROTATION), + val_dataset=MultipleFoldersImageDataset.from_val_data(VAL_FOLDERS, transform=DataTransform.BASE), + predict_dataset=MultipleFoldersImageDataset.from_predict_data(PREDICT_FOLDER, transform=DataTransform.BASE), + batch_size=2, +) + + +assert isinstance(datamodule.train_dataset, FlashDataset) +assert isinstance(datamodule.predict_dataset, FlashDataset) + +# The ``num_classes`` value was set line 89. +assert datamodule.train_dataset.num_classes == 2 + +# The ``num_classes`` value was set only for training as `self.training` was used, +# so it doesn't exist for the predict_dataset +with suppress(AttributeError): + datamodule.val_dataset.num_classes + +# As test_data weren't provided, the test dataset is None. +assert not datamodule.test_dataset + + +print(datamodule.train_dataset[0]) +# Out: +# { +# : , +# : 0, +# : (500, 375) +# } + +assert isinstance(datamodule.predict_dataset, FlashDataset) +print(datamodule.predict_dataset[0]) +# out: +# { +# {: 'data/hymenoptera_data/train/ants/957233405_25c1d1187b.jpg'} +# } + + +# access the dataloader, the collate_fn will be injected directly within the dataloader from the provided transform +batch = next(iter(datamodule.train_dataloader())) +# Out: +# { +# : tensor([...]), +# : tensor([...]), +# : [(...), (...), ...], +# } +print(batch) + + +############################################################################################# +# Step 5 / 5: Provide your new utility with your DataModule # +# # +# The `DataModule` class is a collection of FlashDataset and you can pass them directly to # +# its init function. # +# # +############################################################################################# + + +class ImageClassificationDataModule(DataModule): + @classmethod + def from_multiple_folders( + cls, + train_folders: Optional[List[str]] = None, + val_folders: Optional[List[str]] = None, + test_folders: Optional[List[str]] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + val_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + test_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + predict_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + **data_module_kwargs: Any, + ) -> "ImageClassificationDataModule": + + return cls( + *cls.create_flash_datasets( + CustomDataFormat.MULTIPLE_FOLDERS, + train_folders, + val_folders, + test_folders, + predict_folder, + train_transform, + val_transform, + test_transform, + predict_transform, + ), + **data_module_kwargs, + ) + + +ImageClassificationDataModule.register_flash_dataset(CustomDataFormat.MULTIPLE_FOLDERS, MultipleFoldersImageDataset) + + +# Create the datamodule with your new constructor. This is purely equivalent to the previous datamdoule creation. +datamodule = ImageClassificationDataModule.from_multiple_folders( + train_folders=TRAIN_FOLDERS, + val_folders=VAL_FOLDERS, + predict_folder=PREDICT_FOLDER, + train_transform=DataTransform.RANDOM_ROTATION, + val_transform=DataTransform.BASE, + predict_transform=DataTransform.BASE, + batch_size=2, +) + +# access the dataloader, the collate_fn will be injected directly within the dataloader from the provided transform +batch = next(iter(datamodule.train_dataloader())) +# Out: +# { +# : tensor([...]), +# : tensor([...]), +# : [(...), (...), ...], +# } +print(batch) diff --git a/tests/core/data/test_new_data_module.py b/tests/core/data/test_new_data_module.py new file mode 100644 index 0000000000..ad7d0e24fb --- /dev/null +++ b/tests/core/data/test_new_data_module.py @@ -0,0 +1,135 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable + +import torch +from pytorch_lightning import seed_everything +from pytorch_lightning.trainer.states import RunningStage +from torch.utils.data.dataloader import default_collate + +from flash import Task, Trainer +from flash.core.data.datasets import FlashDataset +from flash.core.data.new_data_module import DataModule +from flash.core.data.preprocess_transform import PreprocessTransform + + +def test_data_module(): + seed_everything(42) + + def train_fn(data): + return data - 100 + + def val_fn(data): + return data + 100 + + def test_fn(data): + return data - 1000 + + def predict_fn(data): + return data + 1000 + + class TestDataset(FlashDataset): + pass + + class TestTransform(PreprocessTransform): + def configure_collate(self, *args, **kwargs) -> Callable: + return default_collate + + def configure_per_batch_transform_on_device(self) -> Callable: + if self.training: + return train_fn + elif self.validating: + return val_fn + elif self.testing: + return test_fn + elif self.predicting: + return predict_fn + + transform = TestTransform(running_stage=RunningStage.TRAINING) + assert transform.running_stage == RunningStage.TRAINING + train_dataset = TestDataset.from_train_data(range(10), transform=transform) + assert train_dataset.running_stage == RunningStage.TRAINING + + transform = TestTransform(running_stage=RunningStage.VALIDATING) + assert transform.running_stage == RunningStage.VALIDATING + val_dataset = TestDataset.from_val_data(range(10), transform=transform) + assert val_dataset.running_stage == RunningStage.VALIDATING + + transform = TestTransform(running_stage=RunningStage.TESTING) + assert transform.running_stage == RunningStage.TESTING + test_dataset = TestDataset.from_test_data(range(10), transform=transform) + assert test_dataset.running_stage == RunningStage.TESTING + + transform = TestTransform(running_stage=RunningStage.PREDICTING) + assert transform.running_stage == RunningStage.PREDICTING + predict_dataset = TestDataset.from_predict_data(range(10), transform=transform) + assert predict_dataset.running_stage == RunningStage.PREDICTING + + dm = DataModule( + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, + predict_dataset=predict_dataset, + batch_size=2, + ) + + batch = next(iter(dm.train_dataloader())) + assert batch.shape == torch.Size([2]) + assert batch.min() >= 0 and batch.max() < 10 + + class TestModel(Task): + def training_step(self, batch, batch_idx): + assert sum(batch < 0) == 2 + + def validation_step(self, batch, batch_idx): + assert sum(batch > 0) == 2 + + def test_step(self, batch, batch_idx): + assert sum(batch < 500) == 2 + + def predict_step(self, batch, batch_idx): + assert sum(batch > 500) == 2 + assert torch.equal(batch, torch.tensor([1000, 1001])) + + def on_train_dataloader(self) -> None: + pass + + def on_val_dataloader(self) -> None: + pass + + def on_test_dataloader(self, *_) -> None: + pass + + def on_predict_dataloader(self) -> None: + pass + + def on_predict_end(self) -> None: + pass + + def on_fit_end(self) -> None: + pass + + model = TestModel(torch.nn.Linear(1, 1)) + trainer = Trainer(fast_dev_run=True) + trainer.fit(model, dm) + trainer.validate(model, dm) + trainer.test(model, dm) + trainer.predict(model, dm) + + class CustomDataModule(DataModule): + pass + + CustomDataModule.register_flash_dataset("custom", TestDataset) + train_dataset, *_ = DataModule.create_flash_datasets("custom", range(10)) + assert train_dataset[0] == 0 diff --git a/tests/examples/test_flash_components.py b/tests/examples/test_flash_components.py index 71cfcf2f71..62933f7929 100644 --- a/tests/examples/test_flash_components.py +++ b/tests/examples/test_flash_components.py @@ -34,5 +34,5 @@ ), ], ) -def test_components(tmpdir, folder, file): +def test_components(folder, file): run_test(str(root / "flash_examples" / folder / file))