This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a template task and docs (#306)
* Initial commit * Updates * Updates * Updates * Remove template README * Fixes * Updates * Add examples * Updates * Updates * Updates * Updates * Add tests * Updates * Fixes * A fix * Fixes * More tests * Updates * Fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update docs/source/reference/template.rst Co-authored-by: Edgar Riba <[email protected]> * Respond to comments * updates * Update docs/source/template/data.rst Co-authored-by: edenlightning <[email protected]> * Update docs/source/template/data.rst Co-authored-by: edenlightning <[email protected]> * Update docs/source/template/data.rst Co-authored-by: edenlightning <[email protected]> * Update docs/source/template/model.rst Co-authored-by: edenlightning <[email protected]> * Updates * Updates * Fixes * Updates * Updates * Updates * Fixes * Fixes * Fix * Add backbones * Add backbones * Updates * Updates * Updates * Fixes * Add links * Fixes * Simplify * Update CHANGELOG.md * Update docs/source/template/optional.rst Co-authored-by: edenlightning <[email protected]> * Update docs/source/template/optional.rst Co-authored-by: edenlightning <[email protected]> * Update docs/source/template/task.rst Co-authored-by: edenlightning <[email protected]> * Updates * Updates * Updates Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Edgar Riba <[email protected]> Co-authored-by: edenlightning <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
- Loading branch information
1 parent
65c658b
commit cb7f906
Showing
25 changed files
with
1,468 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
|
||
.. _template: | ||
|
||
######## | ||
Template | ||
######## | ||
|
||
******** | ||
The task | ||
******** | ||
|
||
Here you should add a description of your task. For example: | ||
Classification is the task of assigning one of a number of classes to each data point. | ||
The :class:`~flash.template.TemplateSKLearnClassifier` is a :class:`~flash.core.model.Task` for classifying the datasets included with scikit-learn. | ||
|
||
------ | ||
|
||
********* | ||
Inference | ||
********* | ||
|
||
Here, you should add a short intro to your predict example, and then use ``literalinclude`` to add it. | ||
|
||
.. note:: We skip the first 14 lines as they are just the copyright notice. | ||
|
||
Our predict example uses a model pre-trained on the Iris data. | ||
|
||
.. literalinclude:: ../../../flash_examples/predict/template.py | ||
:language: python | ||
:lines: 14- | ||
|
||
For more advanced inference options, see :ref:`predictions`. | ||
|
||
------ | ||
|
||
******** | ||
Training | ||
******** | ||
|
||
In this section, we briefly describe the data, and then ``literalinclude`` our finetuning example. | ||
|
||
Now we'll train on Fisher's classic Iris data. | ||
It contains 150 records with four features (sepal length, sepal width, petal length, and petal width) in three classes (species of Iris: setosa, virginica and versicolor). | ||
|
||
Now all we need is to train our task! | ||
|
||
.. literalinclude:: ../../../flash_examples/finetuning/template.py | ||
:language: python | ||
:lines: 14- | ||
|
||
------ | ||
|
||
************* | ||
API reference | ||
************* | ||
|
||
We usually include the API reference for the :class:`~flash.core.model.Task` and :class:`~flash.core.data.data_module.DataModule`. | ||
You can optionally add the other classes you've implemented. | ||
To add the API reference, use the ``autoclass`` directive. | ||
|
||
.. _template_classifier: | ||
|
||
TemplateSKLearnClassifier | ||
------------------------- | ||
|
||
.. autoclass:: flash.template.TemplateSKLearnClassifier | ||
:members: | ||
:exclude-members: forward | ||
|
||
.. _template_data: | ||
|
||
TemplateData | ||
------------ | ||
|
||
.. autoclass:: flash.template.TemplateData |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
.. _contributing_backbones: | ||
|
||
************* | ||
The Backbones | ||
************* | ||
|
||
Now that you've got a way of loading data, you should implement some backbones to use with your :class:`~flash.core.model.Task`. | ||
Create a :any:`FlashRegistry <registry>` to use with your :class:`~flash.core.model.Task` in `backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/template/classification/backbones.py>`_. | ||
|
||
The registry allows you to register backbones for your task that can be selected by the user. | ||
The backbones can come from anywhere as long as you can register a function that loads the backbone. | ||
Furthermore, the user can add their own models to the existing backbones, without having to write their own :class:`~flash.core.model.Task`! | ||
|
||
You can create a registry like this: | ||
|
||
.. code-block:: python | ||
TEMPLATE_BACKBONES = FlashRegistry("backbones") | ||
Let's add a simple MLP backbone to our registry. | ||
We need a function that creates the backbone and returns it along with the output size (so that we can create the model head in our :class:`~flash.core.model.Task`). | ||
You can use any name for the function, although we use ``load_{model name}`` by convention. | ||
Here's the code: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/backbones.py | ||
:language: python | ||
:pyobject: load_mlp_128 | ||
|
||
Here's another example with a slightly more complex model: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/backbones.py | ||
:language: python | ||
:pyobject: load_mlp_128_256 | ||
|
||
Here's a more advanced example, which adds ``SimCLR`` to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/backbones.py>`_: | ||
|
||
.. literalinclude:: ../../../flash/image/backbones.py | ||
:language: python | ||
:pyobject: load_simclr_imagenet | ||
|
||
------ | ||
|
||
Once you've got some data and some backbones, :ref:`implement your task! <contributing_task>` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
.. _contributing_data: | ||
|
||
******** | ||
The Data | ||
******** | ||
|
||
The first step to contributing a task is to implement the classes we need to load some data. | ||
Inside `data.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/template/classification/data.py>`_ you should implement: | ||
|
||
#. some :class:`~flash.core.data.data_source.DataSource` classes *(optional)* | ||
#. a :class:`~flash.core.data.process.Preprocess` | ||
#. a :class:`~flash.core.data.data_module.DataModule` | ||
#. a :class:`~flash.core.data.base_viz.BaseVisualization` *(optional)* | ||
#. a :class:`~flash.core.data.process.Postprocess` *(optional)* | ||
|
||
DataSource | ||
^^^^^^^^^^ | ||
|
||
The :class:`~flash.core.data.data_source.DataSource` class contains the logic for data loading from different sources such as folders, files, tensors, etc. | ||
Every Flash :class:`~flash.core.data.data_module.DataModule` can be instantiated with :meth:`~flash.core.data.data_module.DataModule.from_datasets`. | ||
For each additional way you want the user to be able to instantiate your :class:`~flash.core.data.data_module.DataModule`, you'll need to create a :class:`~flash.core.data.data_source.DataSource`. | ||
Each :class:`~flash.core.data.data_source.DataSource` has 2 methods: | ||
|
||
- :meth:`~flash.core.data.data_source.DataSource.load_data` takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata. | ||
- :meth:`~flash.core.data.data_source.DataSource.load_sample` then takes as input a single element from the output of ``load_data`` and returns a sample. | ||
|
||
By default these methods just return their input, so you don't need both a :meth:`~flash.core.data.data_source.DataSource.load_data` and a :meth:`~flash.core.data.data_source.DataSource.load_sample` to create a :class:`~flash.core.data.data_source.DataSource`. | ||
Where possible, you should override one of our existing :class:`~flash.core.data.data_source.DataSource` classes. | ||
|
||
Let's start by implementing a ``TemplateNumpyDataSource``, which overrides :class:`~flash.core.data.data_source.NumpyDataSource`. | ||
The main :class:`~flash.core.data.data_source.DataSource` method that we have to implement is :meth:`~flash.core.data.data_source.DataSource.load_data`. | ||
As we're extending the ``NumpyDataSource``, we expect the same ``data`` argument (in this case, a tuple containing data and corresponding target arrays). | ||
|
||
We can also take the dataset argument. | ||
Any attributes we set on ``dataset`` will be available on the :class:`~torch.utils.data.Dataset` generated by our :class:`~flash.core.data.data_source.DataSource`. | ||
In this data source, we'll set the ``num_features`` attribute. | ||
|
||
Here's the code for our ``TemplateNumpyDataSource.load_data`` method: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateNumpyDataSource.load_data | ||
|
||
.. note:: Later, when we add :ref:`our DataModule implementation <contributing_data_module>`, we'll make ``num_features`` available to the user. | ||
|
||
Sometimes you need to something a bit more custom. | ||
When creating a custom :class:`~flash.core.data.data_source.DataSource`, the type of the ``data`` argument is up to you. | ||
For our template :class:`~flash.core.data.model.Task`, it would be cool if the user could provide a scikit-learn ``Bunch`` as the data source. | ||
To achieve this, we'll add a ``TemplateSKLearnDataSource`` whose ``load_data`` expects a ``Bunch`` as input. | ||
We override our ``TemplateNumpyDataSource`` so that we can call ``super`` with the data and targets extracted from the ``Bunch``. | ||
We perform two additional steps here to improve the user experience: | ||
|
||
1. We set the ``num_classes`` attribute on the ``dataset``. If ``num_classes`` is set, it is automatically made available as a property of the :class:`~flash.core.data.data_module.DataModule`. | ||
2. We create and set a :class:`~flash.core.data.data_source.LabelsState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` serializer, so the user doesn't need to provide them. | ||
|
||
Here's the code for the ``TemplateSKLearnDataSource.load_data`` method: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateSKLearnDataSource.load_data | ||
|
||
We can customize the behaviour of our :meth:`~flash.core.data.data_source.DataSource.load_data` for different stages, by prepending `train`, `val`, `test`, or `predict`. | ||
For our ``TemplateSKLearnDataSource``, we don't want to provide any targets to the model when predicting. | ||
We can implement ``predict_load_data`` like this: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateSKLearnDataSource.predict_load_data | ||
|
||
DataSource vs Dataset | ||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
A :class:`~flash.core.data.data_source.DataSource` is not the same as a :class:`torch.utils.data.Dataset`. | ||
When a ``from_*`` method is called on your :class:`~flash.core.data.data_module.DataModule`, it gets the :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.process.Preprocess`. | ||
A :class:`~torch.utils.data.Dataset` is then created from the :class:`~flash.core.data.data_source.DataSource` for each stage (`train`, `val`, `test`, `predict`) using the provided metadata (e.g. folder name, numpy array etc.). | ||
|
||
The output of the :meth:`~flash.core.data.data_source.DataSource.load_data` can just be a :class:`torch.utils.data.Dataset` instance. | ||
If the library that your :class:`~flash.core.data.model.Task` is based on provides a custom dataset, you don't need to re-write it as a :class:`~flash.core.data.data_source.DataSource`. | ||
For example, the :meth:`~flash.core.data.data_source.DataSource.load_data` of the ``VideoClassificationPathsDataSource`` just creates an :class:`~pytorchvideo.data.EncodedVideoDataset` from the given folder. | ||
Here's how it looks (from `video/classification.data.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/video/classification/data.py>`_): | ||
|
||
.. literalinclude:: ../../../flash/video/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: VideoClassificationPathsDataSource.load_data | ||
|
||
Preprocess | ||
^^^^^^^^^^ | ||
|
||
The :class:`~flash.core.data.process.Preprocess` object contains all the data transforms. | ||
Internally we inject the :class:`~flash.core.data.process.Preprocess` transforms at several points along the pipeline. | ||
|
||
Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your :class:`~flash.core.data.process.Preprocess` is as simple as implementing the ``default_transforms`` method. | ||
The :class:`~flash.core.data.process.Preprocess` must take ``train_transform``, ``val_transform``, ``test_transform``, and ``predict_transform`` arguments in the ``__init__``. | ||
These arguments can be provided by the user (when creating the :class:`~flash.core.data.data_module.DataModule`) to override the default transforms. | ||
Any additional arguments are up to you. | ||
|
||
Inside the ``__init__``, we make a call to super. | ||
This is where we register our data sources. | ||
Data sources should be given as a dictionary which maps data source name to data source object. | ||
The name can be anything, but if you want to take advantage of our built-in ``from_*`` classmethods, you should use :class:`~flash.core.data.data_source.DefaultDataSources` as the names. | ||
In our case, we have both a :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` and a custom scikit-learn data source (which we'll call `"sklearn"`). | ||
|
||
You should also provide a ``default_data_source``. | ||
This is the name of the data source to use by default when predicting. | ||
It'd be cool if we could get predictions just from a numpy array, so we'll use :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` as the default. | ||
|
||
Here's our ``TemplatePreprocess.__init__``: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplatePreprocess.__init__ | ||
|
||
For our ``TemplatePreprocess``, we'll just configure a default ``to_tensor_transform``. | ||
Let's first define the transform as a ``staticmethod``: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplatePreprocess.input_to_tensor | ||
|
||
Our inputs samples will be dictionaries whose keys are in the :class:`~flash.core.data.data_source.DefaultDataKeys`. | ||
You can map each key to different transforms using :class:`~flash.core.data.transforms.ApplyToKeys`. | ||
Here's our ``default_transforms`` method: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplatePreprocess.default_transforms | ||
|
||
.. _contributing_data_module: | ||
|
||
DataModule | ||
^^^^^^^^^^ | ||
|
||
The :class:`~flash.core.data.data_module.DataModule` is responsible for creating the :class:`~torch.utils.data.DataLoader` and injecting the transforms for each stage. | ||
When the user calls a ``from_*`` method (such as :meth:`~flash.core.data.data_module.DataModule.from_numpy`), the following steps take place: | ||
|
||
#. The :meth:`~flash.core.data.data_module.DataModule.from_data_source` method is called with the name of the :class:`~flash.core.data.data_source.DataSource` to use and the inputs to provide to :meth:`~flash.core.data.data_source.DataSource.load_data` for each stage. | ||
#. The :class:`~flash.core.data.process.Preprocess` is created from ``cls.preprocess_cls`` (if it wasn't provided by the user) with any provided transforms. | ||
#. The :class:`~flash.core.data.data_source.DataSource` of the provided name is retrieved from the :class:`~flash.core.data.process.Preprocess`. | ||
#. A :class:`~flash.core.data.auto_dataset.BaseAutoDataset` is created from the :class:`~flash.core.data.data_source.DataSource` for each stage. | ||
#. The :class:`~flash.core.data.data_module.DataModule` is instantiated with the data sets. | ||
|
||
| | ||
To create our ``TemplateData`` :class:`~flash.core.data.data_module.DataModule`, we first need to attach out preprocess class like this: | ||
|
||
.. code-block:: python | ||
preprocess_cls = TemplatePreprocess | ||
Since we provided a :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` :class:`~flash.core.data.data_source.DataSource` in the ``TemplatePreprocess``, :meth:`~flash.core.data.data_module.DataModule.from_numpy` will now work with our ``TemplateData``. | ||
|
||
If you've defined a fully custom :class:`~flash.core.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need to write a ``from_*`` method for each. | ||
Here's the ``from_sklearn`` method for our ``TemplateData``: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateData.from_sklearn | ||
|
||
The final step is to implement the ``num_features`` property for our ``TemplateData``. | ||
This is just a convenience for the user that finds the ``num_features`` attribute on any of the data sets and returns it. | ||
Here's the code: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateData.num_features | ||
|
||
BaseVisualization | ||
^^^^^^^^^^^^^^^^^ | ||
|
||
An optional step is to implement a :class:`~flash.core.data.base_viz.BaseVisualization`. | ||
The :class:`~flash.core.data.base_viz.BaseVisualization` lets you control how data at various points in the pipeline can be visualized. | ||
This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms. | ||
|
||
.. note:: | ||
Don't worry about implementing it right away, you can always come back and add it later! | ||
|
||
Here's the code for our ``TemplateVisualization`` which just prints the data: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:pyobject: TemplateVisualization | ||
|
||
We can configure our custom visualization in the ``TemplateData`` using :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` like this: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateData.configure_data_fetcher | ||
|
||
Postprocess | ||
^^^^^^^^^^^ | ||
|
||
:class:`~flash.core.data.process.Postprocess` contains any transforms that need to be applied *after* the model. | ||
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. | ||
As an example, here's the :class:`~text.classification.data.TextClassificationPostProcess` which gets the logits from a ``SequenceClassifierOutput``: | ||
|
||
.. literalinclude:: ../../../flash/text/classification/data.py | ||
:language: python | ||
:pyobject: TextClassificationPostProcess | ||
|
||
------ | ||
|
||
Now that you've got some data, it's time to :ref:`add some backbones for your task! <contributing_backbones>` |
Oops, something went wrong.