Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Docs cleanup and migrate to testcode (#293)
Browse files Browse the repository at this point in the history
* Migrate to testcode

* Update

* Updates

* Fixes

* Updates

* Updates

* Updates

* Updates

* Add finetuning

* Updates

* Updates

* Update training.rst

* small fix

* small fix

* Updates

* Updates

* Updates

* Updates

* Updates

* Fixes

* Update object detection docs

* Updates

* Updates

* Add video docs

* Fix doctest

* fixes

* Fixes

* Fix

* Update
  • Loading branch information
ethanwharris authored May 13, 2021
1 parent 7d8d159 commit 1f50b3f
Show file tree
Hide file tree
Showing 24 changed files with 467 additions and 1,106 deletions.
76 changes: 76 additions & 0 deletions docs/source/common/finetuning_example.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
To use a Task for finetuning:

1. Load your data and organize it using a DataModule customized for the task (example: :class:`~flash.vision.ImageClassificationData`).
2. Choose and initialize your Task which has state-of-the-art backbones built in (example: :class:`~flash.vision.ImageClassifier`).
3. Init a :class:`flash.core.trainer.Trainer`.
4. Choose a finetune strategy (example: "freeze") and call :func:`flash.core.trainer.Trainer.finetune` with your data.
5. Save your finetuned model.

|
Here's an example of finetuning.

.. testcode:: finetune

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1)

# 4. Finetune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

.. testoutput:: finetune
:hide:

...

Using a finetuned model
-----------------------
Once you've finetuned, use the model to predict:

.. testcode:: finetune

# Serialize predictions as labels, automatically inferred from the training data in part 2.
model.serializer = Labels()

predictions = model.predict(["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg"])
print(predictions)

.. testoutput:: finetune

['bees', 'ants']

Or you can use the saved model for prediction anywhere you want!

.. code-block:: python
from flash.vision import ImageClassifier
# load finetuned checkpoint
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt")
predictions = model.predict('path/to/your/own/image.png')
19 changes: 19 additions & 0 deletions docs/source/common/image_backbones.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Available backbones:

* resnet18 (default)
* resnet34
* resnet50
* resnet101
* resnet152
* resnext50_32x4d
* resnext101_32x8d
* mobilenet_v2
* vgg11
* vgg13
* vgg16
* vgg19
* densenet121
* densenet169
* densenet161
* swav-imagenet
* `TIMM <https://rwightman.github.io/pytorch-image-models/>`_ (130+ PyTorch Image Models)
9 changes: 9 additions & 0 deletions docs/source/common/object_detection_backbones.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Available backbones:

* resnet18
* resnet34
* resnet50
* resnet101
* resnet152
* resnext50_32x4d
* resnext101_32x8d
49 changes: 49 additions & 0 deletions docs/source/common/training_example.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
To train a task from scratch:

1. Load your data and organize it using a DataModule customized for the task (example: :class:`~flash.vision.ImageClassificationData`).
2. Choose and initialize your Task (setting ``pretrained=False``) which has state-of-the-art backbones built in (example: :class:`~flash.vision.ImageClassifier`).
3. Init a :class:`flash.core.trainer.Trainer` or a :class:`pytorch_lightning.trainer.Trainer`.
4. Call :func:`flash.core.trainer.Trainer.fit` with your data set.
5. Save your trained model.

|
Here's an example:

.. testcode:: training

from pytorch_lightning import seed_everything

import flash
from flash.core.classification import Labels
from flash.data.utils import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# set the random seeds.
seed_everything(42)

# 1. Download and organize the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1)

# 4. Train the model
trainer.fit(model, datamodule=datamodule)

# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt")

.. testoutput:: training
:hide:

...
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
"torch": ("https://pytorch.org/docs/stable/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"PIL": ("https://pillow.readthedocs.io/en/stable/", None),
"pytorch_lightning": ("https://pytorch-lightning.readthedocs.io/en/stable/", None),
}

# -- Options for HTML output -------------------------------------------------
Expand Down
40 changes: 25 additions & 15 deletions docs/source/custom_task.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ which is stored as numpy arrays.
1. Imports
----------

We first import everything we're going to use and set the random seed using :func:`~pytorch_lightning.utilities.seed.seed_everything`.

.. code:: python
.. testcode:: custom_task

from typing import Any, Callable, Dict, List, Optional, Tuple

Expand Down Expand Up @@ -50,13 +51,13 @@ you will likely need to override the ``__init__``, ``forward``, and the ``{train
It's best practice in flash for the data to be provide as a dictionary which maps string keys to their values. The
``{train,val,test,predict}_step`` methods need to be overridden to extract the data from the input dictionary.

Example::
.. testcode:: custom_task

class RegressionTask(flash.Task):

def __init__(self, num_inputs, learning_rate=0.2, metrics=None):
# what kind of model do we want?
model = nn.Linear(num_inputs, 1)
model = torch.nn.Linear(num_inputs, 1)

# what loss function do we want?
loss_fn = torch.nn.functional.mse_loss
Expand Down Expand Up @@ -119,9 +120,9 @@ is called for training, validation, and testing) or override ``training_step``,
individually. These methods behave identically to PyTorch Lightning’s
`methods <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#methods>`__.

Here is the pseudo code behind :class:`~flash.core.model.Task` step.
Here is the pseudo code behind :class:`~flash.core.model.Task` step:

Example::
.. code:: python
def step(self, batch: Any, batch_idx: int) -> Any:
"""
Expand All @@ -144,7 +145,7 @@ loading the train data (``if self.training:``), the ``NumpyDataSource`` sets the
optional ``dataset`` argument. Any attributes that are set on the optional ``dataset`` argument will also be set on the
generated ``dataset``.

Example::
.. testcode:: custom_task

class NumpyDataSource(DataSource[Tuple[ND, ND]]):

Expand Down Expand Up @@ -186,7 +187,7 @@ The recommended way to define a custom :class:`~flash.data.process.Preprocess` i
- Override the ``{train,val,test,predict}_default_transforms`` methods to specify the default transforms to use in each stage (these will be used if the transforms passed in the ``__init__`` are ``None``).
- Transforms are given as a mapping from hook name to callable transforms. You should use :class:`~flash.data.transforms.ApplyToKeys` to apply each transform only to specific keys in the data dictionary.

Example::
.. testcode:: custom_task

class NumpyPreprocess(Preprocess):

Expand Down Expand Up @@ -253,7 +254,7 @@ data source whose name is in :class:`~flash.data.data_source.DefaultDataSources`
``DataModule.from_*`` method that provides the expected inputs. So in this case, there is the
:meth:`~flash.data.data_module.DataModule.from_numpy` that will use our numpy data source.

Example::
.. testcode:: custom_task

class NumpyDataModule(flash.DataModule):

Expand All @@ -273,25 +274,29 @@ dataset <https://scikit-learn.org/stable/datasets/toy_dataset.html#diabetes-data
Like any Flash Task, we can fit our model using the ``flash.Trainer`` by
supplying the task itself, and the associated data:

.. code:: python
.. testcode:: custom_task

x, y = datasets.load_diabetes(return_X_y=True)
datamodule = NumpyDataModule.from_numpy(x, y)
model = RegressionTask(num_inputs=datamodule.train_dataset.num_inputs)

trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20)
trainer = flash.Trainer(max_epochs=20, progress_bar_refresh_rate=20, checkpoint_callback=False)
trainer.fit(model, datamodule=datamodule)

.. testoutput:: custom_task
:hide:

...


5. Predicting
-------------

With a trained model we can now perform inference. Here we will use a
few examples from the test set of our data:
With a trained model we can now perform inference. Here we will use a few examples from the test set of our data:

.. code:: python
.. testcode:: custom_task

predict_data = torch.tensor([
predict_data = np.array([
[ 0.0199, 0.0507, 0.1048, 0.0701, -0.0360, -0.0267, -0.0250, -0.0026, 0.0037, 0.0403],
[-0.0128, -0.0446, 0.0606, 0.0529, 0.0480, 0.0294, -0.0176, 0.0343, 0.0702, 0.0072],
[ 0.0381, 0.0507, 0.0089, 0.0425, -0.0428, -0.0210, -0.0397, -0.0026, -0.0181, 0.0072],
Expand All @@ -301,4 +306,9 @@ few examples from the test set of our data:

predictions = model.predict(predict_data)
print(predictions)
# out: [tensor([188.9760]), tensor([196.1777]), tensor([161.3590]), tensor([130.7312]), tensor([149.0340])]

We get the following output:

.. testoutput:: custom_task

[tensor([189.1198]), tensor([196.0839]), tensor([161.2461]), tensor([130.7591]), tensor([149.1780])]
2 changes: 1 addition & 1 deletion docs/source/general/callback.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Callback
Flash Callback
**************

:class:`~flash.data.callback.FlashCallback` is an extension of the PyTorch Lightning `Callback <https://pytorch-lightning.readthedocs.io/en/latest/extensions/callbacks.html>`__.
:class:`~flash.data.callback.FlashCallback` is an extension of :class:`pytorch_lightning.callbacks.Callback`.

A callback is a self-contained program that can be reused across projects.

Expand Down
Loading

0 comments on commit 1f50b3f

Please sign in to comment.