Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[doc] Add DataPipeline + Callbacks + Registry #207

Merged
merged 29 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,20 @@ override the ``__init__`` and ``forward`` methods.

.. testcode::

from flash.data.process import Postprocess
from flash.data.process import Postprocess

class LinearRegressionPostprocess(Postprocess)

def per_sample_transform(self, samples):
for sample in samples:
print(f'disease progression: {sample}')
return samples

class LinearRegression(flash.Task):

postprocess_cls = LinearRegressionPostprocess

def __init__(self, num_inputs, learning_rate=0.001, metrics=None):
# what kind of model do we want?
model = nn.Linear(num_inputs, 1)
Expand Down Expand Up @@ -70,7 +83,7 @@ dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-data

You’ll notice we added a ``DataPipeline``, which will be used when we
call ``.predict()`` on our model. In this case we want to nicely format
our ouput from the model with the string ``"disease progression"``, but
our output from the model with the string ``"disease progression"``, but
you could do any sort of post processing you want (see :ref:`datapipeline`).

Fit
Expand Down
41 changes: 41 additions & 0 deletions docs/source/general/callback.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
########
Callback
########

.. _callback:

**************
Flash Callback
**************

``FlashCallback`` are an extension of PyTorch Lightning :class:`~pytorch_lightning.callbacks.Callback`.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

tchaton marked this conversation as resolved.
Show resolved Hide resolved

tchaton marked this conversation as resolved.
Show resolved Hide resolved
*******************
Available Callbacks
*******************


BaseDataFetcher
_______________

.. autoclass:: flash.data.callback.BaseDataFetcher
:members: enable

BaseViz
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know its long but any reason not to be explicit and call it visualization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

_______

.. autoclass:: flash.data.base_viz.BaseViz
:members:


*************
API reference
*************


FlashCallback
_____________

.. autoclass:: flash.data.callback.FlashCallback
:members:
129 changes: 125 additions & 4 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,132 @@
Data
####

.. _data:

*******************************
Using DataModule + DataPipeline
*******************************

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing context.

What are these objects? Why do they exist? when do I need to use them?

I think we can do something like

  1. How to use out-of-the-box flashdatamodules
  2. How to customize existing datamodules
  3. How to build a datamodule for a new task
  4. What are data pipelines and when are they required
  5. how to use them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Definitely. I will work on this next week :)

In this section, we will create a very simple ``ImageClassificationPreprocess`` with a ``ImageClassificationDataModule``.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Example::

import os
import numpy as np
from flash.data.data_module import DataModule
from flash.data.process import Preprocess
from PIL.Image import Image
import torchvision.transforms as T
from torch import Tensor

# Subclass ``Preprocess``

class ImageClassificationPreprocess(Preprocess):

to_tensor = T.ToTensor()

def load_data(self, folder: str, dataset: AutoDataset) -> Iterable:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# The AutoDataset is optional but can be useful to save some metadata.

# metadata looks like this: [(image_path_1, label_1), ... (image_path_n, label_n)].
tchaton marked this conversation as resolved.
Show resolved Hide resolved
metadata = make_dataset_from_folder(folder)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# for the train ``AutoDataset``, we want to store the ``num_classes``.
if self.training:
dataset.num_classes = len(np.unique([m[1] for m in metadata]))

return metadata

def predict_load_data(self, predict_folder: str) -> Iterable:
# This returns [image_path_1, ... image_path_m].
return os.listdir(folder)

def load_sample(self, sample: Union[str, Tuple[str, int]]) -> Tuple[Image, int]
if self.predicting:
return load_pil(image_path)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
image_path, label = sample
return load_pil(image_path), label

def to_tensor_transform(
self,
sample: Union[Image, Tuple[Image, int]]
) -> Union[Tensor, Tuple[Tensor, int]]:

if self.predicting:
return self.to_tensor(sample)
else:
return self.to_tensor(sample[0]), sample[1]

class ImageClassificationDataModule(DataModule):

# Set ``preprocess_cls`` with your custom ``preprocess``.

preprocess_cls = ImageClassificationPreprocess

@classmethod
def from_folders(
cls,
train_folder: Optional[str],
val_folder: Optional[str],
test_folder: Optional[str],
predict_folder: Optional[str],
**kwargs
):

preprocess = cls.preprocess_cls()

# {stage}_load_data_input will be given to your
# ``Preprocess`` ``{stage}_load_data`` function.
return cls.from_load_data_inputs(
train_load_data_input=train_folder,
val_load_data_input=val_folder,
test_load_data_input=test_folder,
predict_load_data_input=predict_folder,
preprocess=preprocess, # DON'T FORGET TO PASS THE CREATED PREPROCESS
**kwargs,
)

dm = ImageClassificationDataModule.from_folders("./data/train", "./data/val", "./data/test", "./data/predict")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

model = ImageClassifier(...)
trainer = Trainer(...)

trainer.fit(model, dm)



*************
API reference
*************

.. _preprocess:

Preprocess
__________

.. autoclass:: flash.data.process.Preprocess
:members:


----------

.. _postprocess:

Postprocess
___________


.. autoclass:: flash.data.process.Postprocess
:members:


----------

.. _datapipeline:

DataPipeline
------------
____________

To make tasks work for inference, one must create a ``Preprocess`` and ``PostProcess``.
The ``flash.data.process.Preprocess`` exposes 9 hooks to override which can specifialzed for each stage using
``train``, ``val``, ``test``, ``predict`` prefixes.
.. autoclass:: flash.data.data_pipeline.DataPipeline
:members:
69 changes: 69 additions & 0 deletions docs/source/general/registry.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
########
Registry
########

.. _registry:

********************
Available Registries
********************

Registries are mapping from a name and metadata to a function.
It helps organize code and make the functions accessible all across the ``Flash`` codebase.

Example::

from flash.vision.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES

@IMAGE_CLASSIFIER_BACKBONES(name="username/my_backbone"):
def fn(args_1, ... args_n):
return backbone, num_features

_fn = IMAGE_CLASSIFIER_BACKBONES.get("username/my_backbone")
backbone, num_features = _fn(args_1, ..., args_n)


Each Flash ``Task`` can have several registries as static attributes.

Example::

from flash.vision import ImageClassifier
from flash.core.registry import FlashRegistry

class MyImageClassifier(ImageClassifier):

# set the registry as a static attribute
backbones = FlashRegistry("backbones")

# Option 1: Used with partial.
def fn():
# Create backbone and backbone output dimension (`num_features`)
return backbone, num_features

# HINT 1: Use `from functools import partial` if you want to store some arguments.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
MyImageClassifier.backbones(fn=fn, name="username/my_backbone")


# Option 2: Using decorator.
@MyImageClassifier.backbones(name="username/my_backbone")
def fn():
# Create backbone and backbone output dimension (`num_features`)
return backbone, num_features

# The new key should be listed in available backbones
print(MyImageClassifier.available_backbones())
# out: ["username/my_backbone"]

# Create a model with your backbone !
model = MyImageClassifier(backbone="username/my_backbone")

**************
Flash Registry
**************


FlashRegistry
_____________

.. autoclass:: flash.core.registry.FlashRegistry
:members:
3 changes: 3 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Lightning Flash

general/model
general/data
general/callback
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about creating an "Advanced" submenu here that will have all the API + creating a custom task tutorial?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we will do so in another PR. Definitely need some re-organizations.

general/registry


.. toctree::
:maxdepth: 1
Expand Down
9 changes: 1 addition & 8 deletions flash/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,12 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from flash import tabular, text, vision # noqa: E402
from flash.core.classification import ClassificationTask # noqa: E402
from flash.core.model import Task # noqa: E402
from flash.core.trainer import Trainer # noqa: E402
from flash.data.data_module import DataModule # noqa: E402
from flash.data.utils import download_data # noqa: E402

__all__ = [
"Task",
"ClassificationTask",
"Trainer",
"DataModule",
"vision",
"text",
"tabular",
"download_data",
]
21 changes: 9 additions & 12 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def predict(
x = self.transfer_batch_to_device(x, self.device)
x = data_pipeline.device_preprocessor(running_stage)(x)
predictions = self.predict_step(x, 0) # batch_idx is always 0 when running with `model.predict`
predictions = data_pipeline.postprocessor(predictions)
predictions = data_pipeline.postprocessor(running_stage)(predictions)
return predictions

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
Expand Down Expand Up @@ -176,7 +176,11 @@ def preprocess(self, preprocess: Preprocess) -> None:

@property
def postprocess(self) -> Postprocess:
return getattr(self._data_pipeline, '_postprocess_pipeline', None) or self._postprocess
_postprocess_pipeline = getattr(self._data_pipeline, '_postprocess_pipeline', None)
if type(_postprocess_pipeline) != Postprocess:
return _postprocess_pipeline
else:
return self._postprocess or self.postprocess_cls()

@postprocess.setter
def postprocess(self, postprocess: Postprocess) -> None:
Expand All @@ -185,10 +189,7 @@ def postprocess(self, postprocess: Postprocess) -> None:

@property
def data_pipeline(self) -> Optional[DataPipeline]:
if self._data_pipeline is not None:
return self._data_pipeline

elif self.preprocess is not None or self.postprocess is not None:
if self.preprocess is not None or self.postprocess is not None:
# use direct attributes here to avoid recursion with properties that also check the data_pipeline property
return DataPipeline(self.preprocess, self.postprocess)

Expand All @@ -205,12 +206,8 @@ def data_pipeline(self) -> Optional[DataPipeline]:
@data_pipeline.setter
def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None:
self._data_pipeline = data_pipeline
if data_pipeline is not None and getattr(data_pipeline, '_preprocess_pipeline', None) is not None:
self._preprocess = data_pipeline._preprocess_pipeline

if data_pipeline is not None and getattr(data_pipeline, '_postprocess_pipeline', None) is not None:
if type(data_pipeline._postprocess_pipeline) != Postprocess:
self._postprocess = data_pipeline._postprocess_pipeline
self._preprocess = self.preprocess
self._postprocess = self.postprocess

def on_train_dataloader(self) -> None:
if self.data_pipeline is not None:
Expand Down
Loading