Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

DataPipeline Refactor V2: Simplify Input Transform #1010

Merged
merged 15 commits into from
Dec 1, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed `OutputTransform.save_sample` and `save_data` hooks ([#948](https://github.com/PyTorchLightning/lightning-flash/pull/948))

- (Removed InputTransform `pre_tensor_transform`, `to_tensor_transform`, `post_tensor_transform` hooks in favour of `per_sample_transform` ([#1010](https://github.com/PyTorchLightning/lightning-flash/pull/1010))


## [0.5.2] - 2021-11-05

### Added
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def mixup(batch, alpha=1.0):

train_transform = {
# applied only on images as ApplyToKeys is used with `input`
"post_tensor_transform": ApplyToKeys(
"per_sample_transform": ApplyToKeys(
"input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))),

# applied to the entire dictionary as `ApplyToKeys` isn't used.
Expand Down
18 changes: 7 additions & 11 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Here are common terms you need to be familiar with:
- The :class:`~flash.core.data.io.input.Input` provides :meth:`~flash.core.data.io.input.Input.load_data` and :meth:`~flash.core.data.io.input.Input.load_sample` hooks for creating data sets from metadata (such as folder names).
* - :class:`~flash.core.data.io.input_transform.InputTransform`
- The :class:`~flash.core.data.io.input_transform.InputTransform` provides a simple hook-based API to encapsulate your pre-processing logic.
These hooks (such as :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device).
These hooks (such as :meth:`~flash.core.data.io.input_transform.InputTransform.per_sample_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device).
The :class:`~flash.core.data.data_pipeline.DataPipeline` contains a system to call the right hooks when needed.
The :class:`~flash.core.data.io.input_transform.InputTransform` hooks can be either overridden directly or provided as a dictionary of transforms (mapping hook name to callable transform).
* - :class:`~flash.core.data.io.output_transform.OutputTransform`
Expand Down Expand Up @@ -112,7 +112,7 @@ Here's an example:
from flash.core.data.transforms import ApplyToKeys
from flash.image import ImageClassificationData, ImageClassifier

transform = {"to_tensor_transform": ApplyToKeys("input", my_to_tensor_transform)}
transform = {"per_sample_transform": ApplyToKeys("input", my_per_sample_transform)}

datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
Expand All @@ -132,8 +132,8 @@ Alternatively, the user may directly override the hooks for their needs like thi


class CustomImageClassificationInputTransform(ImageClassificationInputTransform):
def to_tensor_transform(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["input"] = my_to_tensor_transform(sample["input"])
def per_sample_transform(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["input"] = my_per_sample_transform(sample["input"])
return sample


Expand Down Expand Up @@ -267,7 +267,7 @@ Next, implement your custom ``ImageClassificationInputTransform`` with some defa
return cls(**state_dict)

def default_transforms(self) -> Dict[str, Callable]:
return {"to_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.to_tensor)}
return {"per_sample_transform": ApplyToKeys(DataKeys.INPUT, T.to_tensor)}

4. The DataModule
_________________
Expand Down Expand Up @@ -325,9 +325,7 @@ ______________

.. note::

The :meth:`~flash.core.data.io.input_transform.InputTransform.pre_tensor_transform`,
:meth:`~flash.core.data.io.input_transform.InputTransform.to_tensor_transform`,
:meth:`~flash.core.data.io.input_transform.InputTransform.post_tensor_transform`,
The :meth:`~flash.core.data.io.input_transform.InputTransform.per_sample_transform`,
:meth:`~flash.core.data.io.input_transform.InputTransform.collate`,
:meth:`~flash.core.data.io.input_transform.InputTransform.per_batch_transform` are injected as the
:paramref:`torch.utils.data.DataLoader.collate_fn` function of the DataLoader.
Expand All @@ -342,9 +340,7 @@ Example::

# This will be wrapped into a :class:`~flash.core.data.io.input_transform._InputTransformSequential`
for sample in samples:
sample = pre_tensor_transform(sample)
sample = to_tensor_transform(sample)
sample = post_tensor_transform(sample)
sample = per_sample_transform(sample)

samples = type(samples)(samples)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/integrations/icevision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Here's an example:
from flash.image import ObjectDetectionData

train_transform = {
"pre_tensor_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]),
"per_sample_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]),
}

datamodule = ObjectDetectionData.from_coco(
Expand Down
5 changes: 2 additions & 3 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ Flash automatically applies some default image transformations and augmentations
The base :class:`~flash.core.data.io.input_transform.InputTransform` defines 7 hooks for different stages in the data loading pipeline.
To apply image augmentations you can directly import the ``default_transforms`` from ``flash.image.classification.transforms`` and then merge your custom image transformations with them using the :func:`~flash.core.data.transforms.merge_transforms` helper function.
Here's an example where we load the default transforms and merge with custom `torchvision` transformations.
We use the `post_tensor_transform` hook to apply the transformations after the image has been converted to a `torch.Tensor`.


.. testsetup:: transformations
Expand All @@ -108,12 +107,12 @@ We use the `post_tensor_transform` hook to apply the transformations after the i
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.transforms import default_transforms

post_tensor_transform = ApplyToKeys(
per_sample_transform = ApplyToKeys(
DataKeys.INPUT,
T.Compose([T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), T.RandomPerspective()]),
)

new_transforms = merge_transforms(default_transforms((64, 64)), {"post_tensor_transform": post_tensor_transform})
new_transforms = merge_transforms(default_transforms((64, 64)), {"per_sample_transform": per_sample_transform})

datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", train_transform=new_transforms
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ For object-detection tasks, you can leverage the transformations from `Albumenta
from flash.image import ObjectDetectionData

train_transform = {
"pre_tensor_transform": transforms.IceVisionTransformAdapter(
"per_sample_transform": transforms.IceVisionTransformAdapter(
[*A.resize_and_pad(128), A.Normalize(), A.Flip(0.4), alb.RandomBrightnessContrast()]
)
}
Expand Down
4 changes: 2 additions & 2 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ InputTransform
The :class:`~flash.core.data.io.input_transform.InputTransform` object contains all the data transforms.
Internally we inject the :class:`~flash.core.data.io.input_transform.InputTransform` transforms at several points along the pipeline.

Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your :class:`~flash.core.data.io.input_transform.InputTransform` is as simple as implementing the ``default_transforms`` method.
Defining the standard transforms (typically at least a ``per_sample_transform`` should be defined) for your :class:`~flash.core.data.io.input_transform.InputTransform` is as simple as implementing the ``default_transforms`` method.
The :class:`~flash.core.data.io.input_transform.InputTransform` must take ``train_transform``, ``val_transform``, ``test_transform``, and ``predict_transform`` arguments in the ``__init__``.
These arguments can be provided by the user (when creating the :class:`~flash.core.data.data_module.DataModule`) to override the default transforms.
Any additional arguments are up to you.
Expand All @@ -115,7 +115,7 @@ Here's our ``TemplateInputTransform.__init__``:
:dedent: 4
:pyobject: TemplateInputTransform.__init__

For our ``TemplateInputTransform``, we'll just configure a default ``to_tensor_transform``.
For our ``TemplateInputTransform``, we'll just configure a default ``per_sample_transform``.
Let's first define the transform as a ``staticmethod``:

.. literalinclude:: ../../../flash/template/classification/data.py
Expand Down
8 changes: 3 additions & 5 deletions flash/audio/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision
from torchvision import transforms as T

if _TORCHAUDIO_AVAILABLE:
Expand All @@ -33,11 +32,10 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]
"""The default transforms for audio classification for spectrograms: resize the spectrogram, convert the
spectrogram and target to a tensor, and collate the batch."""
return {
"to_tensor_transform": nn.Sequential(
ApplyToKeys(DataKeys.INPUT, torchvision.transforms.ToTensor()),
"per_sample_transform": nn.Sequential(
ApplyToKeys(DataKeys.INPUT, T.Compose([T.ToTensor(), T.Resize(spectrogram_size)])),
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
),
"post_tensor_transform": ApplyToKeys(DataKeys.INPUT, T.Resize(spectrogram_size)),
"collate": default_collate,
}

Expand All @@ -55,5 +53,5 @@ def train_default_transforms(
augs.append(ApplyToKeys(DataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)))

if len(augs) > 0:
return merge_transforms(default_transforms(spectrogram_size), {"post_tensor_transform": nn.Sequential(*augs)})
return merge_transforms(default_transforms(spectrogram_size), {"per_sample_transform": nn.Sequential(*augs)})
return default_transforms(spectrogram_size)
22 changes: 4 additions & 18 deletions flash/core/data/base_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,7 @@ class CustomBaseVisualization(BaseVisualization):
def show_load_sample(self, samples: List[Any], running_stage):
# plot samples

def show_pre_tensor_transform(self, samples: List[Any], running_stage):
# plot samples

def show_to_tensor_transform(self, samples: List[Any], running_stage):
# plot samples

def show_post_tensor_transform(self, samples: List[Any], running_stage):
def show_per_sample_transform(self, samples: List[Any], running_stage):
# plot samples

def show_collate(self, batch: List[Any], running_stage):
Expand Down Expand Up @@ -93,9 +87,7 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage):
# out
{
'load_sample': [...],
'pre_tensor_transform': [...],
'to_tensor_transform': [...],
'post_tensor_transform': [...],
'per_sample_transform': [...],
'collate': [...],
'per_batch_transform': [...],
}
Expand Down Expand Up @@ -125,14 +117,8 @@ def show(self, batch: Dict[str, Any], running_stage: RunningStage, func_names_li
def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
"""Override to visualize ``load_sample`` output data."""

def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
"""Override to visualize ``pre_tensor_transform`` output data."""

def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
"""Override to visualize ``to_tensor_transform`` output data."""

def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
"""Override to visualize ``post_tensor_transform`` output data."""
def show_per_sample_transform(self, samples: List[Any], running_stage: RunningStage):
"""Override to visualize ``per_sample_transform`` output data."""

def show_collate(self, batch: List[Any], running_stage: RunningStage) -> None:
"""Override to visualize ``collate`` output data."""
Expand Down
20 changes: 6 additions & 14 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,24 @@ def __init__(
self,
deserializer: "Deserializer",
input_transform: "InputTransform",
pre_tensor_transform: Callable,
to_tensor_transform: Callable,
per_sample_transform: Callable,
):
super().__init__()
self.input_transform = input_transform
self.callback = ControlFlow(self.input_transform.callbacks)
self.deserializer = convert_to_modules(deserializer)
self.pre_tensor_transform = convert_to_modules(pre_tensor_transform)
self.to_tensor_transform = convert_to_modules(to_tensor_transform)
self.per_sample_transform = convert_to_modules(per_sample_transform)

self._current_stage_context = CurrentRunningStageContext(RunningStage.PREDICTING, input_transform, reset=False)
self._pre_tensor_transform_context = CurrentFuncContext("pre_tensor_transform", input_transform)
self._to_tensor_transform_context = CurrentFuncContext("to_tensor_transform", input_transform)
self._per_sample_transform_context = CurrentFuncContext("per_sample_transform", input_transform)

def forward(self, sample: str):

sample = self.deserializer(sample)

with self._current_stage_context:
with self._pre_tensor_transform_context:
sample = self.pre_tensor_transform(sample)
self.callback.on_pre_tensor_transform(sample, RunningStage.PREDICTING)

with self._to_tensor_transform_context:
sample = self.to_tensor_transform(sample)
self.callback.on_to_tensor_transform(sample, RunningStage.PREDICTING)
with self._per_sample_transform_context:
sample = self.per_sample_transform(sample)
self.callback.on_per_sample_transform(sample, RunningStage.PREDICTING)

return sample

Expand Down
37 changes: 7 additions & 30 deletions flash/core/data/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,8 @@ class FlashCallback(Callback):
trainer = Trainer(callbacks=[MyCustomCallback()])
"""

def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once a sample has been loaded using ``load_sample``."""

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``pre_tensor_transform`` has been applied to a sample."""

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
"""Called once ``to_tensor_transform`` has been applied to a sample."""

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
"""Called once ``post_tensor_transform`` has been applied to a sample."""
def on_per_sample_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
"""Called once ``per_sample_transform`` has been applied to a sample."""

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
"""Called once ``per_batch_transform`` has been applied to a batch."""
Expand All @@ -58,14 +49,8 @@ def run_for_all_callbacks(self, *args, method_name: str, **kwargs):
def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_load_sample")

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_pre_tensor_transform")

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_to_tensor_transform")

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_post_tensor_transform")
def on_per_sample_transform(self, sample: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(sample, running_stage, method_name="on_per_sample_transform")

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
self.run_for_all_callbacks(batch, running_stage, method_name="on_per_batch_transform")
Expand Down Expand Up @@ -147,9 +132,7 @@ def from_inputs(
'test': {},
'val': {
'load_sample': [0, 1, 2, 3, 4],
'pre_tensor_transform': [0, 1, 2, 3, 4],
'to_tensor_transform': [0, 1, 2, 3, 4],
'post_tensor_transform': [0, 1, 2, 3, 4],
'per_sample_transform': [0, 1, 2, 3, 4],
'collate': [tensor([0, 1, 2, 3, 4])],
'per_batch_transform': [tensor([0, 1, 2, 3, 4])]},
'predict': {}
Expand Down Expand Up @@ -179,14 +162,8 @@ def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None:
def on_load_sample(self, sample: Any, running_stage: RunningStage) -> None:
self._store(sample, "load_sample", running_stage)

def on_pre_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
self._store(sample, "pre_tensor_transform", running_stage)

def on_to_tensor_transform(self, sample: Any, running_stage: RunningStage) -> None:
self._store(sample, "to_tensor_transform", running_stage)

def on_post_tensor_transform(self, sample: Tensor, running_stage: RunningStage) -> None:
self._store(sample, "post_tensor_transform", running_stage)
def on_per_sample_transform(self, sample: Any, running_stage: RunningStage) -> None:
self._store(sample, "per_sample_transform", running_stage)

def on_per_batch_transform(self, batch: Any, running_stage: RunningStage) -> None:
self._store(batch, "per_batch_transform", running_stage)
Expand Down
Loading