Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Refactor preprocess_cls to preprocess, add Serializer, add DataPipelineState #229

Merged
merged 55 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
512ba6d
Initial commit
ethanwharris Apr 15, 2021
e3de28f
Initial commit
ethanwharris Apr 19, 2021
69074f7
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 19, 2021
4c58b35
Small fixes
ethanwharris Apr 19, 2021
1875474
Small fixes
ethanwharris Apr 19, 2021
b7559ca
Fix a small bug
ethanwharris Apr 20, 2021
641471e
Update docs
ethanwharris Apr 20, 2021
cfe047e
Update notebook
ethanwharris Apr 20, 2021
3274695
Update finetuning image classification
ethanwharris Apr 20, 2021
55d9e9e
Updates
ethanwharris Apr 20, 2021
78cdda6
Update docs and serializer mapping
ethanwharris Apr 20, 2021
5a6cb52
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 20, 2021
228a2e1
Fix missed merge conflicts
ethanwharris Apr 20, 2021
9e3fa8d
Fix some broken tests
ethanwharris Apr 20, 2021
9a6e59b
Fix a test
ethanwharris Apr 20, 2021
305e8d9
Fix a test
ethanwharris Apr 20, 2021
f06aeac
Fix some tests
ethanwharris Apr 20, 2021
9d591d3
Fix a test
ethanwharris Apr 20, 2021
eaf8742
Update examples
ethanwharris Apr 20, 2021
79409b7
Update examples
ethanwharris Apr 20, 2021
a6ac0b7
Pre-commit
ethanwharris Apr 20, 2021
6ad96df
Pre-commit
ethanwharris Apr 20, 2021
0c870cf
Update text classification
ethanwharris Apr 20, 2021
2a84f56
Add a test
ethanwharris Apr 20, 2021
bd176a3
Multi-label example initial commit
ethanwharris Apr 20, 2021
4cc3141
Add predict example for multi_label
ethanwharris Apr 20, 2021
d379151
Remove unused imports
ethanwharris Apr 20, 2021
e70ad82
Update predict example
ethanwharris Apr 20, 2021
10b5799
Update examples
ethanwharris Apr 20, 2021
4967ded
Add multi-label Labels suport
ethanwharris Apr 20, 2021
a408f60
Update test_classification
ethanwharris Apr 21, 2021
2b15226
Update .gitignore
ethanwharris Apr 21, 2021
7f65dcc
Add some tests
ethanwharris Apr 21, 2021
777db40
Fix broken test
ethanwharris Apr 21, 2021
86cd3d7
Update test_process
ethanwharris Apr 21, 2021
86d8a7f
Some docs updates
ethanwharris Apr 21, 2021
3f89cd3
Update docs
ethanwharris Apr 21, 2021
9672e55
Fix some tests
ethanwharris Apr 21, 2021
59e8374
Add back some process_cls
ethanwharris Apr 21, 2021
f9681a8
Add types
ethanwharris Apr 21, 2021
443a9c9
Update docs/source/general/data.rst
ethanwharris Apr 21, 2021
0ca9ec8
Update flash/data/process.py
ethanwharris Apr 21, 2021
245cc5c
Add comment
ethanwharris Apr 21, 2021
7dd16c5
Update example
ethanwharris Apr 21, 2021
fc52a68
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 21, 2021
9d24f98
Remove state checkpoint not needed
ethanwharris Apr 22, 2021
4c3f34b
Merge branch 'feature/multiple_return_types' of https://github.com/Py…
ethanwharris Apr 22, 2021
0ffc412
Fix doctest
ethanwharris Apr 22, 2021
3644764
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 22, 2021
b9a4f92
Update image_classification example
ethanwharris Apr 22, 2021
c54c895
Merge branch 'master' into feature/multiple_return_types
ethanwharris Apr 22, 2021
686823b
Update following fix
ethanwharris Apr 22, 2021
a66a40a
Fix num-workers in test_examples
ethanwharris Apr 22, 2021
3bbfb16
Update example predict
ethanwharris Apr 22, 2021
b1310e5
Better fix for windows error
ethanwharris Apr 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.DS_Store
.lock
lightning_logs

