Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[doc] Add DataPipeline + Callbacks + Registry #207

Merged
merged 29 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 114 additions & 24 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,21 @@ along with a custom data module.

.. testcode:: python

import flash
from typing import Any, List, Tuple

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
from pytorch_lightning import seed_everything
from sklearn import datasets
from sklearn.model_selection import train_test_split
from torch import nn

import flash
from flash.data.auto_dataset import AutoDataset
from flash.data.process import Postprocess, Preprocess

seed_everything(42)


The Task: Linear regression
---------------------------
Expand All @@ -24,6 +32,7 @@ override the ``__init__`` and ``forward`` methods.
.. testcode::

class LinearRegression(flash.Task):

def __init__(self, num_inputs, learning_rate=0.001, metrics=None):
# what kind of model do we want?
model = nn.Linear(num_inputs, 1)
Expand Down Expand Up @@ -61,28 +70,81 @@ Lightning’s
The Data
--------

For a task you will likely need a specific way of loading data. For this
example, lets say we want a ``flash.DataModule`` to be used explicitly
for the prediction of diabetes disease progression. We can create this
``DataModule`` below, wrapping the scikit-learn `Diabetes
For a task you will likely need a specific way of loading data.

Firstly, it is recommended to create a :class:`~flash.data.process.Preprocess` object.
The :class:`~flash.data.process.Preprocess` contains all the processing logic and are similar to ``Callback``.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
The user would to override hooks with their processing logic.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

.. note::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not start with notes since it doesn't look good (esp not 2 notes). I would take is out of the note.

3.b ....
A :class:~flash.data.process.Preprocess object provides a series of hooks that can be overridden with custom data processing logic.
It allows the user much more granular control over their data processing flow.

The :class:`~flash.data.process.Preprocess` object reduces the engineering overhead needed for inference on raw data or
to deploy the model in a production environment, compared to traditional
`Dataset <https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset>`_.

and then add the note

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

As new concepts are being introduced, we strongly encourage the reader to click on :class:`~flash.data.process.Preprocess`
before going further in the tutorial.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Secondly, the user would have to implement a ``DataModule``.

For this task, we will be using ``scikit-learn`` `Diabetes
dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset>`__.

Example::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain in a sentence what are the hooks you need to override in this case and what logic did you add.


import torch
from torch import Tensor
import numpy as np

ND = np.ndarray

class NumpyPreprocess(Preprocess):

def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]:
if self.training:
dataset.num_inputs = data[0].shape[1]
return [(x, y) for x, y in zip(*data)]

def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]:
x, y = sample
x = torch.from_numpy(x).float()
y = torch.tensor(y, dtype=torch.float)
return x, y

def predict_load_data(self, data: ND) -> ND:
return data

def predict_to_tensor_transform(self, sample: ND) -> ND:
return torch.from_numpy(sample).float()


class SklearnDataModule(flash.DataModule):

preprocess_cls = NumpyPreprocess

You’ll notice we added a ``DataPipeline``, which will be used when we
call ``.predict()`` on our model. In this case we want to nicely format
our ouput from the model with the string ``"disease progression"``, but
you could do any sort of post processing you want (see :ref:`datapipeline`).
@classmethod
def from_dataset(cls, x: ND, y: ND, batch_size: int = 64, num_workers: int = 0):

Fit
---
preprocess = cls.preprocess_cls()

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)

dm = cls.from_load_data_inputs(
train_load_data_input=(x_train, y_train),
test_load_data_input=(x_test, y_test),
preprocess=preprocess,
batch_size=batch_size,
num_workers=num_workers
)
dm.num_inputs = dm._train_ds.num_inputs
return dm


Fitting
-------

Like any Flash Task, we can fit our model using the ``flash.Trainer`` by
supplying the task itself, and the associated data:

.. code:: python

data = DiabetesData()
model = LinearRegression(num_inputs=data.num_inputs)
datamodule = SklearnDataModule.from_dataset(*datasets.load_diabetes(return_X_y=True))
model = LinearRegression(num_inputs=datamodule.num_inputs)

trainer = flash.Trainer(max_epochs=1000)
trainer.fit(model, data)
Expand All @@ -99,15 +161,43 @@ few examples from the test set of our data:
[-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384],
[-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094]])

model.predict(predict_data)
predictions = model.predict(predict_data)
print(predictions)
#out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])]


To customize the postprocessing of this task, you can create a :class:`~flash.data.process.Postprocess` objects and assign it to your model as follow:
tchaton marked this conversation as resolved.
Show resolved Hide resolved

.. code:: python

class CustomPostprocess(Postprocess):

THRESHOLD = 14.72

def predict_per_sample_transform(self, pred: Any) -> Any:
if pred > self.THRESHOLD:

def send_slack_message(pred):
print(f"This prediction: {pred} is above the threshold: {self.THRESHOLD}")

send_slack_message(pred)
return pred


class LinearRegression(flash.Task):

postprocess_cls = CustomPostprocess

...

And when running predict one more time.

.. code:: python

Because of our custom data pipeline’s ``after_uncollate`` method, we
will get a nicely formatted output like the following:
predict_data = ...

.. code::
predictions = model.predict(predict_data)
# out: This prediction: tensor([14.7288]) is above the threshold: 14.72

['disease progression: 155.90',
'disease progression: 156.59',
'disease progression: 152.69',
'disease progression: 149.05',
'disease progression: 150.90']
print(predictions)
# out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])]
46 changes: 46 additions & 0 deletions docs/source/general/callback.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
########
Callback
########

.. _callback:

**************
Flash Callback
**************

:class:`~flash.data.callback.FlashCallback` are extensions of the PyTorch Lightning `Callback <https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html>`__.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved

tchaton marked this conversation as resolved.
Show resolved Hide resolved
A callback is a self-contained program that can be reused across projects.

Flash and Lightning have a callback system to execute callbacks when needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would skip this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am following Lightning docs there.


Callbacks should capture NON-ESSENTIAL logic that is NOT required for your lightning module to run.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

tchaton marked this conversation as resolved.
Show resolved Hide resolved
*******************
Available Callbacks
*******************


BaseDataFetcher
_______________

.. autoclass:: flash.data.callback.BaseDataFetcher
:members: enable

BaseViz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know its long but any reason not to be explicit and call it visualization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

_______

.. autoclass:: flash.data.base_viz.BaseViz
:members:


*************
API reference
*************


FlashCallback
_____________

.. autoclass:: flash.data.callback.FlashCallback
:members:
129 changes: 125 additions & 4 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,132 @@
Data
####

.. _data:

*******************************
Using DataModule + DataPipeline
*******************************

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing context.

What are these objects? Why do they exist? when do I need to use them?

I think we can do something like

  1. How to use out-of-the-box flashdatamodules
  2. How to customize existing datamodules
  3. How to build a datamodule for a new task
  4. What are data pipelines and when are they required
  5. how to use them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Definitely. I will work on this next week :)

In this section, we will create a very simple ``ImageClassificationPreprocess`` with a ``ImageClassificationDataModule``.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Example::

import os
import numpy as np
from flash.data.data_module import DataModule
from flash.data.process import Preprocess
from PIL.Image import Image
import torchvision.transforms as T
from torch import Tensor

# Subclass ``Preprocess``

class ImageClassificationPreprocess(Preprocess):

to_tensor = T.ToTensor()

def load_data(self, folder: str, dataset: AutoDataset) -> Iterable:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# The AutoDataset is optional but can be useful to save some metadata.

# metadata looks like this: [(image_path_1, label_1), ... (image_path_n, label_n)].
tchaton marked this conversation as resolved.
Show resolved Hide resolved
metadata = make_dataset_from_folder(folder)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# for the train ``AutoDataset``, we want to store the ``num_classes``.
if self.training:
dataset.num_classes = len(np.unique([m[1] for m in metadata]))

return metadata

def predict_load_data(self, predict_folder: str) -> Iterable:
# This returns [image_path_1, ... image_path_m].
return os.listdir(folder)

def load_sample(self, sample: Union[str, Tuple[str, int]]) -> Tuple[Image, int]
if self.predicting:
return load_pil(image_path)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
image_path, label = sample
return load_pil(image_path), label

def to_tensor_transform(
self,
sample: Union[Image, Tuple[Image, int]]
) -> Union[Tensor, Tuple[Tensor, int]]:

if self.predicting:
return self.to_tensor(sample)
else:
return self.to_tensor(sample[0]), sample[1]

class ImageClassificationDataModule(DataModule):

# Set ``preprocess_cls`` with your custom ``preprocess``.

preprocess_cls = ImageClassificationPreprocess

@classmethod
def from_folders(
cls,
train_folder: Optional[str],
val_folder: Optional[str],
test_folder: Optional[str],
predict_folder: Optional[str],
**kwargs
):

preprocess = cls.preprocess_cls()

# {stage}_load_data_input will be given to your
# ``Preprocess`` ``{stage}_load_data`` function.
return cls.from_load_data_inputs(
train_load_data_input=train_folder,
val_load_data_input=val_folder,
test_load_data_input=test_folder,
predict_load_data_input=predict_folder,
preprocess=preprocess, # DON'T FORGET TO PASS THE CREATED PREPROCESS
**kwargs,
)

dm = ImageClassificationDataModule.from_folders("./data/train", "./data/val", "./data/test", "./data/predict")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

model = ImageClassifier(...)
trainer = Trainer(...)

trainer.fit(model, dm)



*************
API reference
*************

.. _preprocess:

Preprocess
__________

.. autoclass:: flash.data.process.Preprocess
:members:


----------

.. _postprocess:

Postprocess
___________


.. autoclass:: flash.data.process.Postprocess
:members:


----------

.. _datapipeline:

DataPipeline
------------
____________

To make tasks work for inference, one must create a ``Preprocess`` and ``PostProcess``.
The ``flash.data.process.Preprocess`` exposes 9 hooks to override which can specifialzed for each stage using
``train``, ``val``, ``test``, ``predict`` prefixes.
.. autoclass:: flash.data.data_pipeline.DataPipeline
:members:
Loading