Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Cluster #36

Merged
merged 20 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Lightning Flash

reference/task
reference/image_classification
reference/image_embedder
reference/text_classification
reference/tabular_classification

Expand Down
60 changes: 30 additions & 30 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,39 @@ Use the :class:`~flash.text.ImageClassificatier` pretrained model for inference

.. code-block:: python

# import our libraries
from flash.text import TextClassifier

# Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)
# import our libraries
from flash.text import TextClassifier

# Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)

Or on a given dataset:

.. code-block:: python

# import our libraries
from flash import download_data
from flash.text import TextClassifier
# import our libraries
from flash import download_data
from flash.text import TextClassifier

# 1. Download dataset, save it under 'data' dir
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 1. Download dataset, save it under 'data' dir
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 2. Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 2. Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3. Perform inference from a csv file
predictions = model.predict("data/imdb/test.csv")
print(predictions)
# 3. Perform inference from a csv file
predictions = model.predict("data/imdb/test.csv")
print(predictions)

For more advanced inference options, see :ref:`predictions`.

Expand Down Expand Up @@ -97,16 +97,16 @@ Now all we need is three lines of code to build to train our task!

.. code-block:: python

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
backbone="resnet18",
backbone="resnet18",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
Expand Down Expand Up @@ -140,7 +140,7 @@ By default, we use a `ResNet-18 <https://arxiv.org/abs/1512.03385>`_ for image c

# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)
Expand Down
154 changes: 154 additions & 0 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@

.. _image_embedder:

##############
Image Embedder
##############

********
The task
********
Image embedding encodes an image into a vector of image features which can be used for anything like clustering, similarity
search or classification.

------

*********
Inference
*********

The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](http://www.image-net.org/), a dataset of over 14 million images.

Use the :class:`~flash.text.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.text.ImageEmbedder.predict`:
tchaton marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python
tchaton marked this conversation as resolved.
Show resolved Hide resolved

from flash.vision import ImageEmbedder

# Load finetuned task
embedder = ImageEmbedder(backbone='swav-imagenet')

# 2. Perform inference on an image file
embeddings = model.predict('path/to/image.png')
print(predictions)

Or on an image tensor

.. code-block:: python

from flash.vision import ImageEmbedder

# Load finetuned task
embedder = ImageEmbedder(backbone='swav-imagenet')
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# 2. Perform inference on an image file
images = torch.rand(32, 3, 224, 224)
embeddings = model.predict(images)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
print(predictions)

For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********
To tailor this image embedder to your dataset, finetune first.

.. code-block:: python

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
backbone="resnet18",
tchaton marked this conversation as resolved.
Show resolved Hide resolved
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
embedder = ImageEmbedder()

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1))

# 6. Test the model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

------

*********************
Changing the backbone
*********************
By default, we use the encoder from `Swav <https://arxiv.org/pdf/2006.09882.pdf>`_ pretrained on Imagenet via contrastive learning. You can change the model run by the task by passing in a different backbone.

.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object!
tchaton marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)

# 2. build the task
task = ImageClassifier(num_classes=2, backbone="resnet34")

Backbones available

.. list-table:: Backbones
:widths: 50 20 20
:header-rows: 1

* - backbone
- dataset
- training method
* - resnet18
- Imagenet
- supervised
* - resnet34
- Imagenet
- supervised
* - resnet50
- Imagenet
- supervised
* - resnet101
- Imagenet
- supervised
* - resnet152
- Imagenet
- supervised
* - swav-imagenet
- Imagenet
- self-supervised (clustering)

------

*************
API reference
*************

.. _image_embedder_class:

ImageEmbedder
---------------

.. autoclass:: flash.vision.ImageEmbedder
:members:
:exclude-members: forward

.. _image_embedder_data:
10 changes: 6 additions & 4 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@ def predict(
The post-processed model predictions

"""
data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 else (batch, None)
predictions = self.forward(batch_x)
self.eval()
with torch.no_grad():
tchaton marked this conversation as resolved.
Show resolved Hide resolved
data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 else (batch, None)
predictions = self.forward(batch_x)
return data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x

def configure_optimizers(self) -> torch.optim.Optimizer:
Expand Down
37 changes: 37 additions & 0 deletions flash/model_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV

DEFAULT_URLS = {
"SimCLR": 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt',
"SwAV": 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar',
}


def load_model(name):
if name == 'simclr-imagenet':
return load_simclr_imagenet()

elif name == 'swav-imagenet':
return load_swav_imagenet()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

else:
raise MisconfigurationException("Currently, only `simclr-imagenet` and `swav-imagenet` are supported.")


def load_simclr_imagenet():
simclr = SimCLR.load_from_checkpoint(DEFAULT_URLS["SimCLR"], strict=False)

model_config = {'model': simclr.encoder, 'emb_size': 2048}
return model_config


def load_swav_imagenet():
swav = SwAV.load_from_checkpoint(DEFAULT_URLS["SwAV"], strict=True)
model_config = {'model': swav.model, 'num_features': 3000}
return model_config


models = {'simclr-imagenet': load_simclr_imagenet, 'swav-imagenet': load_swav_imagenet}
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.detection import ImageDetector
from flash.vision.embedding import ImageEmbedder
1 change: 1 addition & 0 deletions flash/vision/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.vision.embedding.image_embedder_model import ImageEmbedder
Loading