Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Cluster #36

Merged
merged 20 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Lightning Flash

reference/task
reference/image_classification
reference/image_embedder
reference/text_classification
reference/tabular_classification

Expand Down
60 changes: 30 additions & 30 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,39 @@ Use the :class:`~flash.text.ImageClassificatier` pretrained model for inference

.. code-block:: python

# import our libraries
from flash.text import TextClassifier

# Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)
# import our libraries
from flash.text import TextClassifier

# Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)

Or on a given dataset:

.. code-block:: python

# import our libraries
from flash import download_data
from flash.text import TextClassifier
# import our libraries
from flash import download_data
from flash.text import TextClassifier

# 1. Download dataset, save it under 'data' dir
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 1. Download dataset, save it under 'data' dir
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 2. Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 2. Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3. Perform inference from a csv file
predictions = model.predict("data/imdb/test.csv")
print(predictions)
# 3. Perform inference from a csv file
predictions = model.predict("data/imdb/test.csv")
print(predictions)

For more advanced inference options, see :ref:`predictions`.

Expand Down Expand Up @@ -97,16 +97,16 @@ Now all we need is three lines of code to build to train our task!

.. code-block:: python

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
backbone="resnet18",
backbone="resnet18",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
Expand Down Expand Up @@ -140,7 +140,7 @@ By default, we use a `ResNet-18 <https://arxiv.org/abs/1512.03385>`_ for image c

# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)
Expand Down
150 changes: 150 additions & 0 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@

.. _image_embedder:

##############
Image Embedder
##############

********
The task
********
Image embedding encodes an image into a vector of image features which can be used for anything like clustering, similarity
search or classification.

------

*********
Inference
*********

The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](http://www.image-net.org/), a dataset of over 14 million images.

Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.vision.ImageEmbedder.predict`:

.. testcode:: python

from flash.vision import ImageEmbedder

# Load finetuned task
embedder = ImageEmbedder(backbone='swav-imagenet')

# 2. Perform inference on an image file
embeddings = model.predict('path/to/image.png')
print(predictions)

Or on a random image tensor

.. testcode:: python

# 2. Perform inference on an image file
images = torch.rand(32, 3, 224, 224)
embeddings = model.predict(images)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
print(predictions)

For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********
To tailor this image embedder to your dataset, finetune first.

.. code-block:: python

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
embedder = ImageEmbedder(backbone="resnet18")

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule)

# 6. Test the model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

------

*********************
Changing the backbone
*********************
By default, we use the encoder from `Swav <https://arxiv.org/pdf/2006.09882.pdf>`_ pretrained on Imagenet via contrastive learning. You can change the model run by the task by passing in a different backbone.

.. note::

When changing the backbone, make sure you pass in the same backbone to the Task and the Data object!

.. code-block:: python

# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)

# 2. build the task
task = ImageClassifier(num_classes=2, backbone="resnet34")

Backbones available

.. list-table:: Backbones
:widths: 50 20 20
:header-rows: 1

* - backbone
- dataset
- training method
* - resnet18
- Imagenet
- supervised
* - resnet34
- Imagenet
- supervised
* - resnet50
- Imagenet
- supervised
* - resnet101
- Imagenet
- supervised
* - resnet152
- Imagenet
- supervised
* - swav-imagenet
- Imagenet
- self-supervised (clustering)

------

*************
API reference
*************

.. _image_embedder_class:

ImageEmbedder
---------------

.. autoclass:: flash.vision.ImageEmbedder
:members:
:exclude-members: forward

.. _image_embedder_data:
24 changes: 19 additions & 5 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -91,6 +92,17 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)

@property
@contextmanager
def predict_context(self):
try:
self.eval()
torch.set_grad_enabled(False)
yield
finally:
self.train()
torch.set_grad_enabled(True)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

tchaton marked this conversation as resolved.
Show resolved Hide resolved
def predict(
self,
x: Any,
Expand All @@ -115,11 +127,13 @@ def predict(
The post-processed model predictions

"""
data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 else (batch, None)
predictions = self.forward(batch_x)
return data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
with self.predict_context:
data_pipeline = data_pipeline or self.data_pipeline
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 else (batch, None)
predictions = self.forward(batch_x)
output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
return output

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
Expand Down
33 changes: 33 additions & 0 deletions flash/model_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from functools import partial

from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV

ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com"


def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"):
simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
model_config = {'model': simclr.encoder, 'emb_size': 2048}
return model_config


def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"):
swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
model_config = {'model': swav.model, 'num_features': 3000}
return model_config


_models = {'simclr-imagenet': load_simclr_imagenet, 'swav-imagenet': load_swav_imagenet}


def load_model(name):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not _BOLTS_AVAILABLE:
raise MisconfigurationException("Bolts isn't installed. Please, use pip install `lightning-bolts`.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if name in _models:
return _models[name]()
else:
raise MisconfigurationException("Currently, only `simclr-imagenet` and `swav-imagenet` are supported.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.detection import ImageDetector
from flash.vision.embedding import ImageEmbedder
1 change: 1 addition & 0 deletions flash/vision/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.vision.embedding.image_embedder_model import ImageEmbedder
Loading