Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Cluster #36

Merged
merged 20 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Lightning Flash

reference/task
reference/image_classification
reference/image_embedder
reference/text_classification
reference/tabular_classification

Expand Down
60 changes: 30 additions & 30 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,39 @@ Use the :class:`~flash.text.ImageClassificatier` pretrained model for inference

.. code-block:: python

# import our libraries
from flash.text import TextClassifier

# Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)
# import our libraries
from flash.text import TextClassifier

# Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 2. Perform inference from list of sequences
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)

Or on a given dataset:

.. code-block:: python

# import our libraries
from flash import download_data
from flash.text import TextClassifier
# import our libraries
from flash import download_data
from flash.text import TextClassifier

# 1. Download dataset, save it under 'data' dir
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 1. Download dataset, save it under 'data' dir
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')

# 2. Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 2. Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")

# 3. Perform inference from a csv file
predictions = model.predict("data/imdb/test.csv")
print(predictions)
# 3. Perform inference from a csv file
predictions = model.predict("data/imdb/test.csv")
print(predictions)

For more advanced inference options, see :ref:`predictions`.

Expand Down Expand Up @@ -97,16 +97,16 @@ Now all we need is three lines of code to build to train our task!

.. code-block:: python

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
backbone="resnet18",
backbone="resnet18",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
Expand Down Expand Up @@ -140,7 +140,7 @@ By default, we use a `ResNet-18 <https://arxiv.org/abs/1512.03385>`_ for image c

# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)
Expand Down
151 changes: 151 additions & 0 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@

.. _image_embedder:

##############
Image Embedder
##############

********
The task
********
Image embedding encodes an image into a vector of image features which can be used for anything like clustering, similarity
search or classification.

------

*********
Inference
*********

The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](http://www.image-net.org/), a dataset of over 14 million images.

Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.vision.ImageEmbedder.predict`:

.. code-block:: python
tchaton marked this conversation as resolved.
Show resolved Hide resolved

from flash.vision import ImageEmbedder

# Load finetuned task
embedder = ImageEmbedder(backbone='resnet18')

# 2. Perform inference on an image file
embeddings = model.predict('path/to/image.png')
print(predictions)

Or on a random image tensor

.. code-block:: python

# 2. Perform inference on an image file
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix indentation

images = torch.rand(32, 3, 224, 224)
embeddings = model.predict(images)
print(predictions)

For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********
To tailor this image embedder to your dataset, finetune first.

.. code-block:: python

import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)

# 3. Build the model
embedder = ImageEmbedder(backbone="resnet18")

# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule)

# 6. Test the model
trainer.test()

# 7. Save it!
trainer.save_checkpoint("image_embedder_model.pt")

------

*********************
Changing the backbone
*********************
By default, we use the encoder from `Swav <https://arxiv.org/pdf/2006.09882.pdf>`_ pretrained on Imagenet via contrastive learning. You can change the model run by the task by passing in a different backbone.

.. note::

When changing the backbone, make sure you pass in the same backbone to the Task and the Data object!

.. code-block:: python

# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)

# 2. build the task
task = ImageClassifier(num_classes=2, backbone="resnet34")

Backbones available

.. list-table:: Backbones
:widths: 50 20 20
:header-rows: 1

* - backbone
- dataset
- training method
* - resnet18
- Imagenet
- supervised
* - resnet34
- Imagenet
- supervised
* - resnet50
- Imagenet
- supervised
* - resnet101
- Imagenet
- supervised
* - resnet152
- Imagenet
- supervised
* - swav-imagenet
- Imagenet
- self-supervised (clustering)

------

*************
API reference
*************

.. _image_embedder_class:

ImageEmbedder
---------------

.. autoclass:: flash.vision.ImageEmbedder
:members:
:exclude-members: forward

.. _image_embedder_data:
29 changes: 28 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union

import pytorch_lightning as pl
Expand All @@ -21,6 +22,25 @@
from flash.core.utils import get_callable_dict


def predict_context(func: Callable) -> Callable:
"""
This decorator is used as context manager
to put model in eval mode before running predict and reset to train after.
"""

def wrapper(self, *args, **kwargs) -> Any:
self.eval()
torch.set_grad_enabled(False)

result = func(self, *args, **kwargs)

self.train()
torch.set_grad_enabled(True)
return result

return wrapper


class Task(pl.LightningModule):
"""A general Task.

Expand Down Expand Up @@ -91,6 +111,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)

@predict_context
def predict(
self,
x: Any,
Expand All @@ -103,12 +124,17 @@ def predict(
Predict function for raw data or processed data

Args:

x: Input to predict. Can be raw data or processed data.

batch_idx: Batch index

dataloader_idx: Dataloader index

skip_collate_fn: Whether to skip the collate step.
this is required when passing data already processed
for the model, for example, data from a dataloader

data_pipeline: Use this to override the current data pipeline

Returns:
Expand All @@ -119,7 +145,8 @@ def predict(
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 else (batch, None)
predictions = self.forward(batch_x)
return data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
return output

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
Expand Down
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.detection import ImageDetector
from flash.vision.embedding import ImageEmbedder
1 change: 1 addition & 0 deletions flash/vision/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.vision.embedding.image_embedder_model import ImageEmbedder
Loading