Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Cluster (#36)
Browse files Browse the repository at this point in the history
* image emb

* image emb

* add ImageEmbedder Pipeline

* clean pooling

* Update docs/source/reference/image_embedder.rst

Co-authored-by: Jirka Borovec <[email protected]>

* Update docs/source/reference/image_embedder.rst

Co-authored-by: Jirka Borovec <[email protected]>

* update on comments

* update on comments

* removing todo

* change to decorator

* add comment

* add `grad_enabled` to predict

* bolts break the doc

* update

* resolve doc tests

* remove grad enabled to False

Co-authored-by: William Falcon <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored Feb 1, 2021
1 parent 48ec088 commit 008bc75
Show file tree
Hide file tree
Showing 9 changed files with 435 additions and 31 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Lightning Flash

reference/task
reference/image_classification
reference/image_embedder
reference/text_classification
reference/tabular_classification

Expand Down
60 changes: 30 additions & 30 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,39 @@ Use the :class:`~flash.text.ImageClassificatier` pretrained model for inference

.. code-block:: python
# import our libraries
from flash.text import TextClassifier
# Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 2. Perform inference from list of sequences
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)
# import our libraries
from flash.text import TextClassifier
# Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 2. Perform inference from list of sequences
predictions = model.predict([
"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.",
"The worst movie in the history of cinema.",
"I come from Bulgaria where it 's almost impossible to have a tornado."
"Very, very afraid"
"This guy has done a great job with this movie!",
])
print(predictions)
Or on a given dataset:

.. code-block:: python
# import our libraries
from flash import download_data
from flash.text import TextClassifier
# import our libraries
from flash import download_data
from flash.text import TextClassifier
# 1. Download dataset, save it under 'data' dir
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 1. Download dataset, save it under 'data' dir
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/')
# 2. Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 2. Load finetuned task
model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
# 3. Perform inference from a csv file
predictions = model.predict("data/imdb/test.csv")
print(predictions)
# 3. Perform inference from a csv file
predictions = model.predict("data/imdb/test.csv")
print(predictions)
For more advanced inference options, see :ref:`predictions`.

Expand Down Expand Up @@ -97,16 +97,16 @@ Now all we need is three lines of code to build to train our task!

.. code-block:: python
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
# 2. Load the data
datamodule = ImageClassificationData.from_folders(
backbone="resnet18",
backbone="resnet18",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
Expand Down Expand Up @@ -140,7 +140,7 @@ By default, we use a `ResNet-18 <https://arxiv.org/abs/1512.03385>`_ for image c
# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)
Expand Down
151 changes: 151 additions & 0 deletions docs/source/reference/image_embedder.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@

.. _image_embedder:

##############
Image Embedder
##############

********
The task
********
Image embedding encodes an image into a vector of image features which can be used for anything like clustering, similarity
search or classification.

------

*********
Inference
*********

The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](http://www.image-net.org/), a dataset of over 14 million images.

Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.vision.ImageEmbedder.predict`:

.. code-block:: python
from flash.vision import ImageEmbedder
# Load finetuned task
embedder = ImageEmbedder(backbone='resnet18')
# 2. Perform inference on an image file
embeddings = model.predict('path/to/image.png')
print(predictions)
Or on a random image tensor

.. code-block:: python
# 2. Perform inference on an image file
import torch
images = torch.rand(32, 3, 224, 224)
embeddings = model.predict(images)
print(predictions)
For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********
To tailor this image embedder to your dataset, finetune first.

.. code-block:: python
import flash
from flash import download_data
from flash.vision import ImageClassificationData, ImageClassifier
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/')
# 2. Load the data
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/",
test_folder="data/hymenoptera_data/test/",
)
# 3. Build the model
embedder = ImageEmbedder(backbone="resnet18")
# 4. Create the trainer. Run once on data
trainer = flash.Trainer(max_epochs=1)
# 5. Train the model
trainer.finetune(model, datamodule=datamodule)
# 6. Test the model
trainer.test()
# 7. Save it!
trainer.save_checkpoint("image_embedder_model.pt")
------

*********************
Changing the backbone
*********************
By default, we use the encoder from `Swav <https://arxiv.org/pdf/2006.09882.pdf>`_ pretrained on Imagenet via contrastive learning. You can change the model run by the task by passing in a different backbone.

.. note::

When changing the backbone, make sure you pass in the same backbone to the Task and the Data object!

.. code-block:: python
# 1. organize the data
data = ImageClassificationData.from_folders(
backbone="resnet34",
train_folder="data/hymenoptera_data/train/",
valid_folder="data/hymenoptera_data/val/"
)
# 2. build the task
task = ImageClassifier(num_classes=2, backbone="resnet34")
Backbones available

.. list-table:: Backbones
:widths: 50 20 20
:header-rows: 1

* - backbone
- dataset
- training method
* - resnet18
- Imagenet
- supervised
* - resnet34
- Imagenet
- supervised
* - resnet50
- Imagenet
- supervised
* - resnet101
- Imagenet
- supervised
* - resnet152
- Imagenet
- supervised
* - swav-imagenet
- Imagenet
- self-supervised (clustering)

------

*************
API reference
*************

.. _image_embedder_class:

ImageEmbedder
---------------

.. autoclass:: flash.vision.ImageEmbedder
:members:
:exclude-members: forward

.. _image_embedder_data:
29 changes: 28 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union

import pytorch_lightning as pl
Expand All @@ -21,6 +22,25 @@
from flash.core.utils import get_callable_dict


def predict_context(func: Callable) -> Callable:
"""
This decorator is used as context manager
to put model in eval mode before running predict and reset to train after.
"""

def wrapper(self, *args, **kwargs) -> Any:
self.eval()
torch.set_grad_enabled(False)

result = func(self, *args, **kwargs)

self.train()
torch.set_grad_enabled(True)
return result

return wrapper


class Task(pl.LightningModule):
"""A general Task.
Expand Down Expand Up @@ -91,6 +111,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None:
output = self.step(batch, batch_idx)
self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True)

@predict_context
def predict(
self,
x: Any,
Expand All @@ -103,12 +124,17 @@ def predict(
Predict function for raw data or processed data
Args:
x: Input to predict. Can be raw data or processed data.
batch_idx: Batch index
dataloader_idx: Dataloader index
skip_collate_fn: Whether to skip the collate step.
this is required when passing data already processed
for the model, for example, data from a dataloader
data_pipeline: Use this to override the current data pipeline
Returns:
Expand All @@ -119,7 +145,8 @@ def predict(
batch = x if skip_collate_fn else data_pipeline.collate_fn(x)
batch_x, batch_y = batch if len(batch) == 2 else (batch, None)
predictions = self.forward(batch_x)
return data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x
return output

def configure_optimizers(self) -> torch.optim.Optimizer:
return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
Expand Down
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from flash.vision.classification import ImageClassificationData, ImageClassifier
from flash.vision.detection import ImageDetector
from flash.vision.embedding import ImageEmbedder
1 change: 1 addition & 0 deletions flash/vision/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.vision.embedding.image_embedder_model import ImageEmbedder
Loading

0 comments on commit 008bc75

Please sign in to comment.