Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState #229

Merged
merged 55 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
512ba6d
Initial commit
ethanwharris Apr 15, 2021
e3de28f
Initial commit
ethanwharris Apr 19, 2021
69074f7
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 19, 2021
4c58b35
Small fixes
ethanwharris Apr 19, 2021
1875474
Small fixes
ethanwharris Apr 19, 2021
b7559ca
Fix a small bug
ethanwharris Apr 20, 2021
641471e
Update docs
ethanwharris Apr 20, 2021
cfe047e
Update notebook
ethanwharris Apr 20, 2021
3274695
Update finetuning image classification
ethanwharris Apr 20, 2021
55d9e9e
Updates
ethanwharris Apr 20, 2021
78cdda6
Update docs and serializer mapping
ethanwharris Apr 20, 2021
5a6cb52
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 20, 2021
228a2e1
Fix missed merge conflicts
ethanwharris Apr 20, 2021
9e3fa8d
Fix some broken tests
ethanwharris Apr 20, 2021
9a6e59b
Fix a test
ethanwharris Apr 20, 2021
305e8d9
Fix a test
ethanwharris Apr 20, 2021
f06aeac
Fix some tests
ethanwharris Apr 20, 2021
9d591d3
Fix a test
ethanwharris Apr 20, 2021
eaf8742
Update examples
ethanwharris Apr 20, 2021
79409b7
Update examples
ethanwharris Apr 20, 2021
a6ac0b7
Pre-commit
ethanwharris Apr 20, 2021
6ad96df
Pre-commit
ethanwharris Apr 20, 2021
0c870cf
Update text classification
ethanwharris Apr 20, 2021
2a84f56
Add a test
ethanwharris Apr 20, 2021
bd176a3
Multi-label example initial commit
ethanwharris Apr 20, 2021
4cc3141
Add predict example for multi_label
ethanwharris Apr 20, 2021
d379151
Remove unused imports
ethanwharris Apr 20, 2021
e70ad82
Update predict example
ethanwharris Apr 20, 2021
10b5799
Update examples
ethanwharris Apr 20, 2021
4967ded
Add multi-label Labels suport
ethanwharris Apr 20, 2021
a408f60
Update test_classification
ethanwharris Apr 21, 2021
2b15226
Update .gitignore
ethanwharris Apr 21, 2021
7f65dcc
Add some tests
ethanwharris Apr 21, 2021
777db40
Fix broken test
ethanwharris Apr 21, 2021
86cd3d7
Update test_process
ethanwharris Apr 21, 2021
86d8a7f
Some docs updates
ethanwharris Apr 21, 2021
3f89cd3
Update docs
ethanwharris Apr 21, 2021
9672e55
Fix some tests
ethanwharris Apr 21, 2021
59e8374
Add back some process_cls
ethanwharris Apr 21, 2021
f9681a8
Add types
ethanwharris Apr 21, 2021
443a9c9
Update docs/source/general/data.rst
ethanwharris Apr 21, 2021
0ca9ec8
Update flash/data/process.py
ethanwharris Apr 21, 2021
245cc5c
Add comment
ethanwharris Apr 21, 2021
7dd16c5
Update example
ethanwharris Apr 21, 2021
fc52a68
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 21, 2021
9d24f98
Remove state checkpoint not needed
ethanwharris Apr 22, 2021
4c3f34b
Merge branch 'feature/multiple_return_types' of https://github.com/Py…
ethanwharris Apr 22, 2021
0ffc412
Fix doctest
ethanwharris Apr 22, 2021
3644764
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 22, 2021
b9a4f92
Update image_classification example
ethanwharris Apr 22, 2021
c54c895
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 22, 2021
686823b
Update following fix
ethanwharris Apr 22, 2021
a66a40a
Fix num-workers in test_examples
ethanwharris Apr 22, 2021
3bbfb16
Update example predict
ethanwharris Apr 22, 2021
b1310e5
Better fix for windows error
ethanwharris Apr 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.DS_Store
.lock
lightning_logs

Expand Down Expand Up @@ -149,3 +150,4 @@ xsum
coco128
wmt_en_ro
kinetics
movie_posters
12 changes: 6 additions & 6 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ We will define a custom ``NumpyDataModule`` class subclassing :class:`~flash.dat
This ``NumpyDataModule`` class will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate
:class:`~flash.data.data_module.DataModule` from x, y numpy arrays.

Here is how it would look like:
Here is how it would look:

Example::

x, y = ...
preprocess_cls = ...
datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess_cls)
preprocess = ...
datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess)

Here is the ``NumpyDataModule`` implementation:

Expand All @@ -140,12 +140,12 @@ Example::
cls,
x: ND,
y: ND,
preprocess_cls: Preprocess = NumpyPreprocess,
preprocess: Preprocess = None,
batch_size: int = 64,
num_workers: int = 0
):

preprocess = preprocess_cls()
preprocess = preprocess or NumpyPreprocess()

x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=.20, random_state=0)
Expand Down Expand Up @@ -180,7 +180,7 @@ It allows the user much more granular control over their data processing flow.

.. note::

Why introducing :class:`~flash.data.process.Preprocess` ?
Why introduce :class:`~flash.data.process.Preprocess` ?

The :class:`~flash.data.process.Preprocess` object reduces the engineering overhead to make inference on raw data or
to deploy the model in production environnement compared to traditional
Expand Down
45 changes: 33 additions & 12 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ Here are common terms you need to be familiar with:
The :class:`~flash.data.process.Preprocess` hooks covers from data-loading to model forwarding.
* - :class:`~flash.data.process.Postprocess`
- The :class:`~flash.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic.
The :class:`~flash.data.process.Postprocess` hooks covers from model outputs to predictions export.
The :class:`~flash.data.process.Postprocess` hooks cover from model outputs to predictions export.
* - :class:`~flash.data.process.Serializer`
- The :class:`~flash.data.process.Serializer` provides a single ``serialize`` method that is used to convert model outputs (after the :class:`~flash.data.process.Postprocess`) to the desired output format during prediction.