Expand Down Expand Up @@ -149,3 +150,4 @@ xsum
coco128
wmt_en_ro
kinetics
movie_posters
12 changes: 6 additions & 6 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ We will define a custom ``NumpyDataModule`` class subclassing :class:`~flash.dat
This ``NumpyDataModule`` class will provide a ``from_xy_dataset`` helper ``classmethod`` to instantiate
:class:`~flash.data.data_module.DataModule` from x, y numpy arrays.

Here is how it would look like:
Here is how it would look:

Example::

x, y = ...
preprocess_cls = ...
datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess_cls)
preprocess = ...
datamodule = NumpyDataModule.from_xy_dataset(x, y, preprocess)

Here is the ``NumpyDataModule`` implementation:

Expand All @@ -140,12 +140,12 @@ Example::
cls,
x: ND,
y: ND,
preprocess_cls: Preprocess = NumpyPreprocess,
preprocess: Preprocess = None,
batch_size: int = 64,
num_workers: int = 0
):

preprocess = preprocess_cls()
preprocess = preprocess or NumpyPreprocess()

x_train, x_test, y_train, y_test = train_test_split(
x, y, test_size=.20, random_state=0)
Expand Down Expand Up @@ -180,7 +180,7 @@ It allows the user much more granular control over their data processing flow.

.. note::

Why introducing :class:`~flash.data.process.Preprocess` ?
Why introduce :class:`~flash.data.process.Preprocess` ?

The :class:`~flash.data.process.Preprocess` object reduces the engineering overhead to make inference on raw data or
to deploy the model in production environnement compared to traditional
Expand Down
45 changes: 33 additions & 12 deletions docs/source/general/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ Here are common terms you need to be familiar with:
The :class:`~flash.data.process.Preprocess` hooks covers from data-loading to model forwarding.
* - :class:`~flash.data.process.Postprocess`
- The :class:`~flash.data.process.Postprocess` provides a simple hook-based API to encapsulate your post-processing logic.
The :class:`~flash.data.process.Postprocess` hooks covers from model outputs to predictions export.
The :class:`~flash.data.process.Postprocess` hooks cover from model outputs to predictions export.
* - :class:`~flash.data.process.Serializer`
- The :class:`~flash.data.process.Serializer` provides a single ``serialize`` method that is used to convert model outputs (after the :class:`~flash.data.process.Postprocess`) to the desired output format during prediction.

*******************************************
How to use out-of-the-box flashdatamodules
Expand All @@ -49,7 +51,9 @@ However, after model training, it requires a lot of engineering overhead to make
Usually, extra processing logic should be added to bridge the gap between training data and raw data.

The :class:`~flash.data.process.Preprocess` and :class:`~flash.data.process.Postprocess` classes can be used to
store the data as well as the preprocessing and postprocessing transforms.
store the data as well as the preprocessing and postprocessing transforms. The :class:`~flash.data.process.Serializer`
class provides the logic for converting :class:`~flash.data.process.Postprocess` outputs to the desired predict format
(e.g. classes, labels, probabilites, etc.).

By providing a series of hooks that can be overridden with custom data processing logic,
the user has much more granular control over their data processing flow.
Expand Down Expand Up @@ -122,7 +126,7 @@ Example::
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
preprocess_cls=CustomImageClassificationPreprocess
preprocess=CustomImageClassificationPreprocess(),
)


Expand Down Expand Up @@ -157,7 +161,7 @@ Example::
val_folder="./data/val",
test_folder="./data/test",
predict_folder="./data/predict",
preprocess=preprocess
preprocess=preprocess,
)

model = ImageClassifier(...)
Expand Down Expand Up @@ -190,6 +194,7 @@ Example::
**kwargs
):

# Set a custom ``Preprocess`` if none was provided
preprocess = preprocess or cls.preprocess_cls()

# {stage}_load_data_input will be given to your
Expand Down Expand Up @@ -291,6 +296,18 @@ ___________
:members:


----------

.. _serializer:

Serializer
___________


.. autoclass:: flash.data.process.Serializer
:members:


----------

.. _datapipeline:
Expand Down Expand Up @@ -414,16 +431,18 @@ Example::
predictions = lightning_module(data)


Postprocess
___________
Postprocess and Serializer
__________________________


Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`.
The Flash :class:`~flash.data.data_pipeline.DataPipeline` will behind the scenes execute the :class:`~flash.data.process.Postprocess` hooks.
Once the predictions have been generated by the Flash :class:`~flash.core.model.Task`, the Flash
:class:`~flash.data.data_pipeline.DataPipeline` will execute the :class:`~flash.data.process.Postprocess` hooks and the
:class:`~flash.data.process.Serializer` behind the scenes.

