Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Update data docs (#318)
Browse files Browse the repository at this point in the history
* Update data.rst

* Updates

* updates

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
ethanwharris and mergify[bot] authored May 19, 2021
1 parent cb098db commit f12f55c
Showing 1 changed file with 119 additions and 116 deletions.
235 changes: 119 additions & 116 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ Here are common terms you need to be familiar with:
* - Term
- Definition
* - :class:`~flash.core.data.data_module.DataModule`
- The :class:`~flash.core.data.data_module.DataModule` contains the dataset, transforms and dataloaders.
- The :class:`~flash.core.data.data_module.DataModule` contains the datasets, transforms and dataloaders.
* - :class:`~flash.core.data.data_pipeline.DataPipeline`
- The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` objects.
- The :class:`~flash.core.data.data_pipeline.DataPipeline` is Flash internal object to manage: :class:`~flash.core.data.data_source.DataSource`, :class:`~flash.core.data.process.Preprocess`, :class:`~flash.core.data.process.Postprocess`, and :class:`~flash.core.data.process.Serializer` objects.
* - :class:`~flash.core.data.data_source.DataSource`
- The :class:`~flash.core.data.data_source.DataSource` provides a hook-based API for creating data sets.
- The :class:`~flash.core.data.data_source.DataSource` provides :meth:`~flash.core.data.data_source.DataSource.load_data` and :meth:`~flash.core.data.data_source.DataSource.load_sample` hooks for creating data sets from metadata (such as folder names).
* - :class:`~flash.core.data.process.Preprocess`
- The :class:`~flash.core.data.process.Preprocess` provides a simple hook-based API to encapsulate your pre-processing logic.
The :class:`~flash.core.data.process.Preprocess` provides multiple hooks such as :meth:`~flash.core.data.process.Preprocess.load_data`
and :meth:`~flash.core.data.process.Preprocess.load_sample` which are used to replace a traditional `Dataset` logic.
Flash DataPipeline contains a system to call the right hooks when needed.
The :class:`~flash.core.data.process.Preprocess` hooks covers from data-loading to model forwarding.
These hooks (such as :meth:`~flash.core.data.process.Preprocess.pre_tensor_transform`) enable transformations to be applied to your data at every point along the pipeline (including on the device).
The :class:`~flash.core.data.data_pipeline.DataPipeline` contains a system to call the right hooks when needed.
The :class:`~flash.core.data.process.Preprocess` hooks can be either overriden directly or provided as a dictionary of transforms (mapping hook name to callable transform).
* - :class:`~flash.core.data.process.Postprocess`
- The :class:`~flash.core.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic.
The :class:`~flash.core.data.process.Postprocess` hooks cover from model outputs to predictions export.
Expand All @@ -40,25 +39,23 @@ How to use out-of-the-box flashdatamodules
*******************************************

Flash provides several DataModules with helpers functions.
Checkout the :ref:`image_classification` section or any other tasks to learn more about them.
Check out the :ref:`image_classification` section (or the sections for any of our other tasks) to learn more.

***************
Data Processing
***************

Currently, it is common practice to implement a `Dataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset>`_
and provide them to a `DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_.

However, after model training, it requires a lot of engineering overhead to make inference on raw data and deploy the model in production environnement.
Currently, it is common practice to implement a :class:`pytorch.utils.data.Dataset`
and provide it to a :class:`pytorch.utils.data.DataLoader`.
However, after model training, it requires a lot of engineering overhead to make inference on raw data and deploy the model in production environment.
Usually, extra processing logic should be added to bridge the gap between training data and raw data.

The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` classes can be used to
store the data as well as the preprocessing and postprocessing transforms. The :class:`~flash.core.data.process.Serializer`
class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format
(e.g. classes, labels, probabilites, etc.).
The :class:`~flash.core.data.data_source.DataSource` class can be used to generate data sets from multiple sources (e.g. folders, numpy, etc.), that can then all be transformed in the same way.
The :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` classes can be used to manage the preprocessing and postprocessing transforms.
The :class:`~flash.core.data.process.Serializer` class provides the logic for converting :class:`~flash.core.data.process.Postprocess` outputs to the desired predict format (e.g. classes, labels, probabilites, etc.).

By providing a series of hooks that can be overridden with custom data processing logic,
the user has much more granular control over their data processing flow.
By providing a series of hooks that can be overridden with custom data processing logic (or just targeted with transforms),
Flash gives the user much more granular control over their data processing flow.

Here are the primary advantages:

Expand All @@ -73,56 +70,62 @@ hooks by adding ``train``, ``val``, ``test`` or ``predict``.

Check out :class:`~flash.core.data.process.Preprocess` for some examples.

.. note::

``[WIP]`` We are currently working on a new feature to make :class:`~flash.core.data.process.Preprocess`

and :class:`~flash.core.data.process.Postprocess` automatically deployable from checkpoints as

``Endpoints`` or ``BatchTransformJob``. Stay tuned !

*************************************
How to customize existing datamodules
*************************************

Flash DataModule can receive directly dataset as follow:
Any Flash :class:`~flash.core.data.data_module.DataModule` can be created directly from datasets using the :meth:`~flash.core.data.data_module.DataModule.from_datasets` like this:

Example::
.. code-block:: python
from flash import Trainer
from flash.core.data.data_module import DataModule
dm = DataModule(train_dataset=MyDataset(train=True))
trainer = Trainer(fast_dev_run=True)
trainer.fit(model, data_module=dm)
data_module = DataModule.from_datasets(train_dataset=MyDataset())
trainer = Trainer()
trainer.fit(model, data_module=data_module)
In order to customize Flash to your need, you need to know what are :class:`~flash.core.data.data_module.DataModule`
and :class:`~flash.core.data.process.Preprocess` responsibilities.
.. note::
The :class:`~flash.core.data.data_module.DataModule` provides additional ``classmethod`` helpers (``from_*``) for loading data from various sources.
In each ``from_*`` method, the :class:`~flash.core.data.data_module.DataModule` internally retrieves the correct :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.process.Preprocess`.
Flash :class:`~flash.core.data.auto_dataset.AutoDataset` instances are created from the :class:`~flash.core.data.data_source.DataSource` for train, val, test, and predict.
The :class:`~flash.core.data.data_module.DataModule` populates the ``DataLoader`` for each stage with the corresponding :class:`~flash.core.data.auto_dataset.AutoDataset`.

At this point, we strongly encourage the readers to quickly check the :class:`~flash.core.data.process.Preprocess` API before getting further.
The :class:`~flash.core.data.process.Preprocess` contains the processing logic related to a given task.
Each :class:`~flash.core.data.process.Preprocess` provides some default transforms through the :meth:`~flash.core.data.process.Preprocess.default_transforms` method.
Users can easily override these by providing their own transforms to the :class:`~flash.core.data.data_module.DataModule`.
Here's an example:

The :class:`~flash.core.data.data_module.DataModule` provides ``classmethod`` helpers to build
:class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.data_pipeline.DataPipeline`,
generate Flash Internal :class:`~flash.core.data.auto_dataset.AutoDataset` and populate DataLoaders with them.
.. code-block:: python
The :class:`~flash.core.data.process.Preprocess` contains the processing logic related to a given task. Users can easily override hooks
to customize a built-in :class:`~flash.core.data.process.Preprocess` for their needs.
from flash.core.data.transforms import ApplyToKeys
from flash.image import ImageClassificationData, ImageClassifier
Example::
transform = {
"to_tensor_transform": ApplyToKeys("input", my_to_tensor_transform)
}
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
train_transform=transform,
val_transform=transform,
test_transform=transform,
)
Alternatively, the user may directly override the hooks for their needs like this:

.. code-block:: python
from typing import Any, Dict
from flash.image import ImageClassificationData, ImageClassifier, ImageClassificationPreprocess
class CustomImageClassificationPreprocess(ImageClassificationPreprocess):
# Assuming you have images in numpy format,
# just override ``load_sample`` hook and add your own logic.
@staticmethod
def load_sample(sample) -> Tuple[Image.Image, int]:
# By default, ``ImageClassificationPreprocess`` expects
# ``.png`` or ``.jpg`` to be loaded into PIL Image.
numpy_image_path, label = sample
return np.load(numpy_image_path), sample
def to_tensor_transform(sample: Dict[str, Any]) -> Dict[str, Any]:
sample["input"] = my_to_tensor_transform(sample["input"])
return sample
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
Expand All @@ -136,13 +139,12 @@ Example::
Custom Preprocess + Datamodule
******************************

The example below shows a very simple ``ImageClassificationPreprocess`` with a ``ImageClassificationDataModule``.
The example below shows a very simple ``ImageClassificationPreprocess`` with a single ``ImageClassificationFoldersDataSource`` and an ``ImageClassificationDataModule``.

1. User-Facing API design
_________________________

Designing an easy to use API is key. This is the first and most important step.

We want the ``ImageClassificationDataModule`` to generate a dataset from folders of images arranged in this way.

Example::
Expand All @@ -156,59 +158,73 @@ Example::

Example::

preprocess = ...

dm = ImageClassificationDataModule.from_folders(
train_folder="./data/train",
val_folder="./data/val",
test_folder="./data/test",
predict_folder="./data/predict",
preprocess=preprocess,
)

model = ImageClassifier(...)
trainer = Trainer(...)

trainer.fit(model, dm)

2. The DataModule
__________________
2. The DataSource
_________________

Secondly, let's implement the ``ImageClassificationDataModule`` from_folders classmethod.
We start by implementing the ``ImageClassificationFoldersDataSource``.
The ``load_data`` method will produce a list of files and targets from the given directory.
The ``load_sample`` method will load the given file as a ``PIL.Image``.
Here's the full ``ImageClassificationFoldersDataSource``:

Example::
.. code-block:: python
from flash.core.data.data_module import DataModule
from PIL import Image
from torchvision.datasets.folder import make_dataset
from typing import Any, Dict
from flash.core.data.data_source import DataSource, DefaultDataKeys
class ImageClassificationDataModule(DataModule):
class ImageClassificationFoldersDataSource(DataSource):
# Set ``preprocess_cls`` with your custom ``preprocess``.
preprocess_cls = ImageClassificationPreprocess
def load_data(self, folder: str, dataset: Any) -> Iterable:
# The dataset is optional but can be useful to save some metadata.
# metadata contains the image path and its corresponding label with the following structure:
# [(image_path_1, label_1), ... (image_path_n, label_n)].
metadata = make_dataset(folder)
# for the train ``AutoDataset``, we want to store the ``num_classes``.
if self.training:
dataset.num_classes = len(np.unique([m[1] for m in metadata]))
return [{DefaultDataKeys.INPUT: file, DefaultDataKeys.TARGET: target} for file, target in metadata]
def predict_load_data(self, predict_folder: str) -> Iterable:
# This returns [image_path_1, ... image_path_m].
return [{DefaultDataKeys.INPUT: file} for file in os.listdir(folder)]
def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]
sample[DefaultDataKeys.INPUT] = Image.open(sample[DefaultDataKeys.INPUT])
return sample
.. note:: We return samples as dictionaries using the :class:`~flash.core.data.data_source.DefaultDataKeys` by convention. This is the recommended (although not required) way to represent data in Flash.
3. The Preprocess
__________________

Finally, implement your custom ``ImageClassificationPreprocess``.
Next, implement your custom ``ImageClassificationPreprocess`` with some default transforms and a reference to the data source:

Example::
.. code-block:: python
from typing import Any, Callable, Dict, Optional, Tuple, Union
import os
import numpy as np
from flash.core.data.data_source import DefaultDataSources
from typing import Any, Callable, Dict, Optional
from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.image.data import ImageNumpyDataSource, ImagePathsDataSource, ImageTensorDataSource
from PIL import Image
import torchvision.transforms as T
from torch import Tensor
from torchvision.datasets.folder import make_dataset
import torchvision.transforms.functional as T
# Subclass ``Preprocess``
class ImageClassificationPreprocess(Preprocess):
to_tensor = T.ToTensor()

def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
Expand All @@ -222,12 +238,9 @@ Example::
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource(),
DefaultDataSources.NUMPY: ImageNumpyDataSource(),
DefaultDataSources.TENSORS: ImageTensorDataSource(),
DefaultDataSources.FOLDERS: ImageClassificationFoldersDataSource(),
},
default_data_source=DefaultDataSources.FILES,
default_data_source=DefaultDataSources.FOLDERS,
)
def get_state_dict(self) -> Dict[str, Any]:
Expand All @@ -237,23 +250,26 @@ Example::
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)
def to_tensor_transform(
self,
sample: Union[Image, Tuple[Image, int]]
) -> Union[Tensor, Tuple[Tensor, int]]:
def default_transforms(self) -> Dict[str, Callable]:
return {
"to_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.to_tensor)
}
if self.predicting:
return self.to_tensor(sample)
else:
return self.to_tensor(sample[0]), sample[1]
4. The DataModule
_________________

Finally, let's implement the ``ImageClassificationDataModule``.
We get the ``from_folders`` classmethod for free as we've registered a ``DefaultDataSources.FOLDERS`` data source in our ``ImageClassificationPreprocess``.
All we need to do is attach our :class:`~flash.core.data.process.Preprocess` class like this:

.. note::
.. code-block:: python
from flash.core.data.data_module import DataModule
Currently, Flash Tasks are implemented using :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess`.
However, it is not a hard requirement and one can still use :class:`~torch.data.utils.Dataset`, but we highly recommend
using :class:`~flash.core.data.process.Preprocess` and :class:`~flash.core.data.process.Postprocess` instead.
class ImageClassificationDataModule(DataModule):
# Set ``preprocess_cls`` with your custom ``preprocess``.
preprocess_cls = ImageClassificationPreprocess
*************
API reference
Expand Down Expand Up @@ -328,25 +344,13 @@ __________

.. autoclass:: flash.core.data.data_module.DataModule
:members:
train_dataset,
val_dataset,
test_dataset,
predict_dataset,
configure_data_fetcher,
show_train_batch,
show_val_batch,
show_test_batch,
show_predict_batch,
available_data_sources,
:exclude-members:
autogenerate_dataset,


******************************
How it works behind the scenes
******************************

Preprocess
DataSource
__________

.. note:: The ``load_data`` and ``load_sample`` will be used to generate an AutoDataset object.
Expand All @@ -355,26 +359,25 @@ Here is the ``AutoDataset`` pseudo-code.

Example::

from pytorch_lightning.trainer.states import RunningStage

class AutoDataset
def __init__(
self,
data: Any,
load_data: Optional[Callable] = None,
load_sample: Optional[Callable] = None,
data_pipeline: Optional['DataPipeline'] = None,
running_stage: Optional[RunningStage] = None
data: List[Any], # The result of a call to DataSource.load_data
data_source: DataSource,
running_stage: RunningStage,
) -> None:

self.preprocess = data_pipeline._preprocess_pipeline
self.preprocessed_data: Iterable = self.preprocess.load_data(data)
self.data = data
self.data_source = data_source

def __getitem__(self, index):
return self.preprocess.load_sample(self.preprocessed_data[index])
return self.data_source.load_sample(self.data[index])

def __len__(self):
return len(self.preprocessed_data)
return len(self.data)

Preprocess
__________

.. note::

Expand Down

0 comments on commit f12f55c

Please sign in to comment.