*******************************************
How to use out-of-the-box flashdatamodules
Expand All @@ -49,7 +51,9 @@ However, after model training, it requires a lot of engineering overhead to make
Usually, extra processing logic should be added to bridge the gap between training data and raw data.

The :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` classes can be used to
store the data as well as the preprocessing and postprocessing transforms.
store the data as well as the preprocessing and postprocessing transforms. The :class:`~flash.data.process.Serializer`
class provides the logic for converting :class:`~flash.data.process.Postprocess` outputs to the desired predict format
(e.g. classes, labels, probabilites, etc.).

By providing a series of hooks that can be overridden with custom data processing logic,
the user has much more granular control over their data processing flow.
Expand Down Expand Up @@ -122,7 +126,7 @@ Example::
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
preprocess_cls=CustomImageClassificationPreprocess
preprocess=CustomImageClassificationPreprocess(),
)


Expand Down Expand Up @@ -157,7 +161,7 @@ Example::
val_folder="./data/val",
test_folder="./data/test",
predict_folder="./data/predict",
preprocess=preprocess
preprocess=preprocess,
)

model = ImageClassifier(...)
Expand Down Expand Up @@ -190,6 +194,7 @@ Example::
**kwargs
):

# Set a custom ``Preprocess`` if none was provided
preprocess = preprocess or cls.preprocess_cls()

# {stage}_load_data_input will be given to your
Expand Down Expand Up @@ -291,6 +296,18 @@ ___________
:members:


----------

.. _serializer:

Serializer
___________


.. autoclass:: flash.data.process.Serializer
:members:


----------

.. _datapipeline:
Expand Down Expand Up @@ -414,16 +431,18 @@ Example::
predictions = lightning_module(data)


Postprocess
___________
Postprocess and Serializer
__________________________


Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`.
The Flash :class:`~flash.data.data_pipeline.DataPipeline` will behind the scenes execute the :class:`~flash.data.process.Postprocess` hooks.
Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`, the Flash
:class:`~flash.data.data_pipeline.DataPipeline` will execute the :class:`~flash.data.process.Postprocess` hooks and the
:class:`~flash.data.process.Serializer` behind the scenes.

First, the ``per_batch_transform`` hooks will be applied on the batch predictions.
Then the ``uncollate`` will split the batch into individual predictions.
Finally, the ``per_sample_transform`` will be applied on each prediction.
First, the :meth:`~flash.data.process.Postprocess.per_batch_transform` hooks will be applied on the batch predictions.
Then, the :meth:`~flash.data.process.Postprocess.uncollate` will split the batch into individual predictions.
Next, the :meth:`~flash.data.process.Postprocess.per_sample_transform` will be applied on each prediction.
Finally, the :meth:`~flash.data.process.Serializer.serialize` method will be called to serialize the predictions.

.. note:: The transform can be applied either on device or ``CPU``.

Expand All @@ -438,7 +457,9 @@ Example::

samples = uncollate(batch)

return [per_sample_transform(sample) for sample in samples]
samples = [per_sample_transform(sample) for sample in samples]
# only if serializers are enabled.
return [serialize(sample) for sample in samples]

predictions = lightning_module(data)
return uncollate_fn(predictions)
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Lightning Flash
reference/task
reference/image_classification
reference/image_embedder
reference/multi_label_classification
reference/summarization
reference/text_classification
reference/tabular_classification
Expand Down
212 changes: 212 additions & 0 deletions docs/source/reference/multi_label_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@

.. _multi_label_classification:

################################
Multi-label Image Classification
################################

********
The task
********
Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality. In this example, we will look at the task of trying to predict the movie genres from an image of the movie poster.

------

********
The data
********
The data we will use in this example is a subset of the awesome movie poster genre prediction data set from the paper "Movie Genre Classification based on Poster Images with Deep Neural Networks" by Wei-Ta Chu and Hung-Jui Guo, resized to 128 by 128.
Take a look at their paper (and please consider citing their paper if you use the data) here: `www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ <https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/>`_.

------

*********
Inference
*********

The :class:`~flash.vision.ImageClassifier` is already pre-trained on `ImageNet <http://www.image-net.org/>`_, a dataset of over 14 million images.

