From b22e78604f3ed77a323c7572341774a6dc43bfc2 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 14 Oct 2021 16:05:58 +0100 Subject: [PATCH] Rename PreprocessTransform to InputTransform (#868) --- CHANGELOG.md | 4 + flash/__init__.py | 6 +- flash/core/data/datasets.py | 32 ++--- ...rocess_transform.py => input_transform.py} | 84 +++++++------ flash/core/data/new_data_module.py | 12 +- .../flash_components/custom_data_loading.py | 82 +++++-------- tests/core/data/test_data_pipeline.py | 12 +- ...s_transform.py => test_input_transform.py} | 116 +++++++++--------- tests/core/data/test_new_data_module.py | 4 +- 9 files changed, 172 insertions(+), 180 deletions(-) rename flash/core/data/{preprocess_transform.py => input_transform.py} (85%) rename tests/core/data/{test_preprocess_transform.py => test_input_transform.py} (53%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8038acec31..45e8911035 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759)) + +- Changed `PreprocessTransform` to `InputTransform` ([#868](https://github.com/PyTorchLightning/lightning-flash/pull/868)) + + ### Fixed - Fixed a bug where additional kwargs (e.g. sampler) passed to tabular data would be ignored ([#792](https://github.com/PyTorchLightning/lightning-flash/pull/792)) diff --git a/flash/__init__.py b/flash/__init__.py index 579a79b27c..4b39185dad 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -22,8 +22,8 @@ from flash.core.data.callback import FlashCallback from flash.core.data.data_module import DataModule # noqa: E402 from flash.core.data.data_source import DataSource - from flash.core.data.datasets import FlashDataset, FlashIterableDataset # noqa: E402 - from flash.core.data.preprocess_transform import PreprocessTransform # noqa: E402 + from flash.core.data.datasets import FlashDataset, FlashIterableDataset + from flash.core.data.input_transform import InputTransform from flash.core.data.process import Postprocess, Preprocess, Serializer from flash.core.model import Task # noqa: E402 from flash.core.trainer import Trainer # noqa: E402 @@ -45,7 +45,7 @@ "FlashDataset", "FlashIterableDataset", "Preprocess", - "PreprocessTransform", + "InputTransform", "Postprocess", "Serializer", "Task", diff --git a/flash/core/data/datasets.py b/flash/core/data/datasets.py index 6fd37fbd60..fd9471da5d 100644 --- a/flash/core/data/datasets.py +++ b/flash/core/data/datasets.py @@ -20,7 +20,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset, IterableDataset -from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE, PreprocessTransform +from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.data.properties import Properties from flash.core.registry import FlashRegistry @@ -35,8 +35,8 @@ class BaseDataset(Properties): DATASET_KEY = "dataset" - transforms_registry: Optional[FlashRegistry] = FlashRegistry("transforms") - transform: Optional[PreprocessTransform] = None + input_transforms_registry: Optional[FlashRegistry] = FlashRegistry("transforms") + transform: Optional[InputTransform] = None @abstractmethod def load_data(self, data: Any) -> Union[Iterable, Mapping]: @@ -49,12 +49,12 @@ def load_data(self, data: Any) -> Union[Iterable, Mapping]: def load_sample(self, data: Any) -> Any: """The `load_sample` hook contains the logic to load a single sample.""" - def __init__(self, running_stage: RunningStage, transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None) -> None: + def __init__(self, running_stage: RunningStage, transform: Optional[INPUT_TRANSFORM_TYPE] = None) -> None: super().__init__() self.running_stage = running_stage if transform: - self.transform = PreprocessTransform.from_transform( - transform, running_stage=running_stage, transforms_registry=self.transforms_registry + self.transform = InputTransform.from_transform( + transform, running_stage=running_stage, input_transforms_registry=self.input_transforms_registry ) def pass_args_to_load_data( @@ -110,7 +110,7 @@ def from_data( cls, *load_data_args, running_stage: Optional[RunningStage] = None, - transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + transform: Optional[INPUT_TRANSFORM_TYPE] = None, **dataset_kwargs: Any, ) -> "BaseDataset": if not running_stage: @@ -126,7 +126,7 @@ def from_data( def from_train_data( cls, *load_data_args, - transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + transform: Optional[INPUT_TRANSFORM_TYPE] = None, **dataset_kwargs: Any, ) -> "BaseDataset": return cls.from_data( @@ -137,7 +137,7 @@ def from_train_data( def from_val_data( cls, *load_data_args, - transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + transform: Optional[INPUT_TRANSFORM_TYPE] = None, **dataset_kwargs: Any, ) -> "BaseDataset": return cls.from_data( @@ -148,7 +148,7 @@ def from_val_data( def from_test_data( cls, *load_data_args, - transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + transform: Optional[INPUT_TRANSFORM_TYPE] = None, **dataset_kwargs: Any, ) -> "BaseDataset": return cls.from_data(*load_data_args, running_stage=RunningStage.TESTING, transform=transform, **dataset_kwargs) @@ -157,7 +157,7 @@ def from_test_data( def from_predict_data( cls, *load_data_args, - transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + transform: Optional[INPUT_TRANSFORM_TYPE] = None, **dataset_kwargs: Any, ) -> "BaseDataset": return cls.from_data( @@ -165,12 +165,14 @@ def from_predict_data( ) @classmethod - def register_transform(cls, enum: Union[LightningEnum, str], fn: Union[Type[PreprocessTransform], partial]) -> None: - if cls.transforms_registry is None: + def register_input_transform( + cls, enum: Union[LightningEnum, str], fn: Union[Type[InputTransform], partial] + ) -> None: + if cls.input_transforms_registry is None: raise MisconfigurationException( - "The class attribute `transforms_registry` should be set as a class attribute. " + "The class attribute `input_transforms_registry` should be set as a class attribute. " ) - cls.transforms_registry(fn=fn, name=enum) + cls.input_transforms_registry(fn=fn, name=enum) def resolve_functions(self): raise NotImplementedError diff --git a/flash/core/data/preprocess_transform.py b/flash/core/data/input_transform.py similarity index 85% rename from flash/core/data/preprocess_transform.py rename to flash/core/data/input_transform.py index 4cc4f5b4ac..844ed44a0f 100644 --- a/flash/core/data/preprocess_transform.py +++ b/flash/core/data/input_transform.py @@ -27,12 +27,12 @@ from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX from flash.core.registry import FlashRegistry -PREPROCESS_TRANSFORM_TYPE = Optional[ - Union["PreprocessTransform", Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]] +INPUT_TRANSFORM_TYPE = Optional[ + Union["InputTransform", Callable, Tuple[Union[LightningEnum, str], Dict[str, Any]], Union[LightningEnum, str]] ] -class PreprocessTransformPlacement(LightningEnum): +class InputTransformPlacement(LightningEnum): PER_SAMPLE_TRANSFORM = "per_sample_transform" PER_BATCH_TRANSFORM = "per_batch_transform" @@ -52,8 +52,8 @@ def wrapper(self, *args, **kwargs) -> Any: return wrapper -class PreprocessTransform(Properties): - def configure_transforms(self, *args, **kwargs) -> Dict[PreprocessTransformPlacement, Callable]: +class InputTransform(Properties): + def configure_transforms(self, *args, **kwargs) -> Dict[InputTransformPlacement, Callable]: """The default transforms to use. Will be overridden by transforms passed to the ``__init__``. @@ -184,21 +184,21 @@ def per_batch_transform_on_device(self, batch: Any) -> Any: @classmethod def from_transform( cls, - transform: PREPROCESS_TRANSFORM_TYPE, + transform: INPUT_TRANSFORM_TYPE, running_stage: RunningStage, - transforms_registry: Optional[FlashRegistry] = None, - ) -> Optional["PreprocessTransform"]: + input_transforms_registry: Optional[FlashRegistry] = None, + ) -> Optional["InputTransform"]: - if isinstance(transform, PreprocessTransform): + if isinstance(transform, InputTransform): transform.running_stage = running_stage return transform if isinstance(transform, Callable): - return cls(running_stage, {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: transform}) + return cls(running_stage, {InputTransformPlacement.PER_SAMPLE_TRANSFORM: transform}) if isinstance(transform, tuple) or isinstance(transform, (LightningEnum, str)): - enum, transform_kwargs = cls._sanitize_registry_transform(transform, transforms_registry) - transform_cls = transforms_registry.get(enum) + enum, transform_kwargs = cls._sanitize_registry_transform(transform, input_transforms_registry) + transform_cls = input_transforms_registry.get(enum) return transform_cls(running_stage, transform=None, **transform_kwargs) if not transform: @@ -209,41 +209,47 @@ def from_transform( @classmethod def from_train_transform( cls, - transform: PREPROCESS_TRANSFORM_TYPE, - transforms_registry: Optional[FlashRegistry] = None, - ) -> Optional["PreprocessTransform"]: + transform: INPUT_TRANSFORM_TYPE, + input_transforms_registry: Optional[FlashRegistry] = None, + ) -> Optional["InputTransform"]: return cls.from_transform( - transform=transform, running_stage=RunningStage.TRAINING, transforms_registry=transforms_registry + transform=transform, + running_stage=RunningStage.TRAINING, + input_transforms_registry=input_transforms_registry, ) @classmethod def from_val_transform( cls, - transform: PREPROCESS_TRANSFORM_TYPE, - transforms_registry: Optional[FlashRegistry] = None, - ) -> Optional["PreprocessTransform"]: + transform: INPUT_TRANSFORM_TYPE, + input_transforms_registry: Optional[FlashRegistry] = None, + ) -> Optional["InputTransform"]: return cls.from_transform( - transform=transform, running_stage=RunningStage.VALIDATING, transforms_registry=transforms_registry + transform=transform, + running_stage=RunningStage.VALIDATING, + input_transforms_registry=input_transforms_registry, ) @classmethod def from_test_transform( cls, - transform: PREPROCESS_TRANSFORM_TYPE, - transforms_registry: Optional[FlashRegistry] = None, - ) -> Optional["PreprocessTransform"]: + transform: INPUT_TRANSFORM_TYPE, + input_transforms_registry: Optional[FlashRegistry] = None, + ) -> Optional["InputTransform"]: return cls.from_transform( - transform=transform, running_stage=RunningStage.TESTING, transforms_registry=transforms_registry + transform=transform, running_stage=RunningStage.TESTING, input_transforms_registry=input_transforms_registry ) @classmethod def from_predict_transform( cls, - transform: PREPROCESS_TRANSFORM_TYPE, - transforms_registry: Optional[FlashRegistry] = None, - ) -> Optional["PreprocessTransform"]: + transform: INPUT_TRANSFORM_TYPE, + input_transforms_registry: Optional[FlashRegistry] = None, + ) -> Optional["InputTransform"]: return cls.from_transform( - transform=transform, running_stage=RunningStage.PREDICTING, transforms_registry=transforms_registry + transform=transform, + running_stage=RunningStage.PREDICTING, + input_transforms_registry=input_transforms_registry, ) def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, Callable]]: @@ -251,7 +257,7 @@ def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, resolved_function = getattr( self, - DataPipeline._resolve_function_hierarchy("configure_transforms", self, running_stage, PreprocessTransform), + DataPipeline._resolve_function_hierarchy("configure_transforms", self, running_stage, InputTransform), ) params = inspect.signature(resolved_function).parameters transforms_out: Optional[Dict[str, Callable]] = resolved_function( @@ -259,10 +265,10 @@ def _resolve_transforms(self, running_stage: RunningStage) -> Optional[Dict[str, ) transforms_out = transforms_out or {} - for placement in PreprocessTransformPlacement: + for placement in InputTransformPlacement: transform_name = f"configure_{placement.value}" resolved_function = getattr( - self, DataPipeline._resolve_function_hierarchy(transform_name, self, running_stage, PreprocessTransform) + self, DataPipeline._resolve_function_hierarchy(transform_name, self, running_stage, InputTransform) ) params = inspect.signature(resolved_function).parameters transforms: Optional[Dict[str, Callable]] = resolved_function( @@ -278,7 +284,7 @@ def _check_transforms( if transform is None: return transform - keys_diff = set(transform.keys()).difference([v for v in PreprocessTransformPlacement]) + keys_diff = set(transform.keys()).difference([v for v in InputTransformPlacement]) if len(keys_diff) > 0: raise MisconfigurationException( @@ -315,11 +321,11 @@ def _get_transform(self, transform: Dict[str, Callable]) -> Callable: @classmethod def _sanitize_registry_transform( - cls, transform: Tuple[Union[LightningEnum, str], Any], transforms_registry: Optional[FlashRegistry] + cls, transform: Tuple[Union[LightningEnum, str], Any], input_transforms_registry: Optional[FlashRegistry] ) -> Tuple[Union[LightningEnum, str], Dict]: msg = "The transform should be provided as a tuple with the following types (LightningEnum, Dict[str, Any]) " msg += "when requesting transform from the registry." - if not transforms_registry: + if not input_transforms_registry: raise MisconfigurationException("You requested a transform from the registry, but it is empty.") if isinstance(transform, tuple) and len(transform) > 2: raise MisconfigurationException(msg) @@ -337,7 +343,7 @@ def _sanitize_registry_transform( def __repr__(self) -> str: return f"{self.__class__.__name__}(running_stage={self.running_stage}, transform={self.transform})" - def __getitem__(self, placement: PreprocessTransformPlacement) -> Callable: + def __getitem__(self, placement: InputTransformPlacement) -> Callable: return self.transform[placement] def _make_collates(self, on_device: bool, collate: Callable) -> Tuple[Callable, Callable]: @@ -359,18 +365,18 @@ def _create_collate_preprocessors(self) -> Tuple[Any]: prefix: str = _STAGES_PREFIX[self.running_stage] func_names: Dict[str, str] = { - k: DataPipeline._resolve_function_hierarchy(k, self, self.running_stage, PreprocessTransform) - for k in [v.value for v in PreprocessTransformPlacement] + k: DataPipeline._resolve_function_hierarchy(k, self, self.running_stage, InputTransform) + for k in [v.value for v in InputTransformPlacement] } collate_fn: Callable = getattr(self, func_names["collate"]) per_batch_transform_overriden: bool = DataPipeline._is_overriden_recursive( - "per_batch_transform", self, PreprocessTransform, prefix=prefix + "per_batch_transform", self, InputTransform, prefix=prefix ) per_sample_transform_on_device_overriden: bool = DataPipeline._is_overriden_recursive( - "per_sample_transform_on_device", self, PreprocessTransform, prefix=prefix + "per_sample_transform_on_device", self, InputTransform, prefix=prefix ) is_per_overriden = per_batch_transform_overriden and per_sample_transform_on_device_overriden diff --git a/flash/core/data/new_data_module.py b/flash/core/data/new_data_module.py index c3d05a4089..7340ccc49c 100644 --- a/flash/core/data/new_data_module.py +++ b/flash/core/data/new_data_module.py @@ -29,7 +29,7 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_pipeline import DefaultPreprocess, Postprocess from flash.core.data.datasets import BaseDataset -from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE, PreprocessTransform +from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FIFTYONE_AVAILABLE @@ -267,10 +267,10 @@ def create_flash_datasets( val_data: Optional[Any] = None, test_data: Optional[Any] = None, predict_data: Optional[Any] = None, - train_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, - val_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, - test_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, - predict_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + train_transform: Optional[INPUT_TRANSFORM_TYPE] = None, + val_transform: Optional[INPUT_TRANSFORM_TYPE] = None, + test_transform: Optional[INPUT_TRANSFORM_TYPE] = None, + predict_transform: Optional[INPUT_TRANSFORM_TYPE] = None, **flash_dataset_kwargs, ) -> Tuple[Optional[BaseDataset]]: cls._verify_flash_dataset_enum(enum) @@ -311,7 +311,7 @@ def _create_flash_dataset( flash_dataset_cls, *load_data_args, running_stage: RunningStage, - transform: Optional[PreprocessTransform], + transform: Optional[InputTransform], **kwargs, ) -> Optional[BaseDataset]: if load_data_args[0] is not None: diff --git a/flash_examples/flash_components/custom_data_loading.py b/flash_examples/flash_components/custom_data_loading.py index 0ccbad904e..0c3edf23d3 100644 --- a/flash_examples/flash_components/custom_data_loading.py +++ b/flash_examples/flash_components/custom_data_loading.py @@ -19,13 +19,12 @@ import torchvision.transforms as T from PIL import Image from pytorch_lightning import seed_everything -from pytorch_lightning.utilities.enums import LightningEnum from torch.utils.data._utils.collate import default_collate -from flash import _PACKAGE_ROOT, FlashDataset, PreprocessTransform +from flash import _PACKAGE_ROOT, FlashDataset, InputTransform from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.new_data_module import DataModule -from flash.core.data.preprocess_transform import PREPROCESS_TRANSFORM_TYPE from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import download_data @@ -44,24 +43,7 @@ ############################################################################################# -# Step 1 / 5: Create an enum to describe your new loading mechanism # -############################################################################################# - - -class DataTransform(LightningEnum): - - BASE = "base" - RANDOM_ROTATION = "random_rotation" - RANDOM_90_DEG_ROTATION = "random_90_def_rotation" - - -class CustomDataFormat(LightningEnum): - - MULTIPLE_FOLDERS = "multiple_folders" - - -############################################################################################# -# Step 2 / 5: Implement a FlashDataset # +# Step 1 / 2: Implement a FlashDataset # # # # A `FlashDataset` is a state-aware (c.f training, validating, testing and predicting) # # dataset. # @@ -108,15 +90,15 @@ def predict_load_data(self, predict_folder: str) -> List[Dict[DefaultDataKeys, A ############################################################################################# -# Step 3 / 5: [optional] Implement a PreprocessTransform # +# Step 2 / 2: [optional] Implement a InputTransform # # # -# A `PreprocessTransform` is a state-aware (c.f training, validating, testing and predicting) # +# A `InputTransform` is a state-aware (c.f training, validating, testing and predicting) # # transform. You would have to implement a `configure_transforms` hook with your transform # # # ############################################################################################# -class ImageBaseTransform(PreprocessTransform): +class BaseImageInputTransform(InputTransform): def configure_per_sample_transform(self, image_size: int = 224) -> Any: per_sample_transform = T.Compose([T.Resize((image_size, image_size)), T.ToTensor()]) return ApplyToKeys(DefaultDataKeys.INPUT, per_sample_transform) @@ -125,7 +107,7 @@ def configure_collate(self) -> Any: return default_collate -class ImageRandomRotationTransform(ImageBaseTransform): +class ImageRandomRotationInputTransform(BaseImageInputTransform): def configure_per_sample_transform(self, image_size: int = 224, rotation: float = 0) -> Any: transforms = [T.Resize((image_size, image_size)), T.ToTensor()] if self.training: @@ -135,15 +117,15 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float # Register your transform within the Flash Dataset registry # Note: Registries can be shared by multiple dataset. -MultipleFoldersImageDataset.register_transform(DataTransform.BASE, ImageBaseTransform) -MultipleFoldersImageDataset.register_transform(DataTransform.RANDOM_ROTATION, ImageRandomRotationTransform) -MultipleFoldersImageDataset.register_transform( - DataTransform.RANDOM_90_DEG_ROTATION, partial(ImageRandomRotationTransform, rotation=90) +MultipleFoldersImageDataset.register_input_transform("base", BaseImageInputTransform) +MultipleFoldersImageDataset.register_input_transform("random_rotation", ImageRandomRotationInputTransform) +MultipleFoldersImageDataset.register_input_transform( + "random_90_def_rotation", partial(ImageRandomRotationInputTransform, rotation=90) ) train_dataset = MultipleFoldersImageDataset.from_train_data( TRAIN_FOLDERS, - transform=(DataTransform.RANDOM_ROTATION, {"rotation": 45}), + transform=("random_rotation", {"rotation": 45}), ) print(train_dataset.transform) @@ -151,18 +133,18 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float # ImageClassificationRandomRotationTransform( # running_stage=train, # transform={ -# PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: ApplyToKeys( +# InputTransformPlacement.PER_SAMPLE_TRANSFORM: ApplyToKeys( # keys="input", # transform=Compose( # ToTensor() # RandomRotation(degrees=[-45.0, 45.0], interpolation=nearest, expand=False, fill=0)), -# PreprocessTransformPlacement.COLLATE: default_collate, +# InputTransformPlacement.COLLATE: default_collate, # }, # ) train_dataset = MultipleFoldersImageDataset.from_train_data( TRAIN_FOLDERS, - transform=DataTransform.RANDOM_90_DEG_ROTATION, + transform="random_90_def_rotation", ) print(train_dataset.transform) @@ -170,23 +152,23 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float # ImageClassificationRandomRotationTransform( # running_stage=train, # transform={ -# PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: ApplyToKeys( +# InputTransformPlacement.PER_SAMPLE_TRANSFORM: ApplyToKeys( # keys="input", # transform=Compose( # ToTensor() # RandomRotation(degrees=[-90.0, 90.0], interpolation=nearest, expand=False, fill=0)), -# PreprocessTransformPlacement.COLLATE: default_collate, +# InputTransformPlacement.COLLATE: default_collate, # }, # ) -val_dataset = MultipleFoldersImageDataset.from_val_data(VAL_FOLDERS, transform=DataTransform.BASE) +val_dataset = MultipleFoldersImageDataset.from_val_data(VAL_FOLDERS, transform="base") print(val_dataset.transform) # Out: # ImageClassificationRandomRotationTransform( # running_stage=validate, # transform={ -# PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: ApplyToKeys(keys="input", transform=ToTensor()), -# PreprocessTransformPlacement.COLLATE: default_collate, +# InputTransformPlacement.PER_SAMPLE_TRANSFORM: ApplyToKeys(keys="input", transform=ToTensor()), +# InputTransformPlacement.COLLATE: default_collate, # }, # ) @@ -208,9 +190,9 @@ def configure_per_sample_transform(self, image_size: int = 224, rotation: float datamodule = DataModule( - train_dataset=MultipleFoldersImageDataset.from_train_data(TRAIN_FOLDERS, transform=DataTransform.RANDOM_ROTATION), - val_dataset=MultipleFoldersImageDataset.from_val_data(VAL_FOLDERS, transform=DataTransform.BASE), - predict_dataset=MultipleFoldersImageDataset.from_predict_data(PREDICT_FOLDER, transform=DataTransform.BASE), + train_dataset=MultipleFoldersImageDataset.from_train_data(TRAIN_FOLDERS, transform="random_rotation"), + val_dataset=MultipleFoldersImageDataset.from_val_data(VAL_FOLDERS, transform="base"), + predict_dataset=MultipleFoldersImageDataset.from_predict_data(PREDICT_FOLDER, transform="base"), batch_size=2, ) @@ -274,16 +256,16 @@ def from_multiple_folders( val_folders: Optional[List[str]] = None, test_folders: Optional[List[str]] = None, predict_folder: Optional[str] = None, - train_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, - val_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, - test_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, - predict_transform: Optional[PREPROCESS_TRANSFORM_TYPE] = None, + train_transform: Optional[INPUT_TRANSFORM_TYPE] = None, + val_transform: Optional[INPUT_TRANSFORM_TYPE] = None, + test_transform: Optional[INPUT_TRANSFORM_TYPE] = None, + predict_transform: Optional[INPUT_TRANSFORM_TYPE] = None, **data_module_kwargs: Any, ) -> "ImageClassificationDataModule": return cls( *cls.create_flash_datasets( - CustomDataFormat.MULTIPLE_FOLDERS, + "multiple_folders", train_folders, val_folders, test_folders, @@ -297,7 +279,7 @@ def from_multiple_folders( ) -ImageClassificationDataModule.register_flash_dataset(CustomDataFormat.MULTIPLE_FOLDERS, MultipleFoldersImageDataset) +ImageClassificationDataModule.register_flash_dataset("multiple_folders", MultipleFoldersImageDataset) # Create the datamodule with your new constructor. This is purely equivalent to the previous datamdoule creation. @@ -305,9 +287,9 @@ def from_multiple_folders( train_folders=TRAIN_FOLDERS, val_folders=VAL_FOLDERS, predict_folder=PREDICT_FOLDER, - train_transform=DataTransform.RANDOM_ROTATION, - val_transform=DataTransform.BASE, - predict_transform=DataTransform.BASE, + train_transform="random_rotation", + val_transform="base", + predict_transform="base", batch_size=2, ) diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index 7124675f30..6853cd4a63 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -525,7 +525,7 @@ def __len__(self) -> int: return 5 -class TestPreprocessTransformationsDataSource(DataSource): +class TestInputTransformationsDataSource(DataSource): def __init__(self): super().__init__() @@ -583,9 +583,9 @@ def predict_load_data(self, sample) -> LamdaDummyDataset: return LamdaDummyDataset(self.fn_predict_load_data) -class TestPreprocessTransformations(DefaultPreprocess): +class TestInputTransformations(DefaultPreprocess): def __init__(self): - super().__init__(data_sources={"default": TestPreprocessTransformationsDataSource()}) + super().__init__(data_sources={"default": TestInputTransformationsDataSource()}) self.train_pre_tensor_transform_called = False self.train_collate_called = False @@ -651,7 +651,7 @@ def test_post_tensor_transform(self, sample: Tensor) -> Tensor: return sample -class TestPreprocessTransformations2(TestPreprocessTransformations): +class TestInputTransformations2(TestInputTransformations): def val_to_tensor_transform(self, sample: Any) -> Tensor: self.val_to_tensor_transform_called = True return {"a": tensor(sample["a"]), "b": tensor(sample["b"])} @@ -684,7 +684,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): def test_datapipeline_transformations(tmpdir): datamodule = DataModule.from_data_source( - "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations() + "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestInputTransformations() ) assert datamodule.train_dataloader().dataset[0] == (0, 1, 2, 3) @@ -697,7 +697,7 @@ def test_datapipeline_transformations(tmpdir): batch = next(iter(datamodule.val_dataloader())) datamodule = DataModule.from_data_source( - "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestPreprocessTransformations2() + "default", 1, 1, 1, 1, batch_size=2, num_workers=0, preprocess=TestInputTransformations2() ) batch = next(iter(datamodule.val_dataloader())) assert torch.equal(batch["a"], tensor([0, 1])) diff --git a/tests/core/data/test_preprocess_transform.py b/tests/core/data/test_input_transform.py similarity index 53% rename from tests/core/data/test_preprocess_transform.py rename to tests/core/data/test_input_transform.py index 8bd5b3f9c6..c11be38df6 100644 --- a/tests/core/data/test_preprocess_transform.py +++ b/tests/core/data/test_input_transform.py @@ -20,56 +20,54 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data.dataloader import default_collate -from flash.core.data.preprocess_transform import PreprocessTransform, PreprocessTransformPlacement +from flash.core.data.input_transform import InputTransform, InputTransformPlacement from flash.core.registry import FlashRegistry -def test_preprocess_transform(): +def test_input_transform(): - transform = PreprocessTransform(running_stage=RunningStage.TRAINING) + transform = InputTransform(running_stage=RunningStage.TRAINING) - assert ( - "PreprocessTransform(running_stage=train, transform={" - in str(transform) + assert "InputTransform(running_stage=train, transform={" in str( + transform ) def fn(x): return x + 1 - transform = PreprocessTransform.from_train_transform(transform=fn) - assert transform.transform == {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn} + transform = InputTransform.from_train_transform(transform=fn) + assert transform.transform == {InputTransformPlacement.PER_SAMPLE_TRANSFORM: fn} - transform = PreprocessTransform.from_val_transform(transform=fn) - assert transform.transform == {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn} + transform = InputTransform.from_val_transform(transform=fn) + assert transform.transform == {InputTransformPlacement.PER_SAMPLE_TRANSFORM: fn} - transform = PreprocessTransform.from_test_transform(transform=fn) - assert transform.transform == {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn} + transform = InputTransform.from_test_transform(transform=fn) + assert transform.transform == {InputTransformPlacement.PER_SAMPLE_TRANSFORM: fn} - transform = PreprocessTransform.from_predict_transform(transform=fn) - assert transform.transform == {PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn} + transform = InputTransform.from_predict_transform(transform=fn) + assert transform.transform == {InputTransformPlacement.PER_SAMPLE_TRANSFORM: fn} - class MyPreprocessTransform(PreprocessTransform): + class MyInputTransform(InputTransform): def configure_transforms(self) -> Optional[Dict[str, Callable]]: return None - transform = MyPreprocessTransform(running_stage=RunningStage.TRAINING) + transform = MyInputTransform(running_stage=RunningStage.TRAINING) assert not transform._current_fn - assert ( - "PreprocessTransform(running_stage=train, transform={" - in str(transform) + assert "InputTransform(running_stage=train, transform={" in str( + transform ) - class MyPreprocessTransform(PreprocessTransform): + class MyInputTransform(InputTransform): def fn(self, x): return x + 1 def configure_per_sample_transform(self) -> Optional[Dict[str, Callable]]: return self.fn if self.training else fn - transform = MyPreprocessTransform(running_stage=RunningStage.TRAINING) + transform = MyInputTransform(running_stage=RunningStage.TRAINING) assert transform.transform == { - PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: transform.fn, - PreprocessTransformPlacement.COLLATE: default_collate, + InputTransformPlacement.PER_SAMPLE_TRANSFORM: transform.fn, + InputTransformPlacement.COLLATE: default_collate, } transform._current_fn = "per_sample_transform" @@ -90,34 +88,34 @@ def configure_per_sample_transform(self) -> Optional[Dict[str, Callable]]: assert transform.current_transform == transform._identity assert transform.per_batch_transform(2) == 2 - transform = MyPreprocessTransform(running_stage=RunningStage.TESTING) + transform = MyInputTransform(running_stage=RunningStage.TESTING) assert transform.transform == { - PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn, - PreprocessTransformPlacement.COLLATE: default_collate, + InputTransformPlacement.PER_SAMPLE_TRANSFORM: fn, + InputTransformPlacement.COLLATE: default_collate, } assert transform.transforms == { "transform": { - PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: fn, - PreprocessTransformPlacement.COLLATE: default_collate, + InputTransformPlacement.PER_SAMPLE_TRANSFORM: fn, + InputTransformPlacement.COLLATE: default_collate, } } - transforms_registry = FlashRegistry("transforms") - transforms_registry(fn=MyPreprocessTransform, name="something") + input_transforms_registry = FlashRegistry("transforms") + input_transforms_registry(fn=MyInputTransform, name="something") - transform = PreprocessTransform.from_transform( - running_stage=RunningStage.TRAINING, transform="something", transforms_registry=transforms_registry + transform = InputTransform.from_transform( + running_stage=RunningStage.TRAINING, transform="something", input_transforms_registry=input_transforms_registry ) transform = transform.from_transform( - running_stage=RunningStage.TRAINING, transform=transform, transforms_registry=transforms_registry + running_stage=RunningStage.TRAINING, transform=transform, input_transforms_registry=input_transforms_registry ) - assert isinstance(transform, MyPreprocessTransform) + assert isinstance(transform, MyInputTransform) assert transform.transform == { - PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM: transform.fn, - PreprocessTransformPlacement.COLLATE: default_collate, + InputTransformPlacement.PER_SAMPLE_TRANSFORM: transform.fn, + InputTransformPlacement.COLLATE: default_collate, } collate_fn = transform.dataloader_collate_fn @@ -132,30 +130,30 @@ def configure_per_sample_transform(self) -> Optional[Dict[str, Callable]]: assert transform._collate_in_worker_from_transform - class MyPreprocessTransform(PreprocessTransform): + class MyInputTransform(InputTransform): def configure_transforms(self) -> Optional[Dict[str, Callable]]: return { - PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM_ON_DEVICE: fn, + InputTransformPlacement.PER_SAMPLE_TRANSFORM_ON_DEVICE: fn, } def configure_per_batch_transform(self): return fn with pytest.raises(MisconfigurationException, match="`per_batch_transform` and `per_sample_transform_on_device`"): - transform = MyPreprocessTransform(running_stage=RunningStage.TESTING) + transform = MyInputTransform(running_stage=RunningStage.TESTING) with pytest.raises(MisconfigurationException, match="The format for the transform isn't correct"): - PreprocessTransform.from_transform(1, running_stage=RunningStage.TRAINING) + InputTransform.from_transform(1, running_stage=RunningStage.TRAINING) - class MyPreprocessTransform(PreprocessTransform): + class MyInputTransform(InputTransform): def configure_transforms(self) -> Optional[Dict[str, Callable]]: return { - PreprocessTransformPlacement.COLLATE: fn, - PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM_ON_DEVICE: fn, - PreprocessTransformPlacement.PER_BATCH_TRANSFORM_ON_DEVICE: fn, + InputTransformPlacement.COLLATE: fn, + InputTransformPlacement.PER_SAMPLE_TRANSFORM_ON_DEVICE: fn, + InputTransformPlacement.PER_BATCH_TRANSFORM_ON_DEVICE: fn, } - transform = MyPreprocessTransform(running_stage=RunningStage.TESTING) + transform = MyInputTransform(running_stage=RunningStage.TESTING) assert not transform._collate_in_worker_from_transform def compose(x, funcs): @@ -163,46 +161,46 @@ def compose(x, funcs): x = f(x) return x - transform = PreprocessTransform.from_transform( + transform = InputTransform.from_transform( transform=partial(compose, funcs=[fn, fn]), running_stage=RunningStage.TRAINING ) - assert transform[PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM](1) == 3 + assert transform[InputTransformPlacement.PER_SAMPLE_TRANSFORM](1) == 3 def test_transform_with_registry(): def fn(): pass - class MyPreprocessTransform(PreprocessTransform): + class MyInputTransform(InputTransform): def configure_transforms(self, name: str = "lightning") -> Optional[Dict[str, Callable]]: self.name = name return { - PreprocessTransformPlacement.PER_SAMPLE_TRANSFORM_ON_DEVICE: fn, + InputTransformPlacement.PER_SAMPLE_TRANSFORM_ON_DEVICE: fn, } registry = FlashRegistry("transforms") - registry(name="custom", fn=MyPreprocessTransform) + registry(name="custom", fn=MyInputTransform) - transform = PreprocessTransform.from_train_transform(transform="custom", transforms_registry=registry) - assert isinstance(transform, MyPreprocessTransform) + transform = InputTransform.from_train_transform(transform="custom", input_transforms_registry=registry) + assert isinstance(transform, MyInputTransform) assert transform.name == "lightning" - transform = PreprocessTransform.from_train_transform( - transform=("custom", {"name": "flash"}), transforms_registry=registry + transform = InputTransform.from_train_transform( + transform=("custom", {"name": "flash"}), input_transforms_registry=registry ) - assert isinstance(transform, MyPreprocessTransform) + assert isinstance(transform, MyInputTransform) assert transform.name == "flash" - transform = PreprocessTransform.from_train_transform(transform=None, transforms_registry=registry) + transform = InputTransform.from_train_transform(transform=None, input_transforms_registry=registry) assert transform is None - transform = PreprocessTransform.from_train_transform(transform=None, transforms_registry=registry) + transform = InputTransform.from_train_transform(transform=None, input_transforms_registry=registry) assert transform is None with pytest.raises( MisconfigurationException, match="The transform should be provided as a tuple with the following types" ): - transform = PreprocessTransform.from_train_transform(transform=("custom", None), transforms_registry=registry) + transform = InputTransform.from_train_transform(transform=("custom", None), input_transforms_registry=registry) with pytest.raises(MisconfigurationException, match="The format for the transform isn't correct"): - transform = PreprocessTransform.from_train_transform(transform=1, transforms_registry=registry) + transform = InputTransform.from_train_transform(transform=1, input_transforms_registry=registry) diff --git a/tests/core/data/test_new_data_module.py b/tests/core/data/test_new_data_module.py index ad7d0e24fb..649d8990a2 100644 --- a/tests/core/data/test_new_data_module.py +++ b/tests/core/data/test_new_data_module.py @@ -20,8 +20,8 @@ from flash import Task, Trainer from flash.core.data.datasets import FlashDataset +from flash.core.data.input_transform import InputTransform from flash.core.data.new_data_module import DataModule -from flash.core.data.preprocess_transform import PreprocessTransform def test_data_module(): @@ -42,7 +42,7 @@ def predict_fn(data): class TestDataset(FlashDataset): pass - class TestTransform(PreprocessTransform): + class TestTransform(InputTransform): def configure_collate(self, *args, **kwargs) -> Callable: return default_collate