diff --git a/CHANGELOG.md b/CHANGELOG.md index 739b97e5f1..9fd2e3c5c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `OutputTransform.save_sample` and `save_data` hooks ([#948](https://github.com/PyTorchLightning/lightning-flash/pull/948)) +- (Removed InputTransform `pre_tensor_transform`, `to_tensor_transform`, `post_tensor_transform` hooks in favour of `per_sample_transform` ([#1010](https://github.com/PyTorchLightning/lightning-flash/pull/1010)) + + ## [0.5.2] - 2021-11-05 ### Added diff --git a/README.md b/README.md index 5056694562..4e8294cbe7 100644 --- a/README.md +++ b/README.md @@ -257,7 +257,7 @@ def mixup(batch, alpha=1.0): train_transform = { # applied only on images as ApplyToKeys is used with `input` - "post_tensor_transform": ApplyToKeys( + "per_sample_transform": ApplyToKeys( "input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))), # applied to the entire dictionary as `ApplyToKeys` isn't used. diff --git a/docs/source/general/data.rst b/docs/source/general/data.rst index 69d2207710..5f5f5149d6 100644 --- a/docs/source/general/data.rst +++ b/docs/source/general/data.rst @@ -31,7 +31,7 @@ Here are common terms you need to be familiar with: - The :class:`~flash.core.data.io.input.Input` provides :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample` hooks for creating data sets from metadata (such as folder names). * - :class:`~flash.core.data.io.input_transform.InputTransform` - The :class:`~flash.core.data.io.input_transform.InputTransform` provides a simple hook-based API to encapsulate your pre-processing logic. - These hooks (such as :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device). + These hooks (such as :meth:`~flash.core.data.io.input_transform.InputTransform.per_sample_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device). The :class:`~flash.core.data.data_pipeline.DataPipeline` contains a system to call the right hooks when needed. The :class:`~flash.core.data.io.input_transform.InputTransform` hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform). * - :class:`~flash.core.data.io.output_transform.OutputTransform` @@ -112,7 +112,7 @@ Here's an example: from flash.core.data.transforms import ApplyToKeys from flash.image import ImageClassificationData, ImageClassifier - transform = {"to_tensor_transform": ApplyToKeys("input", my_to_tensor_transform)} + transform = {"per_sample_transform": ApplyToKeys("input", my_per_sample_transform)} datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", @@ -132,8 +132,8 @@ Alternatively, the user may directly override the hooks for their needs like thi class CustomImageClassificationInputTransform(ImageClassificationInputTransform): - def to_tensor_transform(sample: Dict[str, Any]) -> Dict[str, Any]: - sample["input"] = my_to_tensor_transform(sample["input"]) + def per_sample_transform(sample: Dict[str, Any]) -> Dict[str, Any]: + sample["input"] = my_per_sample_transform(sample["input"]) return sample @@ -267,7 +267,7 @@ Next, implement your custom ``ImageClassificationInputTransform`` with some defa return cls(**state_dict) def default_transforms(self) -> Dict[str, Callable]: - return {"to_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.to_tensor)} + return {"per_sample_transform": ApplyToKeys(DataKeys.INPUT, T.to_tensor)} 4. The DataModule _________________ @@ -325,9 +325,7 @@ ______________ .. note:: - The :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`, - :meth:`~flash.core.data.io.input_transform.InputTransform.to_tensor_transform`, - :meth:`~flash.core.data.io.input_transform.InputTransform.post_tensor_transform`, + The :meth:`~flash.core.data.io.input_transform.InputTransform.per_sample_transform`, :meth:`~flash.core.data.io.input_transform.InputTransform.collate`, :meth:`~flash.core.data.io.input_transform.InputTransform.per_batch_transform` are injected as the :paramref:`torch.utils.data.DataLoader.collate_fn` function of the DataLoader. @@ -342,9 +340,7 @@ Example:: # This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformSequential` for sample in samples: - sample = pre_tensor_transform(sample) - sample = to_tensor_transform(sample) - sample = post_tensor_transform(sample) + sample = per_sample_transform(sample) samples = type(samples)(samples) diff --git a/docs/source/integrations/icevision.rst b/docs/source/integrations/icevision.rst index ff21565a4e..bfb71356b2 100644 --- a/docs/source/integrations/icevision.rst +++ b/docs/source/integrations/icevision.rst @@ -35,7 +35,7 @@ Here's an example: from flash.image import ObjectDetectionData train_transform = { - "pre_tensor_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]), + "per_sample_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]), } datamodule = ObjectDetectionData.from_coco( diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index a32f74d54b..a9f364226c 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -89,7 +89,6 @@ Flash automatically applies some default image transformations and augmentations The base :class:`~flash.core.data.io.input_transform.InputTransform` defines 7 hooks for different stages in the data loading pipeline. To apply image augmentations you can directly import the ``default_transforms`` from ``flash.image.classification.transforms`` and then merge your custom image transformations with them using the :func:`~flash.core.data.transforms.merge_transforms` helper function. Here's an example where we load the default transforms and merge with custom `torchvision` transformations. -We use the `post_tensor_transform` hook to apply the transformations after the image has been converted to a `torch.Tensor`. .. testsetup:: transformations @@ -108,12 +107,12 @@ We use the `post_tensor_transform` hook to apply the transformations after the i from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.transforms import default_transforms - post_tensor_transform = ApplyToKeys( + per_sample_transform = ApplyToKeys( DataKeys.INPUT, T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), T.RandomPerspective()]), ) - new_transforms = merge_transforms(default_transforms((64, 64)), {"post_tensor_transform": post_tensor_transform}) + new_transforms = merge_transforms(default_transforms((64, 64)), {"per_sample_transform": per_sample_transform}) datamodule = ImageClassificationData.from_folders( train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", train_transform=new_transforms diff --git a/docs/source/reference/object_detection.rst b/docs/source/reference/object_detection.rst index 1203b582c7..7a25a7742a 100644 --- a/docs/source/reference/object_detection.rst +++ b/docs/source/reference/object_detection.rst @@ -93,7 +93,7 @@ For object-detection tasks, you can leverage the transformations from `Albumenta from flash.image import ObjectDetectionData train_transform = { - "pre_tensor_transform": transforms.IceVisionTransformAdapter( + "per_sample_transform": transforms.IceVisionTransformAdapter( [*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()] ) } diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index 3d12658781..543898b0c4 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -93,7 +93,7 @@ InputTransform The :class:`~flash.core.data.io.input_transform.InputTransform` object contains all the data transforms. Internally we inject the :class:`~flash.core.data.io.input_transform.InputTransform` transforms at several points along the pipeline. -Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your :class:`~flash.core.data.io.input_transform.InputTransform` is as simple as implementing the ``default_transforms`` method. +Defining the standard transforms (typically at least a ``per_sample_transform`` should be defined) for your :class:`~flash.core.data.io.input_transform.InputTransform` is as simple as implementing the ``default_transforms`` method. The :class:`~flash.core.data.io.input_transform.InputTransform` must take ``train_transform``, ``val_transform``, ``test_transform``, and ``predict_transform`` arguments in the ``__init__``. These arguments can be provided by the user (when creating the :class:`~flash.core.data.data_module.DataModule`) to override the default transforms. Any additional arguments are up to you. @@ -115,7 +115,7 @@ Here's our ``TemplateInputTransform.__init__``: :dedent: 4 :pyobject: TemplateInputTransform.__init__ -For our ``TemplateInputTransform``, we'll just configure a default ``to_tensor_transform``. +For our ``TemplateInputTransform``, we'll just configure a default ``per_sample_transform``. Let's first define the transform as a ``staticmethod``: .. literalinclude:: ../../../flash/template/classification/data.py diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py index a61a41012c..64882038c9 100644 --- a/flash/audio/classification/transforms.py +++ b/flash/audio/classification/transforms.py @@ -22,7 +22,6 @@ from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: - import torchvision from torchvision import transforms as T if _TORCHAUDIO_AVAILABLE: @@ -33,11 +32,10 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable] """The default transforms for audio classification for spectrograms: resize the spectrogram, convert the spectrogram and target to a tensor, and collate the batch.""" return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + "per_sample_transform": nn.Sequential( + ApplyToKeys(DataKeys.INPUT, T.Compose([T.ToTensor(), T.Resize(spectrogram_size)])), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), - "post_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.Resize(spectrogram_size)), "collate": default_collate, } @@ -55,5 +53,5 @@ def train_default_transforms( augs.append(ApplyToKeys(DataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))) if len(augs) > 0: - return merge_transforms(default_transforms(spectrogram_size), {"post_tensor_transform": nn.Sequential(*augs)}) + return merge_transforms(default_transforms(spectrogram_size), {"per_sample_transform": nn.Sequential(*augs)}) return default_transforms(spectrogram_size) diff --git a/flash/core/data/base_viz.py b/flash/core/data/base_viz.py index 513714a8db..211ef15493 100644 --- a/flash/core/data/base_viz.py +++ b/flash/core/data/base_viz.py @@ -37,13 +37,7 @@ class CustomBaseVisualization(BaseVisualization): def show_load_sample(self, samples: List[Any], running_stage): # plot samples - def show_pre_tensor_transform(self, samples: List[Any], running_stage): - # plot samples - - def show_to_tensor_transform(self, samples: List[Any], running_stage): - # plot samples - - def show_post_tensor_transform(self, samples: List[Any], running_stage): + def show_per_sample_transform(self, samples: List[Any], running_stage): # plot samples def show_collate(self, batch: List[Any], running_stage): @@ -93,9 +87,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage): # out { 'load_sample': [...], - 'pre_tensor_transform': [...], - 'to_tensor_transform': [...], - 'post_tensor_transform': [...], + 'per_sample_transform': [...], 'collate': [...], 'per_batch_transform': [...], } @@ -125,14 +117,8 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_li def show_load_sample(self, samples: List[Any], running_stage: RunningStage): """Override to visualize ``load_sample`` output data.""" - def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - """Override to visualize ``pre_tensor_transform`` output data.""" - - def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - """Override to visualize ``to_tensor_transform`` output data.""" - - def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - """Override to visualize ``post_tensor_transform`` output data.""" + def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage): + """Override to visualize ``per_sample_transform`` output data.""" def show_collate(self, batch: List[Any], running_stage: RunningStage) -> None: """Override to visualize ``collate`` output data.""" diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 2f77f2b56d..f6d29f9ec4 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -30,32 +30,24 @@ def __init__( self, deserializer: "Deserializer", input_transform: "InputTransform", - pre_tensor_transform: Callable, - to_tensor_transform: Callable, + per_sample_transform: Callable, ): super().__init__() self.input_transform = input_transform self.callback = ControlFlow(self.input_transform.callbacks) self.deserializer = convert_to_modules(deserializer) - self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) - self.to_tensor_transform = convert_to_modules(to_tensor_transform) + self.per_sample_transform = convert_to_modules(per_sample_transform) self._current_stage_context = CurrentRunningStageContext(RunningStage.PREDICTING, input_transform, reset=False) - self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", input_transform) - self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", input_transform) + self._per_sample_transform_context = CurrentFuncContext("per_sample_transform", input_transform) def forward(self, sample: str): - sample = self.deserializer(sample) with self._current_stage_context: - with self._pre_tensor_transform_context: - sample = self.pre_tensor_transform(sample) - self.callback.on_pre_tensor_transform(sample, RunningStage.PREDICTING) - - with self._to_tensor_transform_context: - sample = self.to_tensor_transform(sample) - self.callback.on_to_tensor_transform(sample, RunningStage.PREDICTING) + with self._per_sample_transform_context: + sample = self.per_sample_transform(sample) + self.callback.on_per_sample_transform(sample, RunningStage.PREDICTING) return sample diff --git a/flash/core/data/callback.py b/flash/core/data/callback.py index 64bd06ad39..9ffbad6de5 100644 --- a/flash/core/data/callback.py +++ b/flash/core/data/callback.py @@ -21,17 +21,8 @@ class FlashCallback(Callback): trainer = Trainer(callbacks=[MyCustomCallback()]) """ - def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: - """Called once a sample has been loaded using ``load_sample``.""" - - def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - """Called once ``pre_tensor_transform`` has been applied to a sample.""" - - def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - """Called once ``to_tensor_transform`` has been applied to a sample.""" - - def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: - """Called once ``post_tensor_transform`` has been applied to a sample.""" + def on_per_sample_transform(self, sample: Tensor, running_stage: RunningStage) -> None: + """Called once ``per_sample_transform`` has been applied to a sample.""" def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: """Called once ``per_batch_transform`` has been applied to a batch.""" @@ -58,14 +49,8 @@ def run_for_all_callbacks(self, *args, method_name: str, **kwargs): def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: self.run_for_all_callbacks(sample, running_stage, method_name="on_load_sample") - def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - self.run_for_all_callbacks(sample, running_stage, method_name="on_pre_tensor_transform") - - def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - self.run_for_all_callbacks(sample, running_stage, method_name="on_to_tensor_transform") - - def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: - self.run_for_all_callbacks(sample, running_stage, method_name="on_post_tensor_transform") + def on_per_sample_transform(self, sample: Any, running_stage: RunningStage) -> None: + self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform") def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform") @@ -147,9 +132,7 @@ def from_inputs( 'test': {}, 'val': { 'load_sample': [0, 1, 2, 3, 4], - 'pre_tensor_transform': [0, 1, 2, 3, 4], - 'to_tensor_transform': [0, 1, 2, 3, 4], - 'post_tensor_transform': [0, 1, 2, 3, 4], + 'per_sample_transform': [0, 1, 2, 3, 4], 'collate': [tensor([0, 1, 2, 3, 4])], 'per_batch_transform': [tensor([0, 1, 2, 3, 4])]}, 'predict': {} @@ -179,14 +162,8 @@ def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None: def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None: self._store(sample, "load_sample", running_stage) - def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - self._store(sample, "pre_tensor_transform", running_stage) - - def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None: - self._store(sample, "to_tensor_transform", running_stage) - - def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None: - self._store(sample, "post_tensor_transform", running_stage) + def on_per_sample_transform(self, sample: Any, running_stage: RunningStage) -> None: + self._store(sample, "per_sample_transform", running_stage) def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None: self._store(batch, "per_batch_transform", running_stage) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 59c14f477c..31f44a65d3 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -571,7 +571,7 @@ def from_input( InputFormat.FOLDERS, train_data="train_folder", train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, ) """ @@ -819,7 +819,7 @@ def from_tensors( train_files=torch.rand(3, 128), train_targets=[1, 0, 1], train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, ) """ @@ -906,7 +906,7 @@ def from_numpy( train_files=np.random.rand(3, 128), train_targets=[1, 0, 1], train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, ) """ @@ -994,7 +994,7 @@ def from_json( "target", train_file="train_data.json", train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, ) @@ -1015,7 +1015,7 @@ def from_json( "target", train_file="train_data.json", train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, feild="data" ) @@ -1102,7 +1102,7 @@ def from_csv( "target", train_file="train_data.csv", train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, ) """ @@ -1182,7 +1182,7 @@ def from_datasets( data_module = DataModule.from_datasets( train_dataset=train_dataset, train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, ) """ @@ -1266,7 +1266,7 @@ def from_fiftyone( data_module = DataModule.from_fiftyone( train_data = train_dataset, train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, ) """ diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index 477b8a416e..baf0c60cd3 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -27,12 +27,7 @@ from flash.core.data.batch import _DeserializeProcessor from flash.core.data.io.input import Input from flash.core.data.io.input_base import InputBase -from flash.core.data.io.input_transform import ( - _InputTransformProcessor, - _InputTransformSequential, - DefaultInputTransform, - InputTransform, -) +from flash.core.data.io.input_transform import _InputTransformProcessor, DefaultInputTransform, InputTransform from flash.core.data.io.output import _OutputProcessor, Output from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform from flash.core.data.process import Deserializer @@ -270,27 +265,17 @@ def _create_collate_input_transform_processors( else worker_collate_fn ) - assert_contains_tensor = self._is_overriden_recursive( - "to_tensor_transform", input_transform, InputTransform, prefix=_STAGES_PREFIX[stage] - ) + per_sample_transform = getattr(input_transform, func_names["per_sample_transform"]) deserialize_processor = _DeserializeProcessor( self._deserializer, input_transform, - getattr(input_transform, func_names["pre_tensor_transform"]), - getattr(input_transform, func_names["to_tensor_transform"]), + per_sample_transform, ) worker_input_transform_processor = _InputTransformProcessor( input_transform, worker_collate_fn, - _InputTransformSequential( - input_transform, - None if is_serving else getattr(input_transform, func_names["pre_tensor_transform"]), - None if is_serving else getattr(input_transform, func_names["to_tensor_transform"]), - getattr(input_transform, func_names["post_tensor_transform"]), - stage, - assert_contains_tensor=assert_contains_tensor, - ), + self._identity if is_serving else per_sample_transform, getattr(input_transform, func_names["per_batch_transform"]), stage, ) diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index 55af982c2e..363a213201 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -260,7 +260,7 @@ def load_sample(sample: Mapping[str, Any], dataset: Optional[Any] = None) -> Any Returns: The loaded sample as a mapping with string keys (e.g. "input", "target") that can be processed by the - :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`. + :meth:`~flash.core.data.io.input_transform.InputTransform.per_sample_transform`. Example:: diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index 1fcfecd8fc..56ab72ee81 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -17,7 +17,6 @@ import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor from torch.utils.data._utils.collate import default_collate from flash.core.data.callback import ControlFlow, FlashCallback @@ -28,14 +27,11 @@ CollateFn, PerBatchTransform, PerBatchTransformOnDevice, + PerSampleTransform, PerSampleTransformOnDevice, - PostTensorTransform, - PreTensorTransform, - ToTensorTransform, ) from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import ( - _contains_any_tensor, _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX, convert_to_modules, @@ -64,30 +60,12 @@ class InputTransform(BaseInputTransform, Properties): The :class:`~flash.core.data.io.input_transform.InputTransform` supports the following hooks: - - ``pre_tensor_transform``: Performs transforms on a single data sample. + - ``per_sample_transform``: Performs transforms on a single data sample. Example:: * Input: Receive a PIL Image and its label. - * Action: Rotate the PIL Image. - - * Output: Return the rotated PIL image and its label. - - - ``to_tensor_transform``: Converts a single data sample to a tensor / data structure containing tensors. - Example:: - - * Input: Receive the rotated PIL Image and its label. - - * Action: Convert the rotated PIL Image to a tensor. - - * Output: Return the tensored image and its label. - - - ``post_tensor_transform``: Performs transform on a single tensor sample. - Example:: - - * Input: Receive the tensored image and its label. - - * Action: Flip the tensored image randomly. + * Action: Rotate the PIL Image and Convert the rotated PIL Image to a tensor. * Output: Return the tensored image and its label. @@ -140,31 +118,26 @@ class CustomInputTransform(InputTransform): def default_transforms() -> Mapping[str, Callable]: return { - "to_tensor_transform": transforms.ToTensor(), + "per_sample_transform": transforms.ToTensor(), "collate": torch.utils.data._utils.collate.default_collate, } def train_default_transforms() -> Mapping[str, Callable]: return { - "pre_tensor_transform": transforms.RandomHorizontalFlip(), - "to_tensor_transform": transforms.ToTensor(), + "per_sample_transform": T.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()]), "collate": torch.utils.data._utils.collate.default_collate, } When overriding hooks for particular stages, you can prefix with ``train``, ``val``, ``test`` or ``predict``. For - example, you can achieve the same as the above example by implementing ``train_pre_tensor_transform`` and - ``train_to_tensor_transform``. + example, you can achieve the same as the above example by implementing ``train_per_sample_transform``. Example:: class CustomInputTransform(InputTransform): - def train_pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: + def train_per_sample_transform(self, sample: PIL.Image) -> PIL.Image: return transforms.RandomHorizontalFlip()(sample) - def to_tensor_transform(self, sample: PIL.Image) -> torch.Tensor: - return transforms.ToTensor()(sample) - def collate(self, samples: List[torch.Tensor]) -> torch.Tensor: return torch.utils.data._utils.collate.default_collate(samples) @@ -175,7 +148,7 @@ def collate(self, samples: List[torch.Tensor]) -> torch.Tensor: class CustomInputTransform(InputTransform): - def pre_tensor_transform(self, sample: PIL.Image) -> PIL.Image: + def per_sample_transform(self, sample: PIL.Image) -> PIL.Image: if self.training: # logic for training @@ -268,9 +241,9 @@ def _check_transforms( return transform if isinstance(transform, list): - transform = {"pre_tensor_transform": ApplyToKeys(DataKeys.INPUT, torch.nn.Sequential(*transform))} + transform = {"per_sample_transform": ApplyToKeys(DataKeys.INPUT, torch.nn.Sequential(*transform))} elif callable(transform): - transform = {"pre_tensor_transform": ApplyToKeys(DataKeys.INPUT, transform)} + transform = {"per_sample_transform": ApplyToKeys(DataKeys.INPUT, transform)} if not isinstance(transform, Dict): raise MisconfigurationException( @@ -407,17 +380,9 @@ def _apply_process_state_transform( else: return self._apply_batch_transform(batch) - def pre_tensor_transform(self, sample: Any) -> Any: + def per_sample_transform(self, sample: Any) -> Any: """Transforms to apply on a single object.""" - return self._apply_process_state_transform(PreTensorTransform, sample=sample) - - def to_tensor_transform(self, sample: Any) -> Tensor: - """Transforms to convert single object to a tensor.""" - return self._apply_process_state_transform(ToTensorTransform, sample=sample) - - def post_tensor_transform(self, sample: Tensor) -> Tensor: - """Transforms to apply on a tensor.""" - return self._apply_process_state_transform(PostTensorTransform, sample=sample) + return self._apply_process_state_transform(PerSampleTransform, sample=sample) def per_batch_transform(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). @@ -534,102 +499,25 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): return cls(**state_dict) -class _InputTransformSequential(torch.nn.Module): - """This class is used to chain 3 functions together for the _InputTransformProcessor ``per_sample_transform`` - function. - - 1. ``pre_tensor_transform`` - 2. ``to_tensor_transform`` - 3. ``post_tensor_transform`` - """ - - def __init__( - self, - input_transform: InputTransform, - pre_tensor_transform: Optional[Callable], - to_tensor_transform: Optional[Callable], - post_tensor_transform: Callable, - stage: RunningStage, - assert_contains_tensor: bool = False, - ): - super().__init__() - self.input_transform = input_transform - self.callback = ControlFlow(self.input_transform.callbacks) - self.pre_tensor_transform = convert_to_modules(pre_tensor_transform) - self.to_tensor_transform = convert_to_modules(to_tensor_transform) - self.post_tensor_transform = convert_to_modules(post_tensor_transform) - self.stage = stage - self.assert_contains_tensor = assert_contains_tensor - - self._current_stage_context = CurrentRunningStageContext(stage, input_transform, reset=False) - self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", input_transform) - self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", input_transform) - self._post_tensor_transform_context = CurrentFuncContext("post_tensor_transform", input_transform) - - def forward(self, sample: Any) -> Any: - self.callback.on_load_sample(sample, self.stage) - - with self._current_stage_context: - if self.pre_tensor_transform is not None: - with self._pre_tensor_transform_context: - sample = self.pre_tensor_transform(sample) - self.callback.on_pre_tensor_transform(sample, self.stage) - - if self.to_tensor_transform is not None: - with self._to_tensor_transform_context: - sample = self.to_tensor_transform(sample) - self.callback.on_to_tensor_transform(sample, self.stage) - - if self.assert_contains_tensor: - if not _contains_any_tensor(sample): - raise MisconfigurationException( - "When ``to_tensor_transform`` is overriden, " - "``DataPipeline`` expects the outputs to be ``tensors``" - ) - - with self._post_tensor_transform_context: - sample = self.post_tensor_transform(sample) - self.callback.on_post_tensor_transform(sample, self.stage) - - return sample - - def __str__(self) -> str: - return ( - f"{self.__class__.__name__}:\n" - f"\t(pre_tensor_transform): {str(self.pre_tensor_transform)}\n" - f"\t(to_tensor_transform): {str(self.to_tensor_transform)}\n" - f"\t(post_tensor_transform): {str(self.post_tensor_transform)}\n" - f"\t(assert_contains_tensor): {str(self.assert_contains_tensor)}\n" - f"\t(stage): {str(self.stage)}" - ) - - class _InputTransformProcessor(torch.nn.Module): """ - This class is used to encapsultate the following functions of a InputTransformInputTransform Object: + This class is used to encapsulate the following functions of a InputTransformInputTransform Object: Inside a worker: per_sample_transform: Function to transform an individual sample - Inside a worker, it is actually make of 3 functions: - * pre_tensor_transform - * to_tensor_transform - * post_tensor_transform collate: Function to merge sample into a batch per_batch_transform: Function to transform an individual batch - * per_batch_transform Inside main process: - per_sample_transform: Function to transform an individual sample - * per_sample_transform_on_device + per_sample_transform_on_device: Function to transform an individual sample collate: Function to merge sample into a batch - per_batch_transform: Function to transform an individual batch - * per_batch_transform_on_device + per_batch_transform_on_device: Function to transform an individual batch """ def __init__( self, input_transform: InputTransform, collate_fn: Callable, - per_sample_transform: Union[Callable, _InputTransformSequential], + per_sample_transform: Callable, per_batch_transform: Callable, stage: RunningStage, apply_per_sample_transform: bool = True, @@ -659,6 +547,10 @@ def _extract_metadata( return samples, metadata if any(m is not None for m in metadata) else None def forward(self, samples: Sequence[Any]) -> Any: + if not self.on_device: + for sample in samples: + self.callback.on_load_sample(sample, self.stage) + # we create a new dict to prevent from potential memory leaks # assuming that the dictionary samples are stored in between and # potentially modified before the transforms are applied. @@ -678,6 +570,8 @@ def forward(self, samples: Sequence[Any]) -> Any: sample = self.per_sample_transform(sample) if self.on_device: self.callback.on_per_sample_transform_on_device(sample, self.stage) + else: + self.callback.on_per_sample_transform(sample, self.stage) _samples.append(sample) samples = type(_samples)(_samples) diff --git a/flash/core/data/states.py b/flash/core/data/states.py index 14f3b2714d..8dbb738b4a 100644 --- a/flash/core/data/states.py +++ b/flash/core/data/states.py @@ -5,19 +5,7 @@ @dataclass(unsafe_hash=True, frozen=True) -class PreTensorTransform(ProcessState): - - transform: Optional[Callable] = None - - -@dataclass(unsafe_hash=True, frozen=True) -class ToTensorTransform(ProcessState): - - transform: Optional[Callable] = None - - -@dataclass(unsafe_hash=True, frozen=True) -class PostTensorTransform(ProcessState): +class PerSampleTransform(ProcessState): transform: Optional[Callable] = None diff --git a/flash/core/data/utils.py b/flash/core/data/utils.py index 4810c9fca7..1de445d1f7 100644 --- a/flash/core/data/utils.py +++ b/flash/core/data/utils.py @@ -15,12 +15,11 @@ import os.path import tarfile import zipfile -from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Set, Type +from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Set import requests import torch from pytorch_lightning.utilities.apply_func import apply_to_collection -from torch import Tensor from tqdm.auto import tqdm as tq from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE @@ -48,9 +47,7 @@ } _INPUT_TRANSFORM_FUNCS: Set[str] = { - "pre_tensor_transform", - "to_tensor_transform", - "post_tensor_transform", + "per_sample_transform", "per_batch_transform", "per_sample_transform_on_device", "per_batch_transform_on_device", @@ -176,19 +173,6 @@ def extract_tarfile(file_path: str, extract_path: str, mode: str): extract_tarfile(local_filename, path, "r:bz2") -def _contains_any_tensor(value: Any, dtype: Type = Tensor) -> bool: - # TODO: we should refactor FlashDatasetFolder to better integrate - # with DataPipeline. That way, we wouldn't need this check. - # This is because we are running transforms in both places. - if isinstance(value, dtype): - return True - if isinstance(value, (list, tuple)): - return any(_contains_any_tensor(v, dtype=dtype) for v in value) - if isinstance(value, dict): - return any(_contains_any_tensor(v, dtype=dtype) for v in value.values()) - return False - - class FuncModule(torch.nn.Module): """This class is used to wrap a callable within a nn.Module and apply the wrapped function in `__call__`""" diff --git a/flash/core/integrations/icevision/transforms.py b/flash/core/integrations/icevision/transforms.py index 8bc15745cc..9c98a737fc 100644 --- a/flash/core/integrations/icevision/transforms.py +++ b/flash/core/integrations/icevision/transforms.py @@ -248,7 +248,7 @@ def forward(self, x): def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """The default transforms from IceVision.""" return { - "pre_tensor_transform": IceVisionTransformAdapter([*A.resize_and_pad(image_size), A.Normalize()]), + "per_sample_transform": IceVisionTransformAdapter([*A.resize_and_pad(image_size), A.Normalize()]), } @@ -256,5 +256,5 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """The default augmentations from IceVision.""" return { - "pre_tensor_transform": IceVisionTransformAdapter([*A.aug_tfms(size=image_size), A.Normalize()]), + "per_sample_transform": IceVisionTransformAdapter([*A.aug_tfms(size=image_size), A.Normalize()]), } diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py index 458c39997a..ed975b9ba6 100644 --- a/flash/graph/classification/data.py +++ b/flash/graph/classification/data.py @@ -53,7 +53,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): @staticmethod def default_transforms() -> Optional[Dict[str, Callable]]: - return {"pre_tensor_transform": NormalizeFeatures(), "collate": Batch.from_data_list} + return {"per_sample_transform": NormalizeFeatures(), "collate": Batch.from_data_list} class GraphClassificationData(DataModule): diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index e5b6fe1459..e8128f035c 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -532,16 +532,8 @@ def show_load_sample(self, samples: List[Any], running_stage: RunningStage): win_title: str = f"{running_stage} - show_load_sample" self._show_images_and_labels(samples, len(samples), win_title) - def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - win_title: str = f"{running_stage} - show_pre_tensor_transform" - self._show_images_and_labels(samples, len(samples), win_title) - - def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - win_title: str = f"{running_stage} - show_to_tensor_transform" - self._show_images_and_labels(samples, len(samples), win_title) - - def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - win_title: str = f"{running_stage} - show_post_tensor_transform" + def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_per_sample_transform" self._show_images_and_labels(samples, len(samples), win_title) def show_per_batch_transform(self, batch: List[Any], running_stage): diff --git a/flash/image/classification/transforms.py b/flash/image/classification/transforms.py index 36e91b4c19..db91e2219a 100644 --- a/flash/image/classification/transforms.py +++ b/flash/image/classification/transforms.py @@ -15,7 +15,6 @@ from typing import Callable, Dict, Tuple import torch -from torch import nn from flash.core.data.io.input import DataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms @@ -25,7 +24,6 @@ import kornia as K if _TORCHVISION_AVAILABLE: - import torchvision from torchvision import transforms as T if _ALBUMENTATIONS_AVAILABLE: @@ -49,33 +47,38 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: collate the batch, and apply normalization.""" if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": # Better approach as all transforms are applied on tensor directly - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + per_sample_transform = T.Compose( + [ + ApplyToKeys( + DataKeys.INPUT, + T.Compose([T.ToTensor(), K.geometry.Resize(image_size)]), + ), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), - ), - "post_tensor_transform": ApplyToKeys( - DataKeys.INPUT, - K.geometry.Resize(image_size), - ), - "collate": kornia_collate, - "per_batch_transform_on_device": ApplyToKeys( - DataKeys.INPUT, - K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), - ), - } - return { - "pre_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.Resize(image_size)), - "to_tensor_transform": nn.Sequential( - ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DataKeys.TARGET, torch.as_tensor), - ), - "post_tensor_transform": ApplyToKeys( + ] + ) + per_batch_transform_on_device = ApplyToKeys( DataKeys.INPUT, - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + K.augmentation.Normalize(torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])), + ) + return dict( + per_sample_transform=per_sample_transform, + collate=kornia_collate, + per_batch_transform_on_device=per_batch_transform_on_device, + ) + return dict( + per_sample_transform=T.Compose( + [ + ApplyToKeys( + DataKeys.INPUT, + T.Compose( + [T.ToTensor(), T.Resize(image_size), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] + ), + ), + ApplyToKeys(DataKeys.TARGET, torch.as_tensor), + ] ), - "collate": kornia_collate, - } + collate=kornia_collate, + ) def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: @@ -83,8 +86,8 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": # Better approach as all transforms are applied on tensor directly transforms = { - "post_tensor_transform": ApplyToKeys(DataKeys.INPUT, K.augmentation.RandomHorizontalFlip()), + "per_sample_transform": ApplyToKeys(DataKeys.INPUT, K.augmentation.RandomHorizontalFlip()), } else: - transforms = {"pre_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.RandomHorizontalFlip())} + transforms = {"per_sample_transform": ApplyToKeys(DataKeys.INPUT, T.RandomHorizontalFlip())} return merge_transforms(default_transforms(image_size), transforms) diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 1c470e3760..54c6510d9c 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -20,10 +20,8 @@ CollateFn, PerBatchTransform, PerBatchTransformOnDevice, + PerSampleTransform, PerSampleTransformOnDevice, - PostTensorTransform, - PreTensorTransform, - ToTensorTransform, ) from flash.core.data.transforms import ApplyToKeys from flash.core.registry import FlashRegistry @@ -120,15 +118,16 @@ def __init__( ) transform, collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) - to_tensor_transform = ApplyToKeys( - DataKeys.INPUT, - transform, - ) self.adapter.set_state(CollateFn(collate_fn)) - self.adapter.set_state(ToTensorTransform(to_tensor_transform)) - self.adapter.set_state(PostTensorTransform(None)) - self.adapter.set_state(PreTensorTransform(None)) + self.adapter.set_state( + PerSampleTransform( + ApplyToKeys( + DataKeys.INPUT, + transform, + ) + ) + ) self.adapter.set_state(PerSampleTransformOnDevice(None)) self.adapter.set_state(PerBatchTransform(None)) self.adapter.set_state(PerBatchTransformOnDevice(None)) diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index d685fc3c22..c15a489235 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -108,7 +108,7 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): def default_transforms(self) -> Dict[str, Callable]: return { - "to_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), ApplyToKeys( DataKeys.TARGET, diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index f4c89242d1..4221338223 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -506,6 +506,6 @@ def show_load_sample(self, samples: List[Any], running_stage: RunningStage): win_title: str = f"{running_stage} - show_load_sample" self._show_images_and_labels(samples, len(samples), win_title) - def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - win_title: str = f"{running_stage} - show_post_tensor_transform" + def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_per_sample_transform" self._show_images_and_labels(samples, len(samples), win_title) diff --git a/flash/image/segmentation/transforms.py b/flash/image/segmentation/transforms.py index 886f1e5c27..1b561ff439 100644 --- a/flash/image/segmentation/transforms.py +++ b/flash/image/segmentation/transforms.py @@ -37,7 +37,7 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """The default transforms for semantic segmentation: resize the image and mask, collate the batch, and apply normalization.""" return { - "post_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( ApplyToKeys( [DataKeys.INPUT, DataKeys.TARGET], KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation="nearest")), @@ -53,7 +53,7 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] return merge_transforms( default_transforms(image_size), { - "post_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( ApplyToKeys( [DataKeys.INPUT, DataKeys.TARGET], KorniaParallelTransforms(K.augmentation.RandomHorizontalFlip(p=0.5)), @@ -66,7 +66,7 @@ def train_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] def predict_default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: """During predict, we apply the default transforms only on DataKeys.INPUT.""" return { - "post_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( ApplyToKeys( DataKeys.INPUT, K.geometry.Resize(image_size, interpolation="nearest"), diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index cdb7524732..21fa5cdb61 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -86,17 +86,14 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): def default_transforms(self) -> Optional[Dict[str, Callable]]: if self.training: return dict( - to_tensor_transform=T.ToTensor(), + per_sample_transform=T.ToTensor(), per_sample_transform_on_device=nn.Sequential( T.Resize(self.image_size), T.CenterCrop(self.image_size), ), ) if self.predicting: - return dict( - pre_tensor_transform=T.Resize(self.image_size), - to_tensor_transform=T.ToTensor(), - ) + return dict(per_sample_transform=T.Compose([T.Resize(self.image_size), T.ToTensor()])) # Style transfer doesn't support a validation or test phase, so we return nothing here return None diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index e3786b5c64..265eaf7d1a 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -134,13 +134,13 @@ def input_to_tensor(input: np.ndarray): return torch.from_numpy(input).float() def default_transforms(self) -> Optional[Dict[str, Callable]]: - """Configures the default ``to_tensor_transform``. + """Configures the default ``per_sample_transform``. Returns: Our dictionary of transforms. """ return { - "to_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( ApplyToKeys(DataKeys.INPUT, self.input_to_tensor), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), @@ -255,14 +255,5 @@ class TemplateVisualization(BaseVisualization): def show_load_sample(self, samples: List[Any], running_stage: RunningStage): print(samples) - def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage): print(samples) - - def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - print(samples) - - def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - print(samples) - - def show_per_batch_transform(self, batch: List[Any], running_stage): - print(batch) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index daa31b8d69..740f948677 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -542,7 +542,7 @@ def from_parquet( "target", train_file="train_data.parquet", train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, ) """ diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 344b935982..d5bdda536a 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -777,7 +777,7 @@ def from_json( data_module = QuestionAnsweringData.from_json( train_file="train_data.json", train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, backbone="distilbert-base-uncased", max_source_length=384, @@ -880,7 +880,7 @@ def from_csv( "target", train_file="train_data.csv", train_transform={ - "to_tensor_transform": torch.as_tensor, + "per_sample_transform": torch.as_tensor, }, backbone="distilbert-base-uncased", max_source_length=384, diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index 204ad8efbc..ee9cdca7ea 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -313,21 +313,21 @@ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool) -> "VideoClas def default_transforms(self) -> Dict[str, Callable]: if self.training: - post_tensor_transform = [ + per_sample_transform = [ RandomCrop(244, pad_if_needed=True), RandomHorizontalFlip(p=0.5), ] else: - post_tensor_transform = [ + per_sample_transform = [ CenterCrop(244), ] return { - "post_tensor_transform": Compose( + "per_sample_transform": Compose( [ ApplyTransformToKey( key="video", - transform=Compose([UniformTemporalSubsample(8)] + post_tensor_transform), + transform=Compose([UniformTemporalSubsample(8)] + per_sample_transform), ), ] ), diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py index 1459acca63..d335349aea 100644 --- a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -35,26 +35,28 @@ val_dataset = l2l.vision.datasets.MiniImagenet(root="data", mode="validation", download=True) train_transform = { - "to_tensor_transform": nn.Sequential( - ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), + "per_sample_transform": nn.Sequential( + ApplyToKeys( + DataKeys.INPUT, + nn.Sequential( + torchvision.transforms.ToTensor(), + Kg.Resize((196, 196)), + # SPATIAL + Ka.RandomHorizontalFlip(p=0.25), + Ka.RandomRotation(degrees=90.0, p=0.25), + Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25), + Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25), + # PIXEL-LEVEL + Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness + Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation + Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast + Ka.ColorJitter(hue=1 / 30, p=0.25), # hue + Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25), + Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25), + ), + ), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), - "post_tensor_transform": ApplyToKeys( - DataKeys.INPUT, - Kg.Resize((196, 196)), - # SPATIAL - Ka.RandomHorizontalFlip(p=0.25), - Ka.RandomRotation(degrees=90.0, p=0.25), - Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25), - Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25), - # PIXEL-LEVEL - Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness - Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation - Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast - Ka.ColorJitter(hue=1 / 30, p=0.25), # hue - Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25), - Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25), - ), "collate": kornia_collate, "per_batch_transform_on_device": ApplyToKeys( DataKeys.INPUT, diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index d43712cb28..fac7ad14fc 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -213,8 +213,8 @@ def test_from_filepaths_visualise(tmpdir): # call show functions # dm.show_train_batch() - dm.show_train_batch("pre_tensor_transform") - dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_train_batch("per_sample_transform") + dm.show_train_batch(["per_sample_transform", "per_batch_transform"]) @pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") @@ -248,9 +248,7 @@ def test_from_filepaths_visualise_multilabel(tmpdir): # call show functions dm.show_train_batch() - dm.show_train_batch("pre_tensor_transform") - dm.show_train_batch("to_tensor_transform") - dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_train_batch("per_sample_transform") dm.show_val_batch("per_batch_transform") @@ -274,7 +272,7 @@ def test_from_filepaths_splits(tmpdir): assert len(train_filepaths) == len(train_labels) _to_tensor = { - "to_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), diff --git a/tests/core/data/io/test_input_transform.py b/tests/core/data/io/test_input_transform.py index 08ccc76cf5..e5d2426a37 100644 --- a/tests/core/data/io/test_input_transform.py +++ b/tests/core/data/io/test_input_transform.py @@ -20,11 +20,7 @@ from flash import DataModule from flash.core.data.io.input import InputFormat -from flash.core.data.io.input_transform import ( - _InputTransformProcessor, - _InputTransformSequential, - DefaultInputTransform, -) +from flash.core.data.io.input_transform import _InputTransformProcessor, DefaultInputTransform from flash.core.utilities.stages import RunningStage @@ -60,25 +56,6 @@ def test_input_transform_processor_str(): ) -def test_sequential_str(): - sequential = _InputTransformSequential( - Mock(name="input_transform"), - torch.softmax, - torch.as_tensor, - torch.relu, - RunningStage.TRAINING, - True, - ) - assert str(sequential) == ( - "_InputTransformSequential:\n" - "\t(pre_tensor_transform): FuncModule(softmax)\n" - "\t(to_tensor_transform): FuncModule(as_tensor)\n" - "\t(post_tensor_transform): FuncModule(relu)\n" - "\t(assert_contains_tensor): True\n" - "\t(stage): RunningStage.TRAINING" - ) - - def test_input_of_name(): input_transform = CustomInputTransform() diff --git a/tests/core/data/test_base_viz.py b/tests/core/data/test_base_viz.py index 0a865a14de..a42d2ef7b5 100644 --- a/tests/core/data/test_base_viz.py +++ b/tests/core/data/test_base_viz.py @@ -39,25 +39,13 @@ def _rand_image(): class CustomBaseVisualization(BaseVisualization): def __init__(self): super().__init__() - - self.show_load_sample_called = False - self.show_pre_tensor_transform_called = False - self.show_to_tensor_transform_called = False - self.show_post_tensor_transform_called = False - self.show_collate_called = False - self.per_batch_transform_called = False + self.check_reset() def show_load_sample(self, samples: List[Any], running_stage: RunningStage): self.show_load_sample_called = True - def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - self.show_pre_tensor_transform_called = True - - def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - self.show_to_tensor_transform_called = True - - def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - self.show_post_tensor_transform_called = True + def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage): + self.show_per_sample_transform_called = True def show_collate(self, batch: Sequence, running_stage: RunningStage) -> None: self.show_collate_called = True @@ -67,9 +55,7 @@ def show_per_batch_transform(self, batch: Sequence, running_stage: RunningStage) def check_reset(self): self.show_load_sample_called = False - self.show_pre_tensor_transform_called = False - self.show_to_tensor_transform_called = False - self.show_post_tensor_transform_called = False + self.show_per_sample_transform_called = False self.show_collate_called = False self.per_batch_transform_called = False @@ -112,6 +98,7 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: for _ in range(num_tests): for fcn_name in _CALLBACK_FUNCS: dm.data_fetcher.reset() + assert dm.data_fetcher.batches == {"predict": {}, "test": {}, "train": {}, "val": {}} fcn = getattr(dm, f"show_{stage}_batch") fcn(fcn_name, reset=False) @@ -131,12 +118,12 @@ def _get_result(function_name: str): res = _get_result("load_sample") assert isinstance(res[0][DataKeys.TARGET], int) - res = _get_result("to_tensor_transform") + res = _get_result("per_sample_transform") assert len(res) == B assert isinstance(_extract_data(res), torch.Tensor) if not is_predict: - res = _get_result("to_tensor_transform") + res = _get_result("per_sample_transform") assert isinstance(res[0][DataKeys.TARGET], torch.Tensor) res = _get_result("collate") @@ -154,19 +141,17 @@ def _get_result(function_name: str): assert res[0][DataKeys.TARGET].shape == (B,) assert dm.data_fetcher.show_load_sample_called - assert dm.data_fetcher.show_pre_tensor_transform_called - assert dm.data_fetcher.show_to_tensor_transform_called - assert dm.data_fetcher.show_post_tensor_transform_called + assert dm.data_fetcher.show_per_sample_transform_called assert dm.data_fetcher.show_collate_called assert dm.data_fetcher.per_batch_transform_called - dm.data_fetcher.check_reset() + dm.data_fetcher.reset() @pytest.mark.parametrize( "func_names, valid", [ (["load_sample"], True), (["not_a_hook"], False), - (["load_sample", "pre_tensor_transform"], True), + (["load_sample", "per_sample_transform"], True), (["load_sample", "not_a_hook"], True), ], ) diff --git a/tests/core/data/test_callback.py b/tests/core/data/test_callback.py index b9245d2760..e8885bc081 100644 --- a/tests/core/data/test_callback.py +++ b/tests/core/data/test_callback.py @@ -40,9 +40,7 @@ def test_flash_callback(_, __, tmpdir): assert callback_mock.method_calls == [ call.on_load_sample(ANY, RunningStage.TRAINING), - call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), - call.on_to_tensor_transform(ANY, RunningStage.TRAINING), - call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_per_sample_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), ] @@ -66,29 +64,21 @@ def __init__(self): assert callback_mock.method_calls == [ call.on_load_sample(ANY, RunningStage.TRAINING), - call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), - call.on_to_tensor_transform(ANY, RunningStage.TRAINING), - call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_per_sample_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), call.on_load_sample(ANY, RunningStage.VALIDATING), - call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), - call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), - call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_per_sample_transform(ANY, RunningStage.VALIDATING), call.on_collate(ANY, RunningStage.VALIDATING), call.on_per_batch_transform(ANY, RunningStage.VALIDATING), call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), call.on_load_sample(ANY, RunningStage.TRAINING), - call.on_pre_tensor_transform(ANY, RunningStage.TRAINING), - call.on_to_tensor_transform(ANY, RunningStage.TRAINING), - call.on_post_tensor_transform(ANY, RunningStage.TRAINING), + call.on_per_sample_transform(ANY, RunningStage.TRAINING), call.on_collate(ANY, RunningStage.TRAINING), call.on_per_batch_transform(ANY, RunningStage.TRAINING), call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING), call.on_load_sample(ANY, RunningStage.VALIDATING), - call.on_pre_tensor_transform(ANY, RunningStage.VALIDATING), - call.on_to_tensor_transform(ANY, RunningStage.VALIDATING), - call.on_post_tensor_transform(ANY, RunningStage.VALIDATING), + call.on_per_sample_transform(ANY, RunningStage.VALIDATING), call.on_collate(ANY, RunningStage.VALIDATING), call.on_per_batch_transform(ANY, RunningStage.VALIDATING), call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING), diff --git a/tests/core/data/test_callbacks.py b/tests/core/data/test_callbacks.py index bac9e660f7..0dfedfc772 100644 --- a/tests/core/data/test_callbacks.py +++ b/tests/core/data/test_callbacks.py @@ -26,9 +26,7 @@ def test_base_data_fetcher(tmpdir): class CheckData(BaseDataFetcher): def check(self): assert self.batches["val"]["load_sample"] == [0, 1, 2, 3, 4] - assert self.batches["val"]["pre_tensor_transform"] == [0, 1, 2, 3, 4] - assert self.batches["val"]["to_tensor_transform"] == [0, 1, 2, 3, 4] - assert self.batches["val"]["post_tensor_transform"] == [0, 1, 2, 3, 4] + assert self.batches["val"]["per_sample_transform"] == [0, 1, 2, 3, 4] assert torch.equal(self.batches["val"]["collate"][0], tensor([0, 1, 2, 3, 4])) assert torch.equal(self.batches["val"]["per_batch_transform"][0], tensor([0, 1, 2, 3, 4])) assert self.batches["train"] == {} diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index e0c6a0d6b2..6c269db691 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -33,7 +33,7 @@ from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform from flash.core.data.process import Deserializer from flash.core.data.properties import ProcessState -from flash.core.data.states import PerBatchTransformOnDevice, ToTensorTransform +from flash.core.data.states import PerBatchTransformOnDevice, PerSampleTransform from flash.core.model import Task from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -128,13 +128,7 @@ class SubOutputTransform(OutputTransform): def test_data_pipeline_is_overriden_and_resolve_function_hierarchy(tmpdir): class CustomInputTransform(DefaultInputTransform): - def val_pre_tensor_transform(self, *_, **__): - pass - - def predict_to_tensor_transform(self, *_, **__): - pass - - def train_post_tensor_transform(self, *_, **__): + def val_per_sample_transform(self, *_, **__): pass def test_collate(self, *_, **__): @@ -177,23 +171,11 @@ def test_per_batch_transform_on_device(self, *_, **__): for k in data_pipeline.INPUT_TRANSFORM_FUNCS } - # pre_tensor_transform - assert train_func_names["pre_tensor_transform"] == "pre_tensor_transform" - assert val_func_names["pre_tensor_transform"] == "val_pre_tensor_transform" - assert test_func_names["pre_tensor_transform"] == "pre_tensor_transform" - assert predict_func_names["pre_tensor_transform"] == "pre_tensor_transform" - - # to_tensor_transform - assert train_func_names["to_tensor_transform"] == "to_tensor_transform" - assert val_func_names["to_tensor_transform"] == "to_tensor_transform" - assert test_func_names["to_tensor_transform"] == "to_tensor_transform" - assert predict_func_names["to_tensor_transform"] == "predict_to_tensor_transform" - - # post_tensor_transform - assert train_func_names["post_tensor_transform"] == "train_post_tensor_transform" - assert val_func_names["post_tensor_transform"] == "post_tensor_transform" - assert test_func_names["post_tensor_transform"] == "post_tensor_transform" - assert predict_func_names["post_tensor_transform"] == "post_tensor_transform" + # per_sample_transform + assert train_func_names["per_sample_transform"] == "per_sample_transform" + assert val_func_names["per_sample_transform"] == "val_per_sample_transform" + assert test_func_names["per_sample_transform"] == "per_sample_transform" + assert predict_func_names["per_sample_transform"] == "per_sample_transform" # collate assert train_func_names["collate"] == "collate" @@ -218,31 +200,19 @@ def test_per_batch_transform_on_device(self, *_, **__): test_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.TESTING) predict_worker_input_transform_processor = data_pipeline.worker_input_transform_processor(RunningStage.PREDICTING) - _seq = train_worker_input_transform_processor.per_sample_transform - assert _seq.pre_tensor_transform.func == input_transform.pre_tensor_transform - assert _seq.to_tensor_transform.func == input_transform.to_tensor_transform - assert _seq.post_tensor_transform.func == input_transform.train_post_tensor_transform + assert train_worker_input_transform_processor.per_sample_transform.func == input_transform.per_sample_transform assert train_worker_input_transform_processor.collate_fn.func == input_transform.collate assert train_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform - _seq = val_worker_input_transform_processor.per_sample_transform - assert _seq.pre_tensor_transform.func == input_transform.val_pre_tensor_transform - assert _seq.to_tensor_transform.func == input_transform.to_tensor_transform - assert _seq.post_tensor_transform.func == input_transform.post_tensor_transform + assert val_worker_input_transform_processor.per_sample_transform.func == input_transform.val_per_sample_transform assert val_worker_input_transform_processor.collate_fn.func == DataPipeline._identity assert val_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform - _seq = test_worker_input_transform_processor.per_sample_transform - assert _seq.pre_tensor_transform.func == input_transform.pre_tensor_transform - assert _seq.to_tensor_transform.func == input_transform.to_tensor_transform - assert _seq.post_tensor_transform.func == input_transform.post_tensor_transform + assert test_worker_input_transform_processor.per_sample_transform.func == input_transform.per_sample_transform assert test_worker_input_transform_processor.collate_fn.func == input_transform.test_collate assert test_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform - _seq = predict_worker_input_transform_processor.per_sample_transform - assert _seq.pre_tensor_transform.func == input_transform.pre_tensor_transform - assert _seq.to_tensor_transform.func == input_transform.predict_to_tensor_transform - assert _seq.post_tensor_transform.func == input_transform.post_tensor_transform + assert predict_worker_input_transform_processor.per_sample_transform.func == input_transform.per_sample_transform assert predict_worker_input_transform_processor.collate_fn.func == input_transform.collate assert predict_worker_input_transform_processor.per_batch_transform.func == input_transform.per_batch_transform @@ -395,11 +365,7 @@ def on_fit_start(self): @staticmethod def _compare_pre_processor(p1, p2): - p1_seq = p1.per_sample_transform - p2_seq = p2.per_sample_transform - assert p1_seq.pre_tensor_transform.func == p2_seq.pre_tensor_transform.func - assert p1_seq.to_tensor_transform.func == p2_seq.to_tensor_transform.func - assert p1_seq.post_tensor_transform.func == p2_seq.post_tensor_transform.func + assert p1.per_sample_transform.func == p2.per_sample_transform.func assert p1.collate_fn.func == p2.collate_fn.func assert p1.per_batch_transform.func == p2.per_batch_transform.func @@ -595,19 +561,18 @@ class TestInputTransformations(DefaultInputTransform): def __init__(self): super().__init__(inputs={"default": TestInputTransformationsInput()}) - self.train_pre_tensor_transform_called = False + self.train_per_sample_transform_called = False self.train_collate_called = False self.train_per_batch_transform_on_device_called = False - self.val_to_tensor_transform_called = False + self.val_per_sample_transform_called = False self.val_collate_called = False self.val_per_batch_transform_on_device_called = False - self.test_to_tensor_transform_called = False - self.test_post_tensor_transform_called = False + self.test_per_sample_transform_called = False - def train_pre_tensor_transform(self, sample: Any) -> Any: + def train_per_sample_transform(self, sample: Any) -> Any: assert self.training - assert self.current_fn == "pre_tensor_transform" - self.train_pre_tensor_transform_called = True + assert self.current_fn == "per_sample_transform" + self.train_per_sample_transform_called = True return sample + (5,) def train_collate(self, samples) -> Tensor: @@ -622,10 +587,10 @@ def train_per_batch_transform_on_device(self, batch: Any) -> Any: self.train_per_batch_transform_on_device_called = True assert torch.equal(batch, tensor([[0, 1, 2, 3, 5], [0, 1, 2, 3, 5]])) - def val_to_tensor_transform(self, sample: Any) -> Tensor: + def val_per_sample_transform(self, sample: Any) -> Tensor: assert self.validating - assert self.current_fn == "to_tensor_transform" - self.val_to_tensor_transform_called = True + assert self.current_fn == "per_sample_transform" + self.val_per_sample_transform_called = True return sample def val_collate(self, samples) -> Dict[str, Tensor]: @@ -646,22 +611,16 @@ def val_per_batch_transform_on_device(self, batch: Any) -> Any: assert torch.equal(batch["b"], tensor([1, 2])) return [False] - def test_to_tensor_transform(self, sample: Any) -> Tensor: + def test_per_sample_transform(self, sample: Any) -> Tensor: assert self.testing - assert self.current_fn == "to_tensor_transform" - self.test_to_tensor_transform_called = True - return sample - - def test_post_tensor_transform(self, sample: Tensor) -> Tensor: - assert self.testing - assert self.current_fn == "post_tensor_transform" - self.test_post_tensor_transform_called = True + assert self.current_fn == "per_sample_transform" + self.test_per_sample_transform_called = True return sample class TestInputTransformations2(TestInputTransformations): - def val_to_tensor_transform(self, sample: Any) -> Tensor: - self.val_to_tensor_transform_called = True + def val_per_sample_transform(self, sample: Any) -> Tensor: + self.val_per_sample_transform_called = True return {"a": tensor(sample["a"]), "b": tensor(sample["b"])} @@ -701,8 +660,7 @@ def test_datapipeline_transformations(tmpdir): assert datamodule.val_dataloader().dataset[0] == {"a": 0, "b": 1} assert datamodule.val_dataloader().dataset[1] == {"a": 1, "b": 2} - with pytest.raises(MisconfigurationException, match="When ``to_tensor_transform``"): - batch = next(iter(datamodule.val_dataloader())) + batch = next(iter(datamodule.val_dataloader())) datamodule = DataModule.from_input( "default", 1, 1, 1, 1, batch_size=2, num_workers=0, input_transform=TestInputTransformations2() @@ -727,17 +685,16 @@ def test_datapipeline_transformations(tmpdir): input_transform = model._input_transform input = input_transform.input_of_name("default") assert input.train_load_data_called - assert input_transform.train_pre_tensor_transform_called + assert input_transform.train_per_sample_transform_called assert input_transform.train_collate_called assert input_transform.train_per_batch_transform_on_device_called assert input.val_load_data_called assert input.val_load_sample_called - assert input_transform.val_to_tensor_transform_called + assert input_transform.train_per_sample_transform_called assert input_transform.val_collate_called assert input_transform.val_per_batch_transform_on_device_called assert input.test_load_data_called - assert input_transform.test_to_tensor_transform_called - assert input_transform.test_post_tensor_transform_called + assert input_transform.test_per_sample_transform_called assert input.predict_load_data_called @@ -771,7 +728,7 @@ def __init__( def default_transforms(self): return { - "to_tensor_transform": T.Compose([T.ToTensor()]), + "per_sample_transform": T.Compose([T.ToTensor()]), "per_batch_transform_on_device": T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), } @@ -781,7 +738,7 @@ def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) # override default transform to resize images - self.set_state(ToTensorTransform(T.Compose([T.ToTensor(), T.Resize(128)]))) + self.set_state(PerSampleTransform(T.Compose([T.ToTensor(), T.Resize(128)]))) # remove normalization, => image still in [0, 1] range self.set_state(PerBatchTransformOnDevice(None)) @@ -858,7 +815,7 @@ def __init__( val_transform=None, test_transform=None, predict_transform=None, - to_tensor_transform=None, + per_sample_transform=None, train_per_sample_transform_on_device=None, ): super().__init__( @@ -868,10 +825,10 @@ def __init__( predict_transform=predict_transform, inputs={"default": ImageInput()}, ) - self._to_tensor = to_tensor_transform + self._to_tensor = per_sample_transform self._train_per_sample_transform_on_device = train_per_sample_transform_on_device - def to_tensor_transform(self, pil_image: Image.Image) -> Tensor: + def per_sample_transform(self, pil_image: Image.Image) -> Tensor: # convert pil image into a tensor return self._to_tensor(pil_image) @@ -903,7 +860,7 @@ class CustomDataModule(DataModule): "test_folder", None, batch_size=2, - to_tensor_transform=T.ToTensor(), + per_sample_transform=T.ToTensor(), train_per_sample_transform_on_device=T.RandomHorizontalFlip(), ) @@ -934,7 +891,7 @@ def test_input_transform_transforms(tmpdir): with pytest.raises(MisconfigurationException, match="train_transform contains {'choco'}. Only"): DefaultInputTransform(train_transform={"choco": None}) - input_transform = DefaultInputTransform(train_transform={"to_tensor_transform": torch.nn.Linear(1, 1)}) + input_transform = DefaultInputTransform(train_transform={"per_sample_transform": torch.nn.Linear(1, 1)}) # keep is None assert input_transform._train_collate_in_worker_from_transform is True assert input_transform._val_collate_in_worker_from_transform is None @@ -1037,7 +994,7 @@ def local_fn(x): def test_save_hyperparemeters(tmpdir): - kwargs = {"train_transform": {"pre_tensor_transform": local_fn}} + kwargs = {"train_transform": {"per_sample_transform": local_fn}} input_transform = CustomInputTransformHyperparameters("token", **kwargs) state_dict = input_transform.state_dict() torch.save(state_dict, os.path.join(tmpdir, "state_dict.pt")) diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py index 7b29e5a922..ffc48b37ce 100644 --- a/tests/core/data/test_transforms.py +++ b/tests/core/data/test_transforms.py @@ -117,37 +117,37 @@ def test_kornia_collate(): "base_transforms, additional_transforms, expected_result", [ ( - {"to_tensor_transform": _MOCK_TRANSFORM}, - {"post_tensor_transform": _MOCK_TRANSFORM}, - {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, + {"per_sample_transform": _MOCK_TRANSFORM}, + {"per_batch_transform": _MOCK_TRANSFORM}, + {"per_sample_transform": _MOCK_TRANSFORM, "per_batch_transform": _MOCK_TRANSFORM}, ), ( - {"to_tensor_transform": _MOCK_TRANSFORM}, - {"to_tensor_transform": _MOCK_TRANSFORM}, + {"per_sample_transform": _MOCK_TRANSFORM}, + {"per_sample_transform": _MOCK_TRANSFORM}, { - "to_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) ) }, ), ( - {"to_tensor_transform": _MOCK_TRANSFORM}, - {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, + {"per_sample_transform": _MOCK_TRANSFORM}, + {"per_sample_transform": _MOCK_TRANSFORM, "per_batch_transform": _MOCK_TRANSFORM}, { - "to_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) ), - "post_tensor_transform": _MOCK_TRANSFORM, + "per_batch_transform": _MOCK_TRANSFORM, }, ), ( - {"to_tensor_transform": _MOCK_TRANSFORM, "post_tensor_transform": _MOCK_TRANSFORM}, - {"to_tensor_transform": _MOCK_TRANSFORM}, + {"per_sample_transform": _MOCK_TRANSFORM, "per_batch_transform": _MOCK_TRANSFORM}, + {"per_sample_transform": _MOCK_TRANSFORM}, { - "to_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( convert_to_modules(_MOCK_TRANSFORM), convert_to_modules(_MOCK_TRANSFORM) ), - "post_tensor_transform": _MOCK_TRANSFORM, + "per_batch_transform": _MOCK_TRANSFORM, }, ), ], diff --git a/tests/graph/classification/test_data.py b/tests/graph/classification/test_data.py index a588390d17..11eb0d5a52 100644 --- a/tests/graph/classification/test_data.py +++ b/tests/graph/classification/test_data.py @@ -95,19 +95,19 @@ def test_transforms(self, tmpdir): predict_dataset=predict_dataset, train_transform=merge_transforms( GraphClassificationInputTransform.default_transforms(), - {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, + {"per_sample_transform": OneHotDegree(tudataset.num_features - 1)}, ), val_transform=merge_transforms( GraphClassificationInputTransform.default_transforms(), - {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, + {"per_sample_transform": OneHotDegree(tudataset.num_features - 1)}, ), test_transform=merge_transforms( GraphClassificationInputTransform.default_transforms(), - {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, + {"per_sample_transform": OneHotDegree(tudataset.num_features - 1)}, ), predict_transform=merge_transforms( GraphClassificationInputTransform.default_transforms(), - {"pre_tensor_transform": OneHotDegree(tudataset.num_features - 1)}, + {"per_sample_transform": OneHotDegree(tudataset.num_features - 1)}, ), batch_size=2, ) diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index dbcc009cf4..6a9271174f 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -167,8 +167,8 @@ def test_from_filepaths_visualise(tmpdir): # call show functions # dm.show_train_batch() - dm.show_train_batch("pre_tensor_transform") - dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_train_batch("per_sample_transform") + dm.show_train_batch(["per_sample_transform", "per_batch_transform"]) @pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @@ -202,9 +202,7 @@ def test_from_filepaths_visualise_multilabel(tmpdir): # call show functions dm.show_train_batch() - dm.show_train_batch("pre_tensor_transform") - dm.show_train_batch("to_tensor_transform") - dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_train_batch("per_sample_transform") dm.show_val_batch("per_batch_transform") @@ -228,7 +226,7 @@ def test_from_filepaths_splits(tmpdir): assert len(train_filepaths) == len(train_labels) _to_tensor = { - "to_tensor_transform": nn.Sequential( + "per_sample_transform": nn.Sequential( ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()), ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ), @@ -589,7 +587,7 @@ def mixup(batch, alpha=1.0): train_transform = { # applied only on images as ApplyToKeys is used with `input` - "post_tensor_transform": ApplyToKeys("input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))), + "per_sample_transform": ApplyToKeys("input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))), "per_batch_transform": mixup, } # merge the default transform for this task with new one. diff --git a/tests/image/embedding/utils.py b/tests/image/embedding/utils.py index a6dc181b22..8cc77bdccc 100644 --- a/tests/image/embedding/utils.py +++ b/tests/image/embedding/utils.py @@ -37,13 +37,13 @@ def ssl_datamodule( total_num_crops, num_crops, size_crops, crop_scales ) - to_tensor_transform = ApplyToKeys( + per_sample_transform = ApplyToKeys( DataKeys.INPUT, multi_crop_transform, ) input_transform = DefaultInputTransform( train_transform={ - "to_tensor_transform": to_tensor_transform, + "per_sample_transform": per_sample_transform, "collate": collate_fn, } ) diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index 475c1d6b49..cd898a2312 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -379,7 +379,7 @@ def test_map_labels(tmpdir): assert dm.data_fetcher.block_viz_window is False dm.show_train_batch("load_sample") - dm.show_train_batch("to_tensor_transform") + dm.show_train_batch("per_sample_transform") # check training data data = next(iter(dm.train_dataloader())) diff --git a/tests/video/classification/test_model.py b/tests/video/classification/test_model.py index 4eb6775a58..001e8f64ea 100644 --- a/tests/video/classification/test_model.py +++ b/tests/video/classification/test_model.py @@ -153,7 +153,7 @@ def test_video_classifier_finetune_from_folders(tmpdir): assert len(VideoClassifier.available_backbones()) > 5 train_transform = { - "post_tensor_transform": Compose( + "per_sample_transform": Compose( [ ApplyTransformToKey( key="video", @@ -239,7 +239,7 @@ def test_video_classifier_finetune_from_files(tmpdir): assert len(VideoClassifier.available_backbones()) > 5 train_transform = { - "post_tensor_transform": Compose( + "per_sample_transform": Compose( [ ApplyTransformToKey( key="video", @@ -316,7 +316,7 @@ def test_video_classifier_finetune_fiftyone(tmpdir): assert len(VideoClassifier.available_backbones()) > 5 train_transform = { - "post_tensor_transform": Compose( + "per_sample_transform": Compose( [ ApplyTransformToKey( key="video",