First, the ``per_batch_transform`` hooks will be applied on the batch predictions.
Then the ``uncollate`` will split the batch into individual predictions.
Finally, the ``per_sample_transform`` will be applied on each prediction.
First, the :meth:`~flash.data.process.Postprocess.per_batch_transform` hooks will be applied on the batch predictions.
Then, the :meth:`~flash.data.process.Postprocess.uncollate` will split the batch into individual predictions.
Next, the :meth:`~flash.data.process.Postprocess.per_sample_transform` will be applied on each prediction.
Finally, the :meth:`~flash.data.process.Serializer.serialize` method will be called to serialize the predictions.

.. note:: The transform can be applied either on device or ``CPU``.

Expand All @@ -438,7 +457,9 @@ Example::

samples = uncollate(batch)

return [per_sample_transform(sample) for sample in samples]
samples = [per_sample_transform(sample) for sample in samples]
# only if serializers are enabled.
return [serialize(sample) for sample in samples]

predictions = lightning_module(data)
return uncollate_fn(predictions)
4 changes: 1 addition & 3 deletions docs/source/general/predictions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Predict on a single sample of data

You can pass in a sample of data (image file path, a string of text, etc) to the :func:`~flash.core.model.Task.predict` method.


.. code-block:: python

from flash import Trainer
Expand Down Expand Up @@ -51,5 +51,3 @@ Predict on a csv file
# 3. Generate predictions from a csv file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)


4 changes: 2 additions & 2 deletions docs/source/general/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Flash tasks supports many advanced training functionalities out-of-the-box, such

# train on 1 GPU
flash.Trainer(gpus=1)

* Training on multiple GPUs

.. code-block:: python
Expand All @@ -60,7 +60,7 @@ Flash tasks supports many advanced training functionalities out-of-the-box, such

# train on gpu 1, 3, 5 (3 gpus total)
flash.Trainer(gpus=[1, 3, 5])

* Using mixed precision training

.. code-block:: python
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Lightning Flash
reference/task
reference/image_classification
reference/image_embedder
reference/multi_label_classification
reference/summarization
reference/text_classification
reference/tabular_classification
Expand Down
4 changes: 2 additions & 2 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ For getting started with Deep Learning

Easy to learn
^^^^^^^^^^^^^
If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required!
If you are just getting started with deep learning, Flash offers common deep learning tasks you can use out-of-the-box in a few lines of code, no math, fancy nn.Modules or research experience required!

Easy to scale
^^^^^^^^^^^^^
Expand Down Expand Up @@ -70,7 +70,7 @@ You can install flash using pip or conda:
Tasks
=====

Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods.
Flash is comprised of a collection of Tasks. The Flash tasks are laser-focused objects designed to solve a well-defined type of problem, using state-of-the-art methods.

The Flash tasks contain all the relevant information to solve the task at hand- the number of class labels you want to predict, number of columns in your dataset, as well as details on the model architecture used such as loss function, optimizers, etc.

Expand Down
Loading