Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Adds a template task and docs (#306)
Browse files Browse the repository at this point in the history
* Initial commit

* Updates

* Updates

* Updates

* Remove template README

* Fixes

* Updates

* Add examples

* Updates

* Updates

* Updates

* Updates

* Add tests

* Updates

* Fixes

* A fix

* Fixes

* More tests

* Updates

* Fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update docs/source/reference/template.rst

Co-authored-by: Edgar Riba <[email protected]>

* Respond to comments

* updates

* Update docs/source/template/data.rst

Co-authored-by: edenlightning <[email protected]>

* Update docs/source/template/data.rst

Co-authored-by: edenlightning <[email protected]>

* Update docs/source/template/data.rst

Co-authored-by: edenlightning <[email protected]>

* Update docs/source/template/model.rst

Co-authored-by: edenlightning <[email protected]>

* Updates

* Updates

* Fixes

* Updates

* Updates

* Updates

* Fixes

* Fixes

* Fix

* Add backbones

* Add backbones

* Updates

* Updates

* Updates

* Fixes

* Add links

* Fixes

* Simplify

* Update CHANGELOG.md

* Update docs/source/template/optional.rst

Co-authored-by: edenlightning <[email protected]>

* Update docs/source/template/optional.rst

Co-authored-by: edenlightning <[email protected]>

* Update docs/source/template/task.rst

Co-authored-by: edenlightning <[email protected]>

* Updates

* Updates

* Updates

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edgar Riba <[email protected]>
Co-authored-by: edenlightning <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
5 people authored May 19, 2021
1 parent 65c658b commit cb7f906
Show file tree
Hide file tree
Showing 25 changed files with 1,468 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Semantic Segmentation task ([#239](https://github.com/PyTorchLightning/lightning-flash/pull/239) [#287](https://github.com/PyTorchLightning/lightning-flash/pull/287) [#290](https://github.com/PyTorchLightning/lightning-flash/pull/290))
- Added Object detection prediction example ([#283](https://github.com/PyTorchLightning/lightning-flash/pull/283))
- Added Style Transfer task and accompanying finetuning and prediction examples ([#262](https://github.com/PyTorchLightning/lightning-flash/pull/262))
- Added a Template task and tutorials showing how to contribute a task to flash ([#306](https://github.com/PyTorchLightning/lightning-flash/pull/306))

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _load_py_module(fname, pkg="flash"):
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"pytorchvideo": ("https://pytorchvideo.readthedocs.io/en/latest/", None),
"pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None),
}

Expand Down
1 change: 1 addition & 0 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Tutorial: Creating a Custom Task
In this tutorial we will go over the process of creating a custom :class:`~flash.core.model.Task`,
along with a custom :class:`~flash.core.data.data_module.DataModule`.

.. note:: This tutorial is only intended to help you create a small custom task for a personal project. If you want a more detailed guide, have a look at our :ref:`guide on contributing a task to flash. <contributing>`

The tutorial objective is to create a ``RegressionTask`` to learn to predict if someone has ``diabetes`` or not.
We will use ``scikit-learn`` `Diabetes dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-dataset>`__.
Expand Down
19 changes: 19 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ Lightning Flash
general/finetuning
general/predictions


.. toctree::
:maxdepth: 1
:caption: Contributing a Task

template/intro
template/data
template/backbones
template/task
template/optional
template/examples
template/tests
template/docs

.. toctree::
:hidden:

reference/template

Indices and tables
==================

Expand Down
75 changes: 75 additions & 0 deletions docs/source/reference/template.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@

.. _template:

########
Template
########

********
The task
********

Here you should add a description of your task. For example:
Classification is the task of assigning one of a number of classes to each data point.
The :class:`~flash.template.TemplateSKLearnClassifier` is a :class:`~flash.core.model.Task` for classifying the datasets included with scikit-learn.

------

*********
Inference
*********

Here, you should add a short intro to your predict example, and then use ``literalinclude`` to add it.

.. note:: We skip the first 14 lines as they are just the copyright notice.

Our predict example uses a model pre-trained on the Iris data.

.. literalinclude:: ../../../flash_examples/predict/template.py
:language: python
:lines: 14-

For more advanced inference options, see :ref:`predictions`.

------

********
Training
********

In this section, we briefly describe the data, and then ``literalinclude`` our finetuning example.

Now we'll train on Fisher's classic Iris data.
It contains 150 records with four features (sepal length, sepal width, petal length, and petal width) in three classes (species of Iris: setosa, virginica and versicolor).

Now all we need is to train our task!

.. literalinclude:: ../../../flash_examples/finetuning/template.py
:language: python
:lines: 14-

------

*************
API reference
*************

We usually include the API reference for the :class:`~flash.core.model.Task` and :class:`~flash.core.data.data_module.DataModule`.
You can optionally add the other classes you've implemented.
To add the API reference, use the ``autoclass`` directive.

.. _template_classifier:

TemplateSKLearnClassifier
-------------------------

.. autoclass:: flash.template.TemplateSKLearnClassifier
:members:
:exclude-members: forward

.. _template_data:

TemplateData
------------

.. autoclass:: flash.template.TemplateData
43 changes: 43 additions & 0 deletions docs/source/template/backbones.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
.. _contributing_backbones:

*************
The Backbones
*************

Now that you've got a way of loading data, you should implement some backbones to use with your :class:`~flash.core.model.Task`.
Create a :any:`FlashRegistry <registry>` to use with your :class:`~flash.core.model.Task` in `backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/template/classification/backbones.py>`_.

The registry allows you to register backbones for your task that can be selected by the user.
The backbones can come from anywhere as long as you can register a function that loads the backbone.
Furthermore, the user can add their own models to the existing backbones, without having to write their own :class:`~flash.core.model.Task`!

You can create a registry like this:

.. code-block:: python
TEMPLATE_BACKBONES = FlashRegistry("backbones")
Let's add a simple MLP backbone to our registry.
We need a function that creates the backbone and returns it along with the output size (so that we can create the model head in our :class:`~flash.core.model.Task`).
You can use any name for the function, although we use ``load_{model name}`` by convention.
Here's the code:

.. literalinclude:: ../../../flash/template/classification/backbones.py
:language: python
:pyobject: load_mlp_128

Here's another example with a slightly more complex model:

.. literalinclude:: ../../../flash/template/classification/backbones.py
:language: python
:pyobject: load_mlp_128_256

Here's a more advanced example, which adds ``SimCLR`` to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/backbones.py>`_:

.. literalinclude:: ../../../flash/image/backbones.py
:language: python
:pyobject: load_simclr_imagenet

------

Once you've got some data and some backbones, :ref:`implement your task! <contributing_task>`
212 changes: 212 additions & 0 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
.. _contributing_data:

********
The Data
********

The first step to contributing a task is to implement the classes we need to load some data.
Inside `data.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/template/classification/data.py>`_ you should implement:

#. some :class:`~flash.core.data.data_source.DataSource` classes *(optional)*
#. a :class:`~flash.core.data.process.Preprocess`
#. a :class:`~flash.core.data.data_module.DataModule`
#. a :class:`~flash.core.data.base_viz.BaseVisualization` *(optional)*
#. a :class:`~flash.core.data.process.Postprocess` *(optional)*

DataSource
^^^^^^^^^^

The :class:`~flash.core.data.data_source.DataSource` class contains the logic for data loading from different sources such as folders, files, tensors, etc.
Every Flash :class:`~flash.core.data.data_module.DataModule` can be instantiated with :meth:`~flash.core.data.data_module.DataModule.from_datasets`.
For each additional way you want the user to be able to instantiate your :class:`~flash.core.data.data_module.DataModule`, you'll need to create a :class:`~flash.core.data.data_source.DataSource`.
Each :class:`~flash.core.data.data_source.DataSource` has 2 methods:

- :meth:`~flash.core.data.data_source.DataSource.load_data` takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata.
- :meth:`~flash.core.data.data_source.DataSource.load_sample` then takes as input a single element from the output of ``load_data`` and returns a sample.

By default these methods just return their input, so you don't need both a :meth:`~flash.core.data.data_source.DataSource.load_data` and a :meth:`~flash.core.data.data_source.DataSource.load_sample` to create a :class:`~flash.core.data.data_source.DataSource`.
Where possible, you should override one of our existing :class:`~flash.core.data.data_source.DataSource` classes.

Let's start by implementing a ``TemplateNumpyDataSource``, which overrides :class:`~flash.core.data.data_source.NumpyDataSource`.
The main :class:`~flash.core.data.data_source.DataSource` method that we have to implement is :meth:`~flash.core.data.data_source.DataSource.load_data`.
As we're extending the ``NumpyDataSource``, we expect the same ``data`` argument (in this case, a tuple containing data and corresponding target arrays).

We can also take the dataset argument.
Any attributes we set on ``dataset`` will be available on the :class:`~torch.utils.data.Dataset` generated by our :class:`~flash.core.data.data_source.DataSource`.
In this data source, we'll set the ``num_features`` attribute.

Here's the code for our ``TemplateNumpyDataSource.load_data`` method:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplateNumpyDataSource.load_data

.. note:: Later, when we add :ref:`our DataModule implementation <contributing_data_module>`, we'll make ``num_features`` available to the user.

Sometimes you need to something a bit more custom.
When creating a custom :class:`~flash.core.data.data_source.DataSource`, the type of the ``data`` argument is up to you.
For our template :class:`~flash.core.data.model.Task`, it would be cool if the user could provide a scikit-learn ``Bunch`` as the data source.
To achieve this, we'll add a ``TemplateSKLearnDataSource`` whose ``load_data`` expects a ``Bunch`` as input.
We override our ``TemplateNumpyDataSource`` so that we can call ``super`` with the data and targets extracted from the ``Bunch``.
We perform two additional steps here to improve the user experience:

1. We set the ``num_classes`` attribute on the ``dataset``. If ``num_classes`` is set, it is automatically made available as a property of the :class:`~flash.core.data.data_module.DataModule`.
2. We create and set a :class:`~flash.core.data.data_source.LabelsState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` serializer, so the user doesn't need to provide them.

Here's the code for the ``TemplateSKLearnDataSource.load_data`` method:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplateSKLearnDataSource.load_data

We can customize the behaviour of our :meth:`~flash.core.data.data_source.DataSource.load_data` for different stages, by prepending `train`, `val`, `test`, or `predict`.
For our ``TemplateSKLearnDataSource``, we don't want to provide any targets to the model when predicting.
We can implement ``predict_load_data`` like this:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplateSKLearnDataSource.predict_load_data

DataSource vs Dataset
~~~~~~~~~~~~~~~~~~~~~

A :class:`~flash.core.data.data_source.DataSource` is not the same as a :class:`torch.utils.data.Dataset`.
When a ``from_*`` method is called on your :class:`~flash.core.data.data_module.DataModule`, it gets the :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.process.Preprocess`.
A :class:`~torch.utils.data.Dataset` is then created from the :class:`~flash.core.data.data_source.DataSource` for each stage (`train`, `val`, `test`, `predict`) using the provided metadata (e.g. folder name, numpy array etc.).

The output of the :meth:`~flash.core.data.data_source.DataSource.load_data` can just be a :class:`torch.utils.data.Dataset` instance.
If the library that your :class:`~flash.core.data.model.Task` is based on provides a custom dataset, you don't need to re-write it as a :class:`~flash.core.data.data_source.DataSource`.
For example, the :meth:`~flash.core.data.data_source.DataSource.load_data` of the ``VideoClassificationPathsDataSource`` just creates an :class:`~pytorchvideo.data.EncodedVideoDataset` from the given folder.
Here's how it looks (from `video/classification.data.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/video/classification/data.py>`_):

.. literalinclude:: ../../../flash/video/classification/data.py
:language: python
:dedent: 4
:pyobject: VideoClassificationPathsDataSource.load_data

Preprocess
^^^^^^^^^^

The :class:`~flash.core.data.process.Preprocess` object contains all the data transforms.
Internally we inject the :class:`~flash.core.data.process.Preprocess` transforms at several points along the pipeline.

Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your :class:`~flash.core.data.process.Preprocess` is as simple as implementing the ``default_transforms`` method.
The :class:`~flash.core.data.process.Preprocess` must take ``train_transform``, ``val_transform``, ``test_transform``, and ``predict_transform`` arguments in the ``__init__``.
These arguments can be provided by the user (when creating the :class:`~flash.core.data.data_module.DataModule`) to override the default transforms.
Any additional arguments are up to you.

Inside the ``__init__``, we make a call to super.
This is where we register our data sources.
Data sources should be given as a dictionary which maps data source name to data source object.
The name can be anything, but if you want to take advantage of our built-in ``from_*`` classmethods, you should use :class:`~flash.core.data.data_source.DefaultDataSources` as the names.
In our case, we have both a :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` and a custom scikit-learn data source (which we'll call `"sklearn"`).

You should also provide a ``default_data_source``.
This is the name of the data source to use by default when predicting.
It'd be cool if we could get predictions just from a numpy array, so we'll use :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` as the default.

Here's our ``TemplatePreprocess.__init__``:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplatePreprocess.__init__

For our ``TemplatePreprocess``, we'll just configure a default ``to_tensor_transform``.
Let's first define the transform as a ``staticmethod``:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplatePreprocess.input_to_tensor

Our inputs samples will be dictionaries whose keys are in the :class:`~flash.core.data.data_source.DefaultDataKeys`.
You can map each key to different transforms using :class:`~flash.core.data.transforms.ApplyToKeys`.
Here's our ``default_transforms`` method:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplatePreprocess.default_transforms

.. _contributing_data_module:

DataModule
^^^^^^^^^^

The :class:`~flash.core.data.data_module.DataModule` is responsible for creating the :class:`~torch.utils.data.DataLoader` and injecting the transforms for each stage.
When the user calls a ``from_*`` method (such as :meth:`~flash.core.data.data_module.DataModule.from_numpy`), the following steps take place:

#. The :meth:`~flash.core.data.data_module.DataModule.from_data_source` method is called with the name of the :class:`~flash.core.data.data_source.DataSource` to use and the inputs to provide to :meth:`~flash.core.data.data_source.DataSource.load_data` for each stage.
#. The :class:`~flash.core.data.process.Preprocess` is created from ``cls.preprocess_cls`` (if it wasn't provided by the user) with any provided transforms.
#. The :class:`~flash.core.data.data_source.DataSource` of the provided name is retrieved from the :class:`~flash.core.data.process.Preprocess`.
#. A :class:`~flash.core.data.auto_dataset.BaseAutoDataset` is created from the :class:`~flash.core.data.data_source.DataSource` for each stage.
#. The :class:`~flash.core.data.data_module.DataModule` is instantiated with the data sets.

|
To create our ``TemplateData`` :class:`~flash.core.data.data_module.DataModule`, we first need to attach out preprocess class like this:

.. code-block:: python
preprocess_cls = TemplatePreprocess
Since we provided a :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` :class:`~flash.core.data.data_source.DataSource` in the ``TemplatePreprocess``, :meth:`~flash.core.data.data_module.DataModule.from_numpy` will now work with our ``TemplateData``.

If you've defined a fully custom :class:`~flash.core.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need to write a ``from_*`` method for each.
Here's the ``from_sklearn`` method for our ``TemplateData``:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplateData.from_sklearn

The final step is to implement the ``num_features`` property for our ``TemplateData``.
This is just a convenience for the user that finds the ``num_features`` attribute on any of the data sets and returns it.
Here's the code:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplateData.num_features

BaseVisualization
^^^^^^^^^^^^^^^^^

An optional step is to implement a :class:`~flash.core.data.base_viz.BaseVisualization`.
The :class:`~flash.core.data.base_viz.BaseVisualization` lets you control how data at various points in the pipeline can be visualized.
This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms.

.. note::
Don't worry about implementing it right away, you can always come back and add it later!

Here's the code for our ``TemplateVisualization`` which just prints the data:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:pyobject: TemplateVisualization

We can configure our custom visualization in the ``TemplateData`` using :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` like this:

.. literalinclude:: ../../../flash/template/classification/data.py
:language: python
:dedent: 4
:pyobject: TemplateData.configure_data_fetcher

Postprocess
^^^^^^^^^^^

:class:`~flash.core.data.process.Postprocess` contains any transforms that need to be applied *after* the model.
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc.
As an example, here's the :class:`~text.classification.data.TextClassificationPostProcess` which gets the logits from a ``SequenceClassifierOutput``:

.. literalinclude:: ../../../flash/text/classification/data.py
:language: python
:pyobject: TextClassificationPostProcess

------

Now that you've got some data, it's time to :ref:`add some backbones for your task! <contributing_backbones>`
Loading

0 comments on commit cb7f906

Please sign in to comment.