This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Docs cleanup and migrate to testcode (#293)
* Migrate to testcode * Update * Updates * Fixes * Updates * Updates * Updates * Updates * Add finetuning * Updates * Updates * Update training.rst * small fix * small fix * Updates * Updates * Updates * Updates * Updates * Fixes * Update object detection docs * Updates * Updates * Add video docs * Fix doctest * fixes * Fixes * Fix * Update
- Loading branch information
1 parent
7d8d159
commit 1f50b3f
Showing
24 changed files
with
467 additions
and
1,106 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
To use a Task for finetuning: | ||
|
||
1. Load your data and organize it using a DataModule customized for the task (example: :class:`~flash.vision.ImageClassificationData`). | ||
2. Choose and initialize your Task which has state-of-the-art backbones built in (example: :class:`~flash.vision.ImageClassifier`). | ||
3. Init a :class:`flash.core.trainer.Trainer`. | ||
4. Choose a finetune strategy (example: "freeze") and call :func:`flash.core.trainer.Trainer.finetune` with your data. | ||
5. Save your finetuned model. | ||
|
||
| | ||
Here's an example of finetuning. | ||
|
||
.. testcode:: finetune | ||
|
||
from pytorch_lightning import seed_everything | ||
|
||
import flash | ||
from flash.core.classification import Labels | ||
from flash.data.utils import download_data | ||
from flash.vision import ImageClassificationData, ImageClassifier | ||
|
||
# set the random seeds. | ||
seed_everything(42) | ||
|
||
# 1. Download and organize the data | ||
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') | ||
|
||
datamodule = ImageClassificationData.from_folders( | ||
train_folder="data/hymenoptera_data/train/", | ||
val_folder="data/hymenoptera_data/val/", | ||
test_folder="data/hymenoptera_data/test/", | ||
) | ||
|
||
# 2. Build the model using desired Task | ||
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) | ||
|
||
# 3. Create the trainer (run one epoch for demo) | ||
trainer = flash.Trainer(max_epochs=1) | ||
|
||
# 4. Finetune the model | ||
trainer.finetune(model, datamodule=datamodule, strategy="freeze") | ||
|
||
# 5. Save the model! | ||
trainer.save_checkpoint("image_classification_model.pt") | ||
|
||
.. testoutput:: finetune | ||
:hide: | ||
|
||
... | ||
|
||
Using a finetuned model | ||
----------------------- | ||
Once you've finetuned, use the model to predict: | ||
|
||
.. testcode:: finetune | ||
|
||
# Serialize predictions as labels, automatically inferred from the training data in part 2. | ||
model.serializer = Labels() | ||
|
||
predictions = model.predict(["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", "data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg"]) | ||
print(predictions) | ||
|
||
.. testoutput:: finetune | ||
|
||
['bees', 'ants'] | ||
|
||
Or you can use the saved model for prediction anywhere you want! | ||
|
||
.. code-block:: python | ||
from flash.vision import ImageClassifier | ||
# load finetuned checkpoint | ||
model = ImageClassifier.load_from_checkpoint("image_classification_model.pt") | ||
predictions = model.predict('path/to/your/own/image.png') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
Available backbones: | ||
|
||
* resnet18 (default) | ||
* resnet34 | ||
* resnet50 | ||
* resnet101 | ||
* resnet152 | ||
* resnext50_32x4d | ||
* resnext101_32x8d | ||
* mobilenet_v2 | ||
* vgg11 | ||
* vgg13 | ||
* vgg16 | ||
* vgg19 | ||
* densenet121 | ||
* densenet169 | ||
* densenet161 | ||
* swav-imagenet | ||
* `TIMM <https://rwightman.github.io/pytorch-image-models/>`_ (130+ PyTorch Image Models) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
Available backbones: | ||
|
||
* resnet18 | ||
* resnet34 | ||
* resnet50 | ||
* resnet101 | ||
* resnet152 | ||
* resnext50_32x4d | ||
* resnext101_32x8d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
To train a task from scratch: | ||
|
||
1. Load your data and organize it using a DataModule customized for the task (example: :class:`~flash.vision.ImageClassificationData`). | ||
2. Choose and initialize your Task (setting ``pretrained=False``) which has state-of-the-art backbones built in (example: :class:`~flash.vision.ImageClassifier`). | ||
3. Init a :class:`flash.core.trainer.Trainer` or a :class:`pytorch_lightning.trainer.Trainer`. | ||
4. Call :func:`flash.core.trainer.Trainer.fit` with your data set. | ||
5. Save your trained model. | ||
|
||
| | ||
Here's an example: | ||
|
||
.. testcode:: training | ||
|
||
from pytorch_lightning import seed_everything | ||
|
||
import flash | ||
from flash.core.classification import Labels | ||
from flash.data.utils import download_data | ||
from flash.vision import ImageClassificationData, ImageClassifier | ||
|
||
# set the random seeds. | ||
seed_everything(42) | ||
|
||
# 1. Download and organize the data | ||
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') | ||
|
||
datamodule = ImageClassificationData.from_folders( | ||
train_folder="data/hymenoptera_data/train/", | ||
val_folder="data/hymenoptera_data/val/", | ||
test_folder="data/hymenoptera_data/test/", | ||
) | ||
|
||
# 2. Build the model using desired Task | ||
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, pretrained=False) | ||
|
||
# 3. Create the trainer (run one epoch for demo) | ||
trainer = flash.Trainer(max_epochs=1) | ||
|
||
# 4. Train the model | ||
trainer.fit(model, datamodule=datamodule) | ||
|
||
# 5. Save the model! | ||
trainer.save_checkpoint("image_classification_model.pt") | ||
|
||
.. testoutput:: training | ||
:hide: | ||
|
||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.