-
Notifications
You must be signed in to change notification settings - Fork 212
Adds a template task and docs #306
Changes from all commits
f5e3c49
4cccd79
838e40c
49de5a0
7735349
ce2108f
04b3b96
f79f909
9834a47
53a1ba2
2694f46
28b5eec
c552635
65f9bdd
cc3001a
b4102f0
4ae69fa
3bcf221
eb7c3e4
839c99a
e2df1ee
afe8142
bee8bdd
382c2cb
3a24117
acd302e
907c927
0af0d28
9a2be0e
9740580
b6d57a2
084eb6e
e390d32
166dd4d
2f52577
0c0780c
3fdba4a
fa6ba79
7b201b3
96df2c2
9a9cfd4
fe2cff7
8167884
ad976f4
fecb316
b4d952c
c7b7806
23f2f20
ba83757
1c43ec9
5aa7cf5
4d35762
36a6538
17085fb
4550e04
5ebc71e
71de79c
0850333
4fd0344
c21e816
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
|
||
.. _template: | ||
|
||
######## | ||
Template | ||
######## | ||
|
||
******** | ||
The task | ||
******** | ||
|
||
Here you should add a description of your task. For example: | ||
Classification is the task of assigning one of a number of classes to each data point. | ||
The :class:`~flash.template.TemplateSKLearnClassifier` is a :class:`~flash.core.model.Task` for classifying the datasets included with scikit-learn. | ||
|
||
------ | ||
|
||
********* | ||
Inference | ||
********* | ||
|
||
Here, you should add a short intro to your predict example, and then use ``literalinclude`` to add it. | ||
|
||
.. note:: We skip the first 14 lines as they are just the copyright notice. | ||
|
||
Our predict example uses a model pre-trained on the Iris data. | ||
|
||
.. literalinclude:: ../../../flash_examples/predict/template.py | ||
:language: python | ||
:lines: 14- | ||
|
||
For more advanced inference options, see :ref:`predictions`. | ||
|
||
------ | ||
|
||
******** | ||
Training | ||
******** | ||
|
||
In this section, we briefly describe the data, and then ``literalinclude`` our finetuning example. | ||
|
||
Now we'll train on Fisher's classic Iris data. | ||
It contains 150 records with four features (sepal length, sepal width, petal length, and petal width) in three classes (species of Iris: setosa, virginica and versicolor). | ||
|
||
Now all we need is to train our task! | ||
|
||
.. literalinclude:: ../../../flash_examples/finetuning/template.py | ||
:language: python | ||
:lines: 14- | ||
|
||
------ | ||
|
||
************* | ||
API reference | ||
************* | ||
|
||
We usually include the API reference for the :class:`~flash.core.model.Task` and :class:`~flash.core.data.data_module.DataModule`. | ||
You can optionally add the other classes you've implemented. | ||
To add the API reference, use the ``autoclass`` directive. | ||
|
||
.. _template_classifier: | ||
|
||
TemplateSKLearnClassifier | ||
------------------------- | ||
|
||
.. autoclass:: flash.template.TemplateSKLearnClassifier | ||
:members: | ||
:exclude-members: forward | ||
|
||
.. _template_data: | ||
|
||
TemplateData | ||
------------ | ||
|
||
.. autoclass:: flash.template.TemplateData |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
.. _contributing_backbones: | ||
|
||
************* | ||
The Backbones | ||
************* | ||
|
||
Now that you've got a way of loading data, you should implement some backbones to use with your :class:`~flash.core.model.Task`. | ||
Create a :any:`FlashRegistry <registry>` to use with your :class:`~flash.core.model.Task` in `backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/template/classification/backbones.py>`_. | ||
|
||
The registry allows you to register backbones for your task that can be selected by the user. | ||
The backbones can come from anywhere as long as you can register a function that loads the backbone. | ||
Furthermore, the user can add their own models to the existing backbones, without having to write their own :class:`~flash.core.model.Task`! | ||
|
||
You can create a registry like this: | ||
|
||
.. code-block:: python | ||
|
||
TEMPLATE_BACKBONES = FlashRegistry("backbones") | ||
|
||
Let's add a simple MLP backbone to our registry. | ||
We need a function that creates the backbone and returns it along with the output size (so that we can create the model head in our :class:`~flash.core.model.Task`). | ||
You can use any name for the function, although we use ``load_{model name}`` by convention. | ||
Here's the code: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/backbones.py | ||
:language: python | ||
:pyobject: load_mlp_128 | ||
|
||
Here's another example with a slightly more complex model: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/backbones.py | ||
:language: python | ||
:pyobject: load_mlp_128_256 | ||
|
||
Here's a more advanced example, which adds ``SimCLR`` to the ``IMAGE_CLASSIFIER_BACKBONES``, from `flash/image/backbones.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/image/backbones.py>`_: | ||
|
||
.. literalinclude:: ../../../flash/image/backbones.py | ||
:language: python | ||
:pyobject: load_simclr_imagenet | ||
ethanwharris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
------ | ||
|
||
Once you've got some data and some backbones, :ref:`implement your task! <contributing_task>` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
.. _contributing_data: | ||
|
||
******** | ||
The Data | ||
******** | ||
|
||
The first step to contributing a task is to implement the classes we need to load some data. | ||
Inside `data.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/template/classification/data.py>`_ you should implement: | ||
|
||
#. some :class:`~flash.core.data.data_source.DataSource` classes *(optional)* | ||
#. a :class:`~flash.core.data.process.Preprocess` | ||
#. a :class:`~flash.core.data.data_module.DataModule` | ||
#. a :class:`~flash.core.data.base_viz.BaseVisualization` *(optional)* | ||
#. a :class:`~flash.core.data.process.Postprocess` *(optional)* | ||
|
||
DataSource | ||
^^^^^^^^^^ | ||
|
||
The :class:`~flash.core.data.data_source.DataSource` class contains the logic for data loading from different sources such as folders, files, tensors, etc. | ||
Every Flash :class:`~flash.core.data.data_module.DataModule` can be instantiated with :meth:`~flash.core.data.data_module.DataModule.from_datasets`. | ||
For each additional way you want the user to be able to instantiate your :class:`~flash.core.data.data_module.DataModule`, you'll need to create a :class:`~flash.core.data.data_source.DataSource`. | ||
Each :class:`~flash.core.data.data_source.DataSource` has 2 methods: | ||
|
||
- :meth:`~flash.core.data.data_source.DataSource.load_data` takes some dataset metadata (e.g. a folder name) as input and produces a sequence or iterable of samples or sample metadata. | ||
- :meth:`~flash.core.data.data_source.DataSource.load_sample` then takes as input a single element from the output of ``load_data`` and returns a sample. | ||
|
||
By default these methods just return their input, so you don't need both a :meth:`~flash.core.data.data_source.DataSource.load_data` and a :meth:`~flash.core.data.data_source.DataSource.load_sample` to create a :class:`~flash.core.data.data_source.DataSource`. | ||
Where possible, you should override one of our existing :class:`~flash.core.data.data_source.DataSource` classes. | ||
|
||
Let's start by implementing a ``TemplateNumpyDataSource``, which overrides :class:`~flash.core.data.data_source.NumpyDataSource`. | ||
The main :class:`~flash.core.data.data_source.DataSource` method that we have to implement is :meth:`~flash.core.data.data_source.DataSource.load_data`. | ||
As we're extending the ``NumpyDataSource``, we expect the same ``data`` argument (in this case, a tuple containing data and corresponding target arrays). | ||
|
||
We can also take the dataset argument. | ||
Any attributes we set on ``dataset`` will be available on the :class:`~torch.utils.data.Dataset` generated by our :class:`~flash.core.data.data_source.DataSource`. | ||
In this data source, we'll set the ``num_features`` attribute. | ||
|
||
Here's the code for our ``TemplateNumpyDataSource.load_data`` method: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateNumpyDataSource.load_data | ||
|
||
.. note:: Later, when we add :ref:`our DataModule implementation <contributing_data_module>`, we'll make ``num_features`` available to the user. | ||
|
||
Sometimes you need to something a bit more custom. | ||
When creating a custom :class:`~flash.core.data.data_source.DataSource`, the type of the ``data`` argument is up to you. | ||
For our template :class:`~flash.core.data.model.Task`, it would be cool if the user could provide a scikit-learn ``Bunch`` as the data source. | ||
To achieve this, we'll add a ``TemplateSKLearnDataSource`` whose ``load_data`` expects a ``Bunch`` as input. | ||
We override our ``TemplateNumpyDataSource`` so that we can call ``super`` with the data and targets extracted from the ``Bunch``. | ||
We perform two additional steps here to improve the user experience: | ||
|
||
1. We set the ``num_classes`` attribute on the ``dataset``. If ``num_classes`` is set, it is automatically made available as a property of the :class:`~flash.core.data.data_module.DataModule`. | ||
2. We create and set a :class:`~flash.core.data.data_source.LabelsState`. The labels provided here will be shared with the :class:`~flash.core.classification.Labels` serializer, so the user doesn't need to provide them. | ||
|
||
Here's the code for the ``TemplateSKLearnDataSource.load_data`` method: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateSKLearnDataSource.load_data | ||
|
||
We can customize the behaviour of our :meth:`~flash.core.data.data_source.DataSource.load_data` for different stages, by prepending `train`, `val`, `test`, or `predict`. | ||
For our ``TemplateSKLearnDataSource``, we don't want to provide any targets to the model when predicting. | ||
We can implement ``predict_load_data`` like this: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateSKLearnDataSource.predict_load_data | ||
|
||
DataSource vs Dataset | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me rephrase it to see if I understand it correctly: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It may also be useful to understand how it is different from torch.utils.DataLoader, since Dataset only requires getitem, but Dataloader also does some preprocessing, although I think does not distinguish between training, validation ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The high-level view is this:
So DataSource, Preprocess, DataPipeline is really just a different way of creating a DataSet and DataLoader (not a replacement). Can't speak to similarity with Fast AI as I'm not very familiar with it. Hope that helps! |
||
~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
A :class:`~flash.core.data.data_source.DataSource` is not the same as a :class:`torch.utils.data.Dataset`. | ||
ethanwharris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
When a ``from_*`` method is called on your :class:`~flash.core.data.data_module.DataModule`, it gets the :class:`~flash.core.data.data_source.DataSource` to use from the :class:`~flash.core.data.process.Preprocess`. | ||
A :class:`~torch.utils.data.Dataset` is then created from the :class:`~flash.core.data.data_source.DataSource` for each stage (`train`, `val`, `test`, `predict`) using the provided metadata (e.g. folder name, numpy array etc.). | ||
|
||
The output of the :meth:`~flash.core.data.data_source.DataSource.load_data` can just be a :class:`torch.utils.data.Dataset` instance. | ||
If the library that your :class:`~flash.core.data.model.Task` is based on provides a custom dataset, you don't need to re-write it as a :class:`~flash.core.data.data_source.DataSource`. | ||
For example, the :meth:`~flash.core.data.data_source.DataSource.load_data` of the ``VideoClassificationPathsDataSource`` just creates an :class:`~pytorchvideo.data.EncodedVideoDataset` from the given folder. | ||
Here's how it looks (from `video/classification.data.py <https://github.com/PyTorchLightning/lightning-flash/blob/master/flash/video/classification/data.py>`_): | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we could give a simpler example for something like |
||
.. literalinclude:: ../../../flash/video/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: VideoClassificationPathsDataSource.load_data | ||
|
||
Preprocess | ||
ethanwharris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
^^^^^^^^^^ | ||
|
||
The :class:`~flash.core.data.process.Preprocess` object contains all the data transforms. | ||
Internally we inject the :class:`~flash.core.data.process.Preprocess` transforms at several points along the pipeline. | ||
|
||
Defining the standard transforms (typically at least a ``to_tensor_transform`` should be defined) for your :class:`~flash.core.data.process.Preprocess` is as simple as implementing the ``default_transforms`` method. | ||
The :class:`~flash.core.data.process.Preprocess` must take ``train_transform``, ``val_transform``, ``test_transform``, and ``predict_transform`` arguments in the ``__init__``. | ||
These arguments can be provided by the user (when creating the :class:`~flash.core.data.data_module.DataModule`) to override the default transforms. | ||
Any additional arguments are up to you. | ||
|
||
Inside the ``__init__``, we make a call to super. | ||
This is where we register our data sources. | ||
Data sources should be given as a dictionary which maps data source name to data source object. | ||
The name can be anything, but if you want to take advantage of our built-in ``from_*`` classmethods, you should use :class:`~flash.core.data.data_source.DefaultDataSources` as the names. | ||
In our case, we have both a :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` and a custom scikit-learn data source (which we'll call `"sklearn"`). | ||
|
||
You should also provide a ``default_data_source``. | ||
This is the name of the data source to use by default when predicting. | ||
It'd be cool if we could get predictions just from a numpy array, so we'll use :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` as the default. | ||
|
||
Here's our ``TemplatePreprocess.__init__``: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplatePreprocess.__init__ | ||
|
||
For our ``TemplatePreprocess``, we'll just configure a default ``to_tensor_transform``. | ||
Let's first define the transform as a ``staticmethod``: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplatePreprocess.input_to_tensor | ||
|
||
Our inputs samples will be dictionaries whose keys are in the :class:`~flash.core.data.data_source.DefaultDataKeys`. | ||
You can map each key to different transforms using :class:`~flash.core.data.transforms.ApplyToKeys`. | ||
Here's our ``default_transforms`` method: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplatePreprocess.default_transforms | ||
|
||
.. _contributing_data_module: | ||
|
||
DataModule | ||
^^^^^^^^^^ | ||
|
||
The :class:`~flash.core.data.data_module.DataModule` is responsible for creating the :class:`~torch.utils.data.DataLoader` and injecting the transforms for each stage. | ||
When the user calls a ``from_*`` method (such as :meth:`~flash.core.data.data_module.DataModule.from_numpy`), the following steps take place: | ||
|
||
#. The :meth:`~flash.core.data.data_module.DataModule.from_data_source` method is called with the name of the :class:`~flash.core.data.data_source.DataSource` to use and the inputs to provide to :meth:`~flash.core.data.data_source.DataSource.load_data` for each stage. | ||
#. The :class:`~flash.core.data.process.Preprocess` is created from ``cls.preprocess_cls`` (if it wasn't provided by the user) with any provided transforms. | ||
#. The :class:`~flash.core.data.data_source.DataSource` of the provided name is retrieved from the :class:`~flash.core.data.process.Preprocess`. | ||
#. A :class:`~flash.core.data.auto_dataset.BaseAutoDataset` is created from the :class:`~flash.core.data.data_source.DataSource` for each stage. | ||
#. The :class:`~flash.core.data.data_module.DataModule` is instantiated with the data sets. | ||
|
||
| | ||
|
||
To create our ``TemplateData`` :class:`~flash.core.data.data_module.DataModule`, we first need to attach out preprocess class like this: | ||
|
||
.. code-block:: python | ||
|
||
preprocess_cls = TemplatePreprocess | ||
|
||
Since we provided a :attr:`~flash.core.data.data_source.DefaultDataSources.NUMPY` :class:`~flash.core.data.data_source.DataSource` in the ``TemplatePreprocess``, :meth:`~flash.core.data.data_module.DataModule.from_numpy` will now work with our ``TemplateData``. | ||
|
||
If you've defined a fully custom :class:`~flash.core.data.data_source.DataSource` (like our ``TemplateSKLearnDataSource``), then you will need to write a ``from_*`` method for each. | ||
Here's the ``from_sklearn`` method for our ``TemplateData``: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateData.from_sklearn | ||
|
||
The final step is to implement the ``num_features`` property for our ``TemplateData``. | ||
This is just a convenience for the user that finds the ``num_features`` attribute on any of the data sets and returns it. | ||
Here's the code: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateData.num_features | ||
|
||
BaseVisualization | ||
^^^^^^^^^^^^^^^^^ | ||
|
||
An optional step is to implement a :class:`~flash.core.data.base_viz.BaseVisualization`. | ||
The :class:`~flash.core.data.base_viz.BaseVisualization` lets you control how data at various points in the pipeline can be visualized. | ||
This is extremely useful for debugging purposes, allowing users to view their data and understand the impact of their transforms. | ||
|
||
.. note:: | ||
Don't worry about implementing it right away, you can always come back and add it later! | ||
|
||
Here's the code for our ``TemplateVisualization`` which just prints the data: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:pyobject: TemplateVisualization | ||
|
||
We can configure our custom visualization in the ``TemplateData`` using :meth:`~flash.core.data.data_module.DataModule.configure_data_fetcher` like this: | ||
|
||
.. literalinclude:: ../../../flash/template/classification/data.py | ||
:language: python | ||
:dedent: 4 | ||
:pyobject: TemplateData.configure_data_fetcher | ||
|
||
Postprocess | ||
^^^^^^^^^^^ | ||
|
||
:class:`~flash.core.data.process.Postprocess` contains any transforms that need to be applied *after* the model. | ||
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. | ||
As an example, here's the :class:`~text.classification.data.TextClassificationPostProcess` which gets the logits from a ``SequenceClassifierOutput``: | ||
|
||
.. literalinclude:: ../../../flash/text/classification/data.py | ||
:language: python | ||
:pyobject: TextClassificationPostProcess | ||
|
||
ethanwharris marked this conversation as resolved.
Show resolved
Hide resolved
|
||
------ | ||
|
||
Now that you've got some data, it's time to :ref:`add some backbones for your task! <contributing_backbones>` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Include link to images to make your description better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just tabular data, so I'm not sure what images we would show here