We can use the :class:`~flash.vision.ImageClassifier` model (pretrained on our data) for inference on any string sequence using :func:`~flash.vision.ImageClassifier.predict`.
We can also add a simple visualisation by extending :class:`~flash.data.base_viz.BaseVisualization`, like this:

.. code-block:: python
# import our libraries
from typing import Any
import torchvision.transforms.functional as T
from torchvision.utils import make_grid
from flash import Trainer
from flash.data.base_viz import BaseVisualization
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
# 2. Define our custom visualisation and datamodule
class CustomViz(BaseVisualization):
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
def show_per_batch_transform(self, batch: Any, _):
images = batch[0]
image = make_grid(images, nrow=2)
image = T.to_pil_image(image, 'RGB')
image.show()
# 3. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/image_classification_multi_label_model.pt",
)
# 4a. Predict the genres of a few movie posters!
predictions = model.predict([
"data/movie_posters/val/tt0361500.jpg",
"data/movie_posters/val/tt0361748.jpg",
"data/movie_posters/val/tt0362478.jpg",
])
print(predictions)
# 4b. Or generate predictions with a whole folder!
datamodule = ImageClassificationData.from_folders(
predict_folder="data/movie_posters/predict/",
data_fetcher=CustomViz(),
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
preprocess=model.preprocess,
)
predictions = Trainer().predict(model, datamodule=datamodule)
print(predictions)
# 5. Show some data!
datamodule.show_predict_batch()
For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********

Now let's look at how we can finetune a model on the movie poster data.
Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.ImageClassificationData`.

.. note:: The dataset contains ``train`` and ``validation`` folders, and then each folder contains images and a ``metadata.csv`` which stores the labels.

.. code-block::
movie_posters
├── train
│ ├── metadata.csv
│ ├── tt0084058.jpg
│ ├── tt0084867.jpg
│ ...
└── val
├── metadata.csv
├── tt0200465.jpg
├── tt0326965.jpg
...
The ``metadata.csv`` files in each folder contain our labels, so we need to create a function (``load_data``) to extract the list of images and associated labels:

.. code-block:: python
# import our libraries
import os
from typing import List, Tuple
import pandas as pd
import torch
genres = [
"Action", "Adventure", "Animation", "Biography", "Comedy", "Crime", "Documentary", "Drama", "Family", "Fantasy", "History", "Horror", "Music", "Musical", "Mystery", "N/A", "News", "Reality-TV", "Romance", "Sci-Fi", "Short", "Sport", "Thriller", "War", "Western"
]
def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]:
metadata = pd.read_csv(os.path.join(root, data, "metadata.csv"))
images = []
labels = []
for _, row in metadata.iterrows():
images.append(os.path.join(root, data, row['Id'] + ".jpg"))
labels.append([int(row[genre]) for genre in genres])
return images, labels
Our :class:`~flash.data.process.Preprocess` overrides the :meth:`~flash.data.process.Preprocess.load_data` method to create an iterable of image paths and label tensors. The :class:`~flash.vision.classification.data.ImageClassificationPreprocess` then handles loading and augmenting the images for us!
Now all we need is three lines of code to build to train our task!

.. note:: We need set `multi_label=True` in both our :class:`~flash.Task` and our :class:`~flash.data.process.Serializer` to use a binary cross entropy loss and to process outputs correctly.

.. code-block:: python
import flash
from flash.core.classification import Labels
from flash.core.finetuning import FreezeUnfreeze
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier
from flash.vision.classification.data import ImageClassificationPreprocess
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/")
# 2. Load the data
ImageClassificationPreprocess.image_size = (128, 128)
train_filepaths, train_labels = load_data('train')
val_filepaths, val_labels = load_data('val')
test_filepaths, test_labels = load_data('test')
datamodule = ImageClassificationData.from_filepaths(
train_filepaths=train_filepaths,
train_labels=train_labels,
val_filepaths=val_filepaths,
val_labels=val_labels,
test_filepaths=test_filepaths,
test_labels=test_labels,
preprocess=ImageClassificationPreprocess(),
)
# 3. Build the model
model = ImageClassifier(
backbone="resnet18",
num_classes=len(genres),
multi_label=True,
)
# 4. Create the trainer.
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)
# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
# 6a. Predict what's on a few images!
# Serialize predictions as labels.
model.serializer = Labels(genres, multi_label=True)
predictions = model.predict([
"data/movie_posters/val/tt0361500.jpg",
"data/movie_posters/val/tt0361748.jpg",
"data/movie_posters/val/tt0362478.jpg",
])
print(predictions)
datamodule = ImageClassificationData.from_folders(
predict_folder="data/movie_posters/predict/",
preprocess=model.preprocess,
)
# 6b. Or generate predictions with a whole folder!
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# 7. Save it!
trainer.save_checkpoint("image_classification_multi_label_model.pt")
------

For more backbone options, see :ref:`image_classification`.
Loading