Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[doc] Add DataPipeline + Callbacks + Registry #207

Merged
merged 29 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 165 additions & 30 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,39 @@ Tutorial: Creating a Custom Task
In this tutorial we will go over the process of creating a custom task,
along with a custom data module.

1 . Imports
tchaton marked this conversation as resolved.
Show resolved Hide resolved
-----------


.. testcode:: python

import flash
from typing import Any, List, Tuple

import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from torch import nn
from pytorch_lightning import seed_everything
from sklearn import datasets
from sklearn.model_selection import train_test_split
from torch import nn

import flash
from flash.data.auto_dataset import AutoDataset
from flash.data.process import Postprocess, Preprocess

seed_everything(42)

The Task: Linear regression
---------------------------

2 . The Task: Linear regression
tchaton marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think data should go before task.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Steps to create a task

  1. DataModule
  2. PreProcess
  3. Post Process
  4. Registry
  5. Task
  6. Example
  7. Test (This is important since people need to know how to test to merge)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ! Feel free to re-organize this when you have time.

-------------------------------

Here we create a basic linear regression task by subclassing
tchaton marked this conversation as resolved.
Show resolved Hide resolved
``flash.Task``. For the majority of tasks, you will likely only need to
:class:`~flash.core.model.Task`. For the majority of tasks, you will likely only need to
override the ``__init__`` and ``forward`` methods.

.. testcode::

class LinearRegression(flash.Task):

def __init__(self, num_inputs, learning_rate=0.001, metrics=None):
# what kind of model do we want?
model = nn.Linear(num_inputs, 1)
Expand Down Expand Up @@ -58,35 +71,124 @@ testing) or override ``training_step``, ``validation_step``, and
Lightning’s
`methods <https://pytorch-lightning.readthedocs.io/en/latest/lightning_module.html#methods>`__.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

The Data
--------
3 . The Data
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
-------------

For a task you will likely need a specific way of loading data.

It is recommended to create a :class:`~flash.data.process.Preprocess` object.
The :class:`~flash.data.process.Preprocess` contains all the processing logic and are similar to ``Callback``.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
The user has to override hooks with their processing logic.

.. note::
As new concepts are being introduced, we strongly encourage the reader to click on :class:`~flash.data.process.Preprocess`
before going further in the tutorial.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

The user would have to implement a :class:`~flash.data.data_module.DataModule` as a way to perform data checks and instantiate the preprocess.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

.. note::

Philosophically, the :class:`~flash.data.process.Preprocess` belongs with the :class:`~flash.data.data_module.DataModule`
tchaton marked this conversation as resolved.
Show resolved Hide resolved
and the :class:`~flash.data.process.Postprocess` with the :class:`~flash.core.model.Task`.


3.a The DataModule API
----------------------

First, let's design the user-facing API. The ``NumpyDataModule`` will provide a ``from_xy_dataset`` helper ``classmethod``.

Example::

x, y = ...
preprocess_cls = ...
datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess_cls)

Here are the `NumpyDataModule`` implementation:
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Example::

from flash import DataModule
from flash.data.process import Preprocess

class NumpyDataModule(DataModule):

@classmethod
def from_xy_dataset(cls, x: ND, y: ND, preprocess_cls: Preprocess = NumpyPreprocess, batch_size: int = 64, num_workers: int = 0):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

preprocess = preprocess_cls()

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=.20, random_state=0)

dm = cls.from_load_data_inputs(
train_load_data_input=(x_train, y_train),
test_load_data_input=(x_test, y_test),
preprocess=preprocess, # DON'T FORGET TO PROVIDE THE PREPROCESS
batch_size=batch_size,
num_workers=num_workers
)
# Some metatada can be accessed from ``train_ds`` directly.
dm.num_inputs = dm.train_dataset.num_inputs
return dm

For a task you will likely need a specific way of loading data. For this
example, lets say we want a ``flash.DataModule`` to be used explicitly
for the prediction of diabetes disease progression. We can create this
``DataModule`` below, wrapping the scikit-learn `Diabetes
dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset>`__.

.. note::

You’ll notice we added a ``DataPipeline``, which will be used when we
call ``.predict()`` on our model. In this case we want to nicely format
our ouput from the model with the string ``"disease progression"``, but
you could do any sort of post processing you want (see :ref:`datapipeline`).
The :class:`~flash.data.data_module.DataModule` provides a ``from_load_data_inputs`` helper function. This function will take care
of connecting the provided :class:`~flash.data.process.Preprocess` with the :class:`~flash.data.data_module.DataModule`.
Make sure to instantiate your :class:`~flash.data.data_module.DataModule` with this helper if you rely on :class:`~flash.data.process.Preprocess`
objects.

Fit
---
3.b The Preprocess API
----------------------

Example::
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain in a sentence what are the hooks you need to override in this case and what logic did you add.


import torch
from torch import Tensor
import numpy as np

ND = np.ndarray

class NumpyPreprocess(Preprocess):

def load_data(self, data: Tuple[ND, ND], dataset: AutoDataset) -> List[Tuple[ND, float]]:
if self.training:
dataset.num_inputs = data[0].shape[1]
return [(x, y) for x, y in zip(*data)]

def to_tensor_transform(self, sample: Any) -> Tuple[Tensor, Tensor]:
x, y = sample
x = torch.from_numpy(x).float()
y = torch.tensor(y, dtype=torch.float)
return x, y

def predict_load_data(self, data: ND) -> ND:
return data

def predict_to_tensor_transform(self, sample: ND) -> ND:
return torch.from_numpy(sample).float()

4. Fitting
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me this is no longer a part of the guide, since this is true for out-of-the-box tasks as well.

I think this should be:

  1. Imports
  2. Create a custom task (explain what needs to be overridden)
  3. Create a custom datamodule
  4. create a preprocess class
  5. Optional (anything else you may need to customize?)

and then add something like:
You now have a new customized Flash Task! As for any flash task, you can now use it for fitting on a dataset and creating predictions.

add code snippet...

----------

For this task, we will be using ``scikit-learn`` `Diabetes
tchaton marked this conversation as resolved.
Show resolved Hide resolved
dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset>`__.

Like any Flash Task, we can fit our model using the ``flash.Trainer`` by
supplying the task itself, and the associated data:

.. code:: python

data = DiabetesData()
model = LinearRegression(num_inputs=data.num_inputs)
x, y = datasets.load_diabetes(return_X_y=True)
datamodule = NumpyDataModule.from_xy_dataset(x, y)
model = LinearRegression(num_inputs=datamodule.num_inputs)

trainer = flash.Trainer(max_epochs=1000)
trainer.fit(model, data)

5. Predicting
-------------

With a trained model we can now perform inference. Here we will use a
few examples from the test set of our data:

Expand All @@ -99,15 +201,48 @@ few examples from the test set of our data:
[-0.0128, -0.0446, -0.0235, -0.0401, -0.0167, 0.0046, -0.0176, -0.0026, -0.0385, -0.0384],
[-0.0237, -0.0446, 0.0455, 0.0907, -0.0181, -0.0354, 0.0707, -0.0395, -0.0345, -0.0094]])

model.predict(predict_data)
predictions = model.predict(predict_data)
print(predictions)
#out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])]


6. Customize PostProcess
------------------------

To customize the postprocessing of this task, you can create a :class:`~flash.data.process.Postprocess` objects and assign it to your model as follows:
tchaton marked this conversation as resolved.
Show resolved Hide resolved

.. code:: python

class CustomPostprocess(Postprocess):

THRESHOLD = 14.72

def predict_per_sample_transform(self, pred: Any) -> Any:
if pred > self.THRESHOLD:

def send_slack_message(pred):
print(f"This prediction: {pred} is above the threshold: {self.THRESHOLD}")

send_slack_message(pred)
return pred


class LinearRegression(flash.Task):

# ``postprocess_cls`` is a special attribute name used internally
# to instantiate your Postprocess.
postprocess_cls = CustomPostprocess

...

And when running predict one more time.

.. code:: python

Because of our custom data pipeline’s ``after_uncollate`` method, we
will get a nicely formatted output like the following:
predict_data = ...

.. code::
predictions = model.predict(predict_data)
# out: This prediction: tensor([14.7288]) is above the threshold: 14.72

['disease progression: 155.90',
'disease progression: 156.59',
'disease progression: 152.69',
'disease progression: 149.05',
'disease progression: 150.90']
print(predictions)
# out: [tensor([14.7190]), tensor([14.7100]), tensor([14.7288]), tensor([14.6685]), tensor([14.6687])]
46 changes: 46 additions & 0 deletions docs/source/general/callback.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
########
Callback
########

.. _callback:

**************
Flash Callback
**************

:class:`~flash.data.callback.FlashCallback` are extensions of the PyTorch Lightning `Callback <https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html>`__.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved

tchaton marked this conversation as resolved.
Show resolved Hide resolved
A callback is a self-contained program that can be reused across projects.

Flash and Lightning have a callback system to execute callbacks when needed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would skip this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am following Lightning docs there.


Callbacks should capture NON-ESSENTIAL logic that is NOT required for your lightning module to run.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

tchaton marked this conversation as resolved.
Show resolved Hide resolved
*******************
Available Callbacks
*******************


BaseDataFetcher
_______________

.. autoclass:: flash.data.callback.BaseDataFetcher
:members: enable

BaseViz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know its long but any reason not to be explicit and call it visualization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

_______

.. autoclass:: flash.data.base_viz.BaseViz
:members:


*************
API reference
*************


FlashCallback
_____________

.. autoclass:: flash.data.callback.FlashCallback
:members:
Loading