This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* image emb * image emb * add ImageEmbedder Pipeline * clean pooling * Update docs/source/reference/image_embedder.rst Co-authored-by: Jirka Borovec <[email protected]> * Update docs/source/reference/image_embedder.rst Co-authored-by: Jirka Borovec <[email protected]> * update on comments * update on comments * removing todo * change to decorator * add comment * add `grad_enabled` to predict * bolts break the doc * update * resolve doc tests * remove grad enabled to False Co-authored-by: William Falcon <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
- Loading branch information
1 parent
48ec088
commit 008bc75
Showing
9 changed files
with
435 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
|
||
.. _image_embedder: | ||
|
||
############## | ||
Image Embedder | ||
############## | ||
|
||
******** | ||
The task | ||
******** | ||
Image embedding encodes an image into a vector of image features which can be used for anything like clustering, similarity | ||
search or classification. | ||
|
||
------ | ||
|
||
********* | ||
Inference | ||
********* | ||
|
||
The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](http://www.image-net.org/), a dataset of over 14 million images. | ||
|
||
Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.vision.ImageEmbedder.predict`: | ||
|
||
.. code-block:: python | ||
from flash.vision import ImageEmbedder | ||
# Load finetuned task | ||
embedder = ImageEmbedder(backbone='resnet18') | ||
# 2. Perform inference on an image file | ||
embeddings = model.predict('path/to/image.png') | ||
print(predictions) | ||
Or on a random image tensor | ||
|
||
.. code-block:: python | ||
# 2. Perform inference on an image file | ||
import torch | ||
images = torch.rand(32, 3, 224, 224) | ||
embeddings = model.predict(images) | ||
print(predictions) | ||
For more advanced inference options, see :ref:`predictions`. | ||
|
||
------ | ||
|
||
********** | ||
Finetuning | ||
********** | ||
To tailor this image embedder to your dataset, finetune first. | ||
|
||
.. code-block:: python | ||
import flash | ||
from flash import download_data | ||
from flash.vision import ImageClassificationData, ImageClassifier | ||
# 1. Download the data | ||
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') | ||
# 2. Load the data | ||
datamodule = ImageClassificationData.from_folders( | ||
train_folder="data/hymenoptera_data/train/", | ||
valid_folder="data/hymenoptera_data/val/", | ||
test_folder="data/hymenoptera_data/test/", | ||
) | ||
# 3. Build the model | ||
embedder = ImageEmbedder(backbone="resnet18") | ||
# 4. Create the trainer. Run once on data | ||
trainer = flash.Trainer(max_epochs=1) | ||
# 5. Train the model | ||
trainer.finetune(model, datamodule=datamodule) | ||
# 6. Test the model | ||
trainer.test() | ||
# 7. Save it! | ||
trainer.save_checkpoint("image_embedder_model.pt") | ||
------ | ||
|
||
********************* | ||
Changing the backbone | ||
********************* | ||
By default, we use the encoder from `Swav <https://arxiv.org/pdf/2006.09882.pdf>`_ pretrained on Imagenet via contrastive learning. You can change the model run by the task by passing in a different backbone. | ||
|
||
.. note:: | ||
|
||
When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! | ||
|
||
.. code-block:: python | ||
# 1. organize the data | ||
data = ImageClassificationData.from_folders( | ||
backbone="resnet34", | ||
train_folder="data/hymenoptera_data/train/", | ||
valid_folder="data/hymenoptera_data/val/" | ||
) | ||
# 2. build the task | ||
task = ImageClassifier(num_classes=2, backbone="resnet34") | ||
Backbones available | ||
|
||
.. list-table:: Backbones | ||
:widths: 50 20 20 | ||
:header-rows: 1 | ||
|
||
* - backbone | ||
- dataset | ||
- training method | ||
* - resnet18 | ||
- Imagenet | ||
- supervised | ||
* - resnet34 | ||
- Imagenet | ||
- supervised | ||
* - resnet50 | ||
- Imagenet | ||
- supervised | ||
* - resnet101 | ||
- Imagenet | ||
- supervised | ||
* - resnet152 | ||
- Imagenet | ||
- supervised | ||
* - swav-imagenet | ||
- Imagenet | ||
- self-supervised (clustering) | ||
|
||
------ | ||
|
||
************* | ||
API reference | ||
************* | ||
|
||
.. _image_embedder_class: | ||
|
||
ImageEmbedder | ||
--------------- | ||
|
||
.. autoclass:: flash.vision.ImageEmbedder | ||
:members: | ||
:exclude-members: forward | ||
|
||
.. _image_embedder_data: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from flash.vision.classification import ImageClassificationData, ImageClassifier | ||
from flash.vision.detection import ImageDetector | ||
from flash.vision.embedding import ImageEmbedder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from flash.vision.embedding.image_embedder_model import ImageEmbedder |
Oops, something went wrong.