From 8c07f871b4efa08bc90da6522c8de57cb677376e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 31 Jan 2021 14:01:11 -0500 Subject: [PATCH 01/16] image emb --- docs/source/reference/image_embedder.rst | 187 ++++++++++++++++++ flash/model_map.py | 37 ++++ flash/vision/__init__.py | 1 + flash/vision/embedding/__init__.py | 1 + .../vision/embedding/image_embedder_model.py | 160 +++++++++++++++ 5 files changed, 386 insertions(+) create mode 100644 docs/source/reference/image_embedder.rst create mode 100644 flash/model_map.py create mode 100644 flash/vision/embedding/__init__.py create mode 100644 flash/vision/embedding/image_embedder_model.py diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst new file mode 100644 index 0000000000..7026e3ae8f --- /dev/null +++ b/docs/source/reference/image_embedder.rst @@ -0,0 +1,187 @@ + +.. _image_embedder: + +############## +Image Embedder +############## + +******** +The task +******** +Image embedding encodes an image into a feature vector. This feature vector can be used for any downstream task + +------ + +********* +Inference +********* + +The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](http://www.image-net.org/), a dataset of over 14 million images. + +TODO: + +Use the :class:`~flash.text.ImageEmbedder` pretrained model for inference on any string sequence using :func:`~flash.text.TextClassifier.predict`: + +.. code-block:: python + + # import our libraries + from flash.text import TextClassifier + + # Load finetuned task + model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + + # 2. Perform inference from list of sequences + predictions = model.predict([ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado." + "Very, very afraid" + "This guy has done a great job with this movie!", + ]) + print(predictions) + +Or on a given dataset: + +.. code-block:: python + + # import our libraries + from flash import download_data + from flash.text import TextClassifier + + # 1. Download dataset, save it under 'data' dir + download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') + + # 2. Load finetuned task + model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + + # 3. Perform inference from a csv file + predictions = model.predict("data/imdb/test.csv") + print(predictions) + +For more advanced inference options, see :ref:`predictions`. + +------ + +********** +Finetuning +********** + +Lets say you wanted to develope a model that could determine whether an image contains **ants** or **bees**, using the hymenoptera dataset. +Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.ImageClassificationData`. + +.. note:: The dataset contains ``train`` and ``validation`` folders, and then each folder contains a **bees** folder, with pictures of bees, and an **ants** folder with images of, you guessed it, ants. + +.. code-block:: + + hymenoptera_data + ├── train + │ ├── ants + │ │ ├── 0013035.jpg + │ │ ├── 1030023514_aad5c608f9.jpg + │ │ ... + │ └── bees + │ ├── 1092977343_cb42b38d62.jpg + │ ├── 1093831624_fb5fbe2308.jpg + │ ... + └── val + ├── ants + │ ├── 10308379_1b6c72e180.jpg + │ ├── 1053149811_f62a3410d3.jpg + │ ... + └── bees + ├── 1032546534_06907fe3b3.jpg + ├── 10870992_eebeeb3a12.jpg + ... + + +Now all we need is three lines of code to build to train our task! + +.. code-block:: python + + import flash + from flash import download_data + from flash.vision import ImageClassificationData, ImageClassifier + + # 1. Download the data + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') + + # 2. Load the data + datamodule = ImageClassificationData.from_folders( + backbone="resnet18", + train_folder="data/hymenoptera_data/train/", + valid_folder="data/hymenoptera_data/val/", + test_folder="data/hymenoptera_data/test/", + ) + + # 3. Build the model + model = ImageClassifier(num_classes=datamodule.num_classes) + + # 4. Create the trainer. Run once on data + trainer = flash.Trainer(max_epochs=1) + + # 5. Train the model + trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) + + # 6. Test the model + trainer.test() + + # 7. Save it! + trainer.save_checkpoint("image_classification_model.pt") + +------ + +********************* +Changing the backbone +********************* +By default, we use a `ResNet-18 `_ for image classification. You can change the model run by the task by passing in a different backbone. + +.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! + +.. code-block:: python + + # 1. organize the data + data = ImageClassificationData.from_folders( + backbone="resnet34", + train_folder="data/hymenoptera_data/train/", + valid_folder="data/hymenoptera_data/val/" + ) + + # 2. build the task + task = ImageClassifier(num_classes=2, backbone="resnet34") + +Available backbones: + +* resnet34 +* resnet50 +* resnet101 +* resnet152 + +------ + +************* +API reference +************* + +.. _image_classifier: + +ImageClassifier +--------------- + +.. autoclass:: flash.vision.ImageClassifier + :members: + :exclude-members: forward + +.. _image_classification_data: + +ImageClassificationData +----------------------- + +.. autoclass:: flash.vision.ImageClassificationData + +.. automethod:: flash.vision.ImageClassificationData.from_filepaths + +.. automethod:: flash.vision.ImageClassificationData.from_folders + + + + diff --git a/flash/model_map.py b/flash/model_map.py new file mode 100644 index 0000000000..aff1628102 --- /dev/null +++ b/flash/model_map.py @@ -0,0 +1,37 @@ + + +def load_model(name): + if name == 'simclr-imagenet': + return load_simclr_imagenet() + + if name == 'swav-imagenet': + return load_swav_imagenet() + + +def load_simclr_imagenet(): + from pl_bolts.models.self_supervised import SimCLR + weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' + simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) + + model_config = { + 'model': simclr.encoder, + 'emb_size': 2048 + } + return model_config + + +def load_swav_imagenet(): + from pl_bolts.models.self_supervised import SwAV + + weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar' + swav = SwAV.load_from_checkpoint(weight_path, strict=True) + model_config = { + 'model': swav.model, + 'num_features': 3000 + } + return model_config + +models = { + # 'simclr-imagenet': load_simclr_imagenet, + 'swav-imagenet': load_swav_imagenet +} diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 4bffeee822..6d7c326281 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -1,2 +1,3 @@ from flash.vision.classification import ImageClassificationData, ImageClassifier from flash.vision.detection import ImageDetector +from flash.vision.embedding import ImageEmbedder diff --git a/flash/vision/embedding/__init__.py b/flash/vision/embedding/__init__.py new file mode 100644 index 0000000000..5ba86a50cf --- /dev/null +++ b/flash/vision/embedding/__init__.py @@ -0,0 +1 @@ +from flash.vision.embedding.image_embedder_model import ImageEmbedder diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py new file mode 100644 index 0000000000..56ea298199 --- /dev/null +++ b/flash/vision/embedding/image_embedder_model.py @@ -0,0 +1,160 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Mapping, Sequence, Type, Union + +import torch +import torchvision +from pytorch_lightning.metrics import Accuracy +from torch import nn +from torch.nn import functional as F +from pytorch_lightning.utilities.distributed import rank_zero_warn +from pl_bolts.models.self_supervised import SimCLR + +from flash import Task +from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline +from flash.model_map import models + +_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731 +_resnet_feats = lambda model: model.fc.in_features # noqa: E731 + +_backbones = { + "resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats), + "resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats), + "resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats), + "resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats), + "resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats), +} + + +class ImageEmbedder(Task): + """Task that classifies images. + + Args: + embedding_dim: Dimension of the embedded vector. None uses the default from the backbone + backbone: A model to use to extract image features. + pretrained: Use a pretrained backbone. + loss_fn: Loss function for training and finetuning, defaults to cross entropy. + optimizer: Optimizer to use for training and finetuning, defaults to `torch.optim.SGD`. + metrics: Metrics to compute for training and evaluation. + learning_rate: Learning rate to use for training, defaults to `1e-3` + + Example:: + + from flash.vision import ImageEmbedder + + embedder = ImageEmbedder(backbone='swav-imagenet') + image = torch.rand(32, 3, 32, 32) + embeddings = embedder(image) + + Backbones available + + .. list-table:: Backbones + :widths: 50 20 20 + :header-rows: 1 + + * - backbone + - dataset + - training method + * - resnet18 + - Imagenet + - supervised + * - resnet34 + - Imagenet + - supervised + * - resnet50 + - Imagenet + - supervised + * - resnet101 + - Imagenet + - supervised + * - resnet152 + - Imagenet + - supervised + * - swav-imagenet + - Imagenet + - self-supervised (clustering) + + """ + + def __init__( + self, + embedding_dim=None, + backbone="swav-imagenet", + pretrained=True, + loss_fn: Callable = F.cross_entropy, + optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, + metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()), + learning_rate: float = 1e-3, + ): + super().__init__( + model=None, + loss_fn=loss_fn, + optimizer=optimizer, + metrics=metrics, + learning_rate=learning_rate, + ) + + self.save_hyperparameters() + self.backbone_name = backbone + self.embedding_dim = embedding_dim + + if backbone in models: + config = models[backbone]() + self.backbone = config['model'] + num_features = config['num_features'] + + elif backbone not in _backbones: + raise NotImplementedError(f"Backbone {backbone} is not yet supported") + + else: + backbone_fn, split, num_feats = _backbones[backbone] + backbone = backbone_fn(pretrained=pretrained) + self.backbone = split(backbone) + num_features = num_feats(backbone) + + if embedding_dim is None: + self.pooling = nn.Identity() + self.head = nn.Identity() + else: + self.pooling = nn.AdaptiveAvgPool2d((1, 1)), + self.head = nn.Sequential( + nn.Flatten(), + nn.Linear(num_features, embedding_dim), + ) + rank_zero_warn('embedding_dim is not None. Remember to finetune first!') + + def forward(self, x) -> Any: + x = self.backbone(x) + + # bolts ssl models return lists + if isinstance(x, tuple): + x = x[-1] + + if len(x.size()) == 4 and self.embedding_dim is not None: + x = self.pooling(x) + + x = self.head(x) + x = x.view(x.size(0), -1) + return x + + @staticmethod + def default_pipeline() -> ImageClassificationDataPipeline: + return ImageClassificationData.default_pipeline() + + +if __name__ == '__main__': + embedder = ImageEmbedder(backbone='resnet50') + image = torch.rand(32, 3, 128, 128) + embeddings = embedder(image) + print(embeddings.shape) From 9d67c404e82f0daac44421acb78e41a0093babbc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 1 Feb 2021 01:14:50 -0500 Subject: [PATCH 02/16] image emb --- docs/source/index.rst | 1 + .../source/reference/image_classification.rst | 60 ++++---- docs/source/reference/image_embedder.rst | 137 +++++++----------- flash/core/model.py | 9 +- .../vision/embedding/image_embedder_model.py | 33 +---- 5 files changed, 90 insertions(+), 150 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index cc1a7c5ffe..55768b8643 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ Lightning Flash reference/task reference/image_classification + reference/image_embedder reference/text_classification reference/tabular_classification diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index f2f8efe875..1904f59702 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -23,39 +23,39 @@ Use the :class:`~flash.text.ImageClassificatier` pretrained model for inference .. code-block:: python - # import our libraries - from flash.text import TextClassifier - - # Load finetuned task - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") - - # 2. Perform inference from list of sequences - predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado." - "Very, very afraid" - "This guy has done a great job with this movie!", - ]) - print(predictions) + # import our libraries + from flash.text import TextClassifier + + # Load finetuned task + model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + + # 2. Perform inference from list of sequences + predictions = model.predict([ + "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", + "The worst movie in the history of cinema.", + "I come from Bulgaria where it 's almost impossible to have a tornado." + "Very, very afraid" + "This guy has done a great job with this movie!", + ]) + print(predictions) Or on a given dataset: .. code-block:: python - # import our libraries - from flash import download_data - from flash.text import TextClassifier + # import our libraries + from flash import download_data + from flash.text import TextClassifier - # 1. Download dataset, save it under 'data' dir - download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') + # 1. Download dataset, save it under 'data' dir + download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') - # 2. Load finetuned task - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + # 2. Load finetuned task + model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") - # 3. Perform inference from a csv file - predictions = model.predict("data/imdb/test.csv") - print(predictions) + # 3. Perform inference from a csv file + predictions = model.predict("data/imdb/test.csv") + print(predictions) For more advanced inference options, see :ref:`predictions`. @@ -97,16 +97,16 @@ Now all we need is three lines of code to build to train our task! .. code-block:: python - import flash - from flash import download_data - from flash.vision import ImageClassificationData, ImageClassifier + import flash + from flash import download_data + from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') # 2. Load the data datamodule = ImageClassificationData.from_folders( - backbone="resnet18", + backbone="resnet18", train_folder="data/hymenoptera_data/train/", valid_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", @@ -140,7 +140,7 @@ By default, we use a `ResNet-18 `_ for image c # 1. organize the data data = ImageClassificationData.from_folders( - backbone="resnet34", + backbone="resnet34", train_folder="data/hymenoptera_data/train/", valid_folder="data/hymenoptera_data/val/" ) diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index 7026e3ae8f..f25ae72579 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -8,7 +8,8 @@ Image Embedder ******** The task ******** -Image embedding encodes an image into a feature vector. This feature vector can be used for any downstream task +Image embedding encodes an image into a vector of image features which can be used for anything like clustering, similarity +search or classification. ------ @@ -18,44 +19,31 @@ Inference The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](http://www.image-net.org/), a dataset of over 14 million images. -TODO: - -Use the :class:`~flash.text.ImageEmbedder` pretrained model for inference on any string sequence using :func:`~flash.text.TextClassifier.predict`: +Use the :class:`~flash.text.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.text.ImageEmbedder.predict`: .. code-block:: python - # import our libraries - from flash.text import TextClassifier + from flash.vision import ImageEmbedder # Load finetuned task - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") - - # 2. Perform inference from list of sequences - predictions = model.predict([ - "Turgid dialogue, feeble characterization - Harvey Keitel a judge?.", - "The worst movie in the history of cinema.", - "I come from Bulgaria where it 's almost impossible to have a tornado." - "Very, very afraid" - "This guy has done a great job with this movie!", - ]) + embedder = ImageEmbedder(backbone='swav-imagenet') + + # 2. Perform inference on an image file + embeddings = model.predict('path/to/image.png') print(predictions) -Or on a given dataset: +Or on an image tensor .. code-block:: python - # import our libraries - from flash import download_data - from flash.text import TextClassifier - - # 1. Download dataset, save it under 'data' dir - download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') + from flash.vision import ImageEmbedder - # 2. Load finetuned task - model = ImageClassifier.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/image_classification_model.pt") + # Load finetuned task + embedder = ImageEmbedder(backbone='swav-imagenet') - # 3. Perform inference from a csv file - predictions = model.predict("data/imdb/test.csv") + # 2. Perform inference on an image file + images = torch.rand(32, 3, 224, 224) + embeddings = model.predict(images) print(predictions) For more advanced inference options, see :ref:`predictions`. @@ -65,42 +53,13 @@ For more advanced inference options, see :ref:`predictions`. ********** Finetuning ********** - -Lets say you wanted to develope a model that could determine whether an image contains **ants** or **bees**, using the hymenoptera dataset. -Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.ImageClassificationData`. - -.. note:: The dataset contains ``train`` and ``validation`` folders, and then each folder contains a **bees** folder, with pictures of bees, and an **ants** folder with images of, you guessed it, ants. - -.. code-block:: - - hymenoptera_data - ├── train - │ ├── ants - │ │ ├── 0013035.jpg - │ │ ├── 1030023514_aad5c608f9.jpg - │ │ ... - │ └── bees - │ ├── 1092977343_cb42b38d62.jpg - │ ├── 1093831624_fb5fbe2308.jpg - │ ... - └── val - ├── ants - │ ├── 10308379_1b6c72e180.jpg - │ ├── 1053149811_f62a3410d3.jpg - │ ... - └── bees - ├── 1032546534_06907fe3b3.jpg - ├── 10870992_eebeeb3a12.jpg - ... - - -Now all we need is three lines of code to build to train our task! +To tailor this image embedder to your dataset, finetune first. .. code-block:: python - import flash - from flash import download_data - from flash.vision import ImageClassificationData, ImageClassifier + import flash + from flash import download_data + from flash.vision import ImageClassificationData, ImageClassifier # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') @@ -114,7 +73,7 @@ Now all we need is three lines of code to build to train our task! ) # 3. Build the model - model = ImageClassifier(num_classes=datamodule.num_classes) + embedder = ImageEmbedder() # 4. Create the trainer. Run once on data trainer = flash.Trainer(max_epochs=1) @@ -133,7 +92,7 @@ Now all we need is three lines of code to build to train our task! ********************* Changing the backbone ********************* -By default, we use a `ResNet-18 `_ for image classification. You can change the model run by the task by passing in a different backbone. +By default, we use the encoder from `Swav `_ pretrained on Imagenet via contrastive learning. You can change the model run by the task by passing in a different backbone. .. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! @@ -149,12 +108,33 @@ By default, we use a `ResNet-18 `_ for image c # 2. build the task task = ImageClassifier(num_classes=2, backbone="resnet34") -Available backbones: - -* resnet34 -* resnet50 -* resnet101 -* resnet152 +Backbones available + +.. list-table:: Backbones + :widths: 50 20 20 + :header-rows: 1 + + * - backbone + - dataset + - training method + * - resnet18 + - Imagenet + - supervised + * - resnet34 + - Imagenet + - supervised + * - resnet50 + - Imagenet + - supervised + * - resnet101 + - Imagenet + - supervised + * - resnet152 + - Imagenet + - supervised + * - swav-imagenet + - Imagenet + - self-supervised (clustering) ------ @@ -162,26 +142,13 @@ Available backbones: API reference ************* -.. _image_classifier: +.. _image_embedder_class: -ImageClassifier +ImageEmbedder --------------- -.. autoclass:: flash.vision.ImageClassifier +.. autoclass:: flash.vision.ImageEmbedder :members: :exclude-members: forward -.. _image_classification_data: - -ImageClassificationData ------------------------ - -.. autoclass:: flash.vision.ImageClassificationData - -.. automethod:: flash.vision.ImageClassificationData.from_filepaths - -.. automethod:: flash.vision.ImageClassificationData.from_folders - - - - +.. _image_embedder_data: diff --git a/flash/core/model.py b/flash/core/model.py index 51b1a87d12..53ae882d28 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -115,10 +115,11 @@ def predict( The post-processed model predictions """ - data_pipeline = data_pipeline or self.data_pipeline - batch = x if skip_collate_fn else data_pipeline.collate_fn(x) - batch_x, batch_y = batch if len(batch) == 2 else (batch, None) - predictions = self.forward(batch_x) + with torch.no_grad(): + data_pipeline = data_pipeline or self.data_pipeline + batch = x if skip_collate_fn else data_pipeline.collate_fn(x) + batch_x, batch_y = batch if len(batch) == 2 else (batch, None) + predictions = self.forward(batch_x) return data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x def configure_optimizers(self) -> torch.optim.Optimizer: diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 56ea298199..70fbb3f3dd 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -19,9 +19,8 @@ from torch import nn from torch.nn import functional as F from pytorch_lightning.utilities.distributed import rank_zero_warn -from pl_bolts.models.self_supervised import SimCLR -from flash import Task +from flash.core import Task from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline from flash.model_map import models @@ -57,34 +56,6 @@ class ImageEmbedder(Task): image = torch.rand(32, 3, 32, 32) embeddings = embedder(image) - Backbones available - - .. list-table:: Backbones - :widths: 50 20 20 - :header-rows: 1 - - * - backbone - - dataset - - training method - * - resnet18 - - Imagenet - - supervised - * - resnet34 - - Imagenet - - supervised - * - resnet50 - - Imagenet - - supervised - * - resnet101 - - Imagenet - - supervised - * - resnet152 - - Imagenet - - supervised - * - swav-imagenet - - Imagenet - - self-supervised (clustering) - """ def __init__( @@ -156,5 +127,5 @@ def default_pipeline() -> ImageClassificationDataPipeline: if __name__ == '__main__': embedder = ImageEmbedder(backbone='resnet50') image = torch.rand(32, 3, 128, 128) - embeddings = embedder(image) + embeddings = embedder.predict('/Users/williamfalcon/Desktop/abcd.jpeg') print(embeddings.shape) From b3bb46ddeef6096ea94bafe3ee95a962b6924575 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 08:48:23 +0000 Subject: [PATCH 03/16] add ImageEmbedder Pipeline --- flash/core/model.py | 1 + flash/model_map.py | 42 +++++++++---------- .../vision/embedding/image_embedder_model.py | 40 ++++++++++++++---- flash_examples/predict/image_embedder.py | 25 +++++++++++ 4 files changed, 79 insertions(+), 29 deletions(-) create mode 100644 flash_examples/predict/image_embedder.py diff --git a/flash/core/model.py b/flash/core/model.py index 53ae882d28..f2cb5771f2 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -115,6 +115,7 @@ def predict( The post-processed model predictions """ + self.eval() with torch.no_grad(): data_pipeline = data_pipeline or self.data_pipeline batch = x if skip_collate_fn else data_pipeline.collate_fn(x) diff --git a/flash/model_map.py b/flash/model_map.py index aff1628102..6ef61b1a5b 100644 --- a/flash/model_map.py +++ b/flash/model_map.py @@ -1,37 +1,37 @@ +from pytorch_lightning.utilities import _BOLTS_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _BOLTS_AVAILABLE: + from pl_bolts.models.self_supervised import SimCLR, SwAV + +DEFAULT_URLS = { + "SimCLR": 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt', + "SwAV": 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar', +} def load_model(name): if name == 'simclr-imagenet': return load_simclr_imagenet() - if name == 'swav-imagenet': + elif name == 'swav-imagenet': return load_swav_imagenet() + else: + raise MisconfigurationException("Currently, only `simclr-imagenet` and `swav-imagenet` are supported.") + def load_simclr_imagenet(): - from pl_bolts.models.self_supervised import SimCLR - weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt' - simclr = SimCLR.load_from_checkpoint(weight_path, strict=False) - - model_config = { - 'model': simclr.encoder, - 'emb_size': 2048 - } + simclr = SimCLR.load_from_checkpoint(DEFAULT_URLS["SimCLR"], strict=False) + + model_config = {'model': simclr.encoder, 'emb_size': 2048} return model_config def load_swav_imagenet(): - from pl_bolts.models.self_supervised import SwAV - - weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar' - swav = SwAV.load_from_checkpoint(weight_path, strict=True) - model_config = { - 'model': swav.model, - 'num_features': 3000 - } + swav = SwAV.load_from_checkpoint(DEFAULT_URLS["SwAV"], strict=True) + model_config = {'model': swav.model, 'num_features': 3000} return model_config -models = { - # 'simclr-imagenet': load_simclr_imagenet, - 'swav-imagenet': load_swav_imagenet -} + +models = {'simclr-imagenet': load_simclr_imagenet, 'swav-imagenet': load_swav_imagenet} diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 70fbb3f3dd..0e5a9744b9 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -11,18 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Sequence, Type, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union +import pytorch_lightning import torch import torchvision from pytorch_lightning.metrics import Accuracy +from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.nn import functional as F -from pytorch_lightning.utilities.distributed import rank_zero_warn from flash.core import Task -from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline +from flash.core.data import TaskDataPipeline +from flash.core.data.utils import _contains_any_tensor from flash.model_map import models +from flash.vision.classification.data import _default_valid_transforms, _pil_loader, ImageClassificationData _resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731 _resnet_feats = lambda model: model.fc.in_features # noqa: E731 @@ -36,6 +40,28 @@ } +class ImageEmbedderDataPipeline(TaskDataPipeline): + + def __init__(self, valid_transform: Optional[Callable] = _default_valid_transforms, loader: Callable = _pil_loader): + self._valid_transform = valid_transform + self._loader = loader + + def before_collate(self, samples: Any) -> Any: + if _contains_any_tensor(samples): + return samples + + if isinstance(samples, str): + samples = [samples] + + if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): + outputs = [] + for sample in samples: + output = self._loader(sample) + outputs.append(self._valid_transform(output)) + return outputs + raise MisconfigurationException("The samples should either be a tensor, a list of paths or a path.") + + class ImageEmbedder(Task): """Task that classifies images. @@ -98,7 +124,6 @@ def __init__( self.pooling = nn.Identity() self.head = nn.Identity() else: - self.pooling = nn.AdaptiveAvgPool2d((1, 1)), self.head = nn.Sequential( nn.Flatten(), nn.Linear(num_features, embedding_dim), @@ -113,15 +138,14 @@ def forward(self, x) -> Any: x = x[-1] if len(x.size()) == 4 and self.embedding_dim is not None: - x = self.pooling(x) + x = x.mean(-1).mean(-1) x = self.head(x) - x = x.view(x.size(0), -1) return x @staticmethod - def default_pipeline() -> ImageClassificationDataPipeline: - return ImageClassificationData.default_pipeline() + def default_pipeline() -> ImageEmbedderDataPipeline: + return ImageEmbedderDataPipeline() if __name__ == '__main__': diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py new file mode 100644 index 0000000000..317d26a221 --- /dev/null +++ b/flash_examples/predict/image_embedder.py @@ -0,0 +1,25 @@ +import torch + +from flash.vision import ImageEmbedder + +if __name__ == "__main__": + + # 1. Create an ImageEmbedder with swav + # Check out https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav + for backbone in ["resnet50", "swav-imagenet"]: + embedder = ImageEmbedder(backbone=backbone, embedding_dim=128) + + # 2. Generate an embedding from an image path. + embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') + + # 3. Assert dimension + assert embeddings.shape == torch.Size((1, 128)) + + # 4. Create a tensor random image + random_img = torch.randn(1, 3, 32, 32) + + # 5. Generate an embedding from this random image + embeddings = embedder.predict(random_img) + + # 6. Assert dimension + assert embeddings.shape == torch.Size((1, 128)) From 27dfbc3e076b8b73cea5b6c8d7fe8860a86d4092 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 08:59:52 +0000 Subject: [PATCH 04/16] clean pooling --- .../vision/embedding/image_embedder_model.py | 25 +++++++++------ flash_examples/predict/image_embedder.py | 31 ++++++++++--------- 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 0e5a9744b9..99b43e26ce 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -73,6 +73,7 @@ class ImageEmbedder(Task): optimizer: Optimizer to use for training and finetuning, defaults to `torch.optim.SGD`. metrics: Metrics to compute for training and evaluation. learning_rate: Learning rate to use for training, defaults to `1e-3` + pooling_fn: Function used to pool image to generate embeddings. (Default: torch.max) Example:: @@ -93,6 +94,7 @@ def __init__( optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()), learning_rate: float = 1e-3, + pooling_fn: Callable = torch.max ): super().__init__( model=None, @@ -105,6 +107,8 @@ def __init__( self.save_hyperparameters() self.backbone_name = backbone self.embedding_dim = embedding_dim + assert pooling_fn in [torch.mean, torch.max] + self.pooling_fn = pooling_fn if backbone in models: config = models[backbone]() @@ -121,7 +125,6 @@ def __init__( num_features = num_feats(backbone) if embedding_dim is None: - self.pooling = nn.Identity() self.head = nn.Identity() else: self.head = nn.Sequential( @@ -130,6 +133,15 @@ def __init__( ) rank_zero_warn('embedding_dim is not None. Remember to finetune first!') + def apply_pool(self, x): + if self.pooling_fn == torch.max: + x = self.pooling_fn(x, dim=-1)[0] + x = self.pooling_fn(x, dim=-1)[0] + else: + x = self.pooling_fn(x, dim=-1) + x = self.pooling_fn(x, dim=-1) + return x + def forward(self, x) -> Any: x = self.backbone(x) @@ -137,8 +149,8 @@ def forward(self, x) -> Any: if isinstance(x, tuple): x = x[-1] - if len(x.size()) == 4 and self.embedding_dim is not None: - x = x.mean(-1).mean(-1) + if x.dim() == 4 and self.embedding_dim is not None: + x = self.apply_pool(x) x = self.head(x) return x @@ -146,10 +158,3 @@ def forward(self, x) -> Any: @staticmethod def default_pipeline() -> ImageEmbedderDataPipeline: return ImageEmbedderDataPipeline() - - -if __name__ == '__main__': - embedder = ImageEmbedder(backbone='resnet50') - image = torch.rand(32, 3, 128, 128) - embeddings = embedder.predict('/Users/williamfalcon/Desktop/abcd.jpeg') - print(embeddings.shape) diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index 317d26a221..f929b6e664 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -1,25 +1,28 @@ import torch +from flash.core.data import download_data from flash.vision import ImageEmbedder if __name__ == "__main__": - # 1. Create an ImageEmbedder with swav - # Check out https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav - for backbone in ["resnet50", "swav-imagenet"]: - embedder = ImageEmbedder(backbone=backbone, embedding_dim=128) + # 1. Download the data + download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", 'data/') - # 2. Generate an embedding from an image path. - embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') + # 2. Create an ImageEmbedder with swav trained on imagenet. + # Check out SWAV: https://pytorch-lightning-bolts.readthedocs.io/en/latest/self_supervised_models.html#swav + embedder = ImageEmbedder(backbone="swav-imagenet", embedding_dim=128) - # 3. Assert dimension - assert embeddings.shape == torch.Size((1, 128)) + # 3. Generate an embedding from an image path. + embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') - # 4. Create a tensor random image - random_img = torch.randn(1, 3, 32, 32) + # 4. Assert dimension + assert embeddings.shape == torch.Size((1, 128)) - # 5. Generate an embedding from this random image - embeddings = embedder.predict(random_img) + # 5. Create a tensor random image + random_image = torch.randn(1, 3, 32, 32) - # 6. Assert dimension - assert embeddings.shape == torch.Size((1, 128)) + # 6. Generate an embedding from this random image + embeddings = embedder.predict(random_image) + + # 7. Assert dimension + assert embeddings.shape == torch.Size((1, 128)) From b42ed054bfcef68a4deda263bd695540ec66f3fc Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 1 Feb 2021 11:06:30 +0000 Subject: [PATCH 05/16] Update docs/source/reference/image_embedder.rst Co-authored-by: Jirka Borovec --- docs/source/reference/image_embedder.rst | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index f25ae72579..5c4a9a288e 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -32,15 +32,10 @@ Use the :class:`~flash.text.ImageEmbedder` pretrained model for inference on any embeddings = model.predict('path/to/image.png') print(predictions) -Or on an image tensor +Or on a random image tensor .. code-block:: python - from flash.vision import ImageEmbedder - - # Load finetuned task - embedder = ImageEmbedder(backbone='swav-imagenet') - # 2. Perform inference on an image file images = torch.rand(32, 3, 224, 224) embeddings = model.predict(images) From 2df9ddb78cb5378a4839c97f63b6e0ce0d647a69 Mon Sep 17 00:00:00 2001 From: chaton Date: Mon, 1 Feb 2021 11:08:46 +0000 Subject: [PATCH 06/16] Update docs/source/reference/image_embedder.rst Co-authored-by: Jirka Borovec --- docs/source/reference/image_embedder.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index 5c4a9a288e..33da522644 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -89,7 +89,9 @@ Changing the backbone ********************* By default, we use the encoder from `Swav `_ pretrained on Imagenet via contrastive learning. You can change the model run by the task by passing in a different backbone. -.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! +.. note:: + + When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! .. code-block:: python From 4259665cf33c82fd743b0fa481188d1b0b67601e Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 11:25:46 +0000 Subject: [PATCH 07/16] update on comments --- docs/source/reference/image_embedder.rst | 17 ++++----- flash/core/model.py | 18 +++++++-- flash/model_map.py | 38 +++++++++---------- .../vision/embedding/image_embedder_model.py | 6 +-- 4 files changed, 43 insertions(+), 36 deletions(-) diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index f25ae72579..c9e06ec21d 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -19,9 +19,9 @@ Inference The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](http://www.image-net.org/), a dataset of over 14 million images. -Use the :class:`~flash.text.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.text.ImageEmbedder.predict`: +Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.vision.ImageEmbedder.predict`: -.. code-block:: python +.. testcode:: python from flash.vision import ImageEmbedder @@ -34,7 +34,7 @@ Use the :class:`~flash.text.ImageEmbedder` pretrained model for inference on any Or on an image tensor -.. code-block:: python +.. testcode:: python from flash.vision import ImageEmbedder @@ -66,20 +66,19 @@ To tailor this image embedder to your dataset, finetune first. # 2. Load the data datamodule = ImageClassificationData.from_folders( - backbone="resnet18", - train_folder="data/hymenoptera_data/train/", - valid_folder="data/hymenoptera_data/val/", - test_folder="data/hymenoptera_data/test/", + train_folder="data/hymenoptera_data/train/", + valid_folder="data/hymenoptera_data/val/", + test_folder="data/hymenoptera_data/test/", ) # 3. Build the model - embedder = ImageEmbedder() + embedder = ImageEmbedder(backbone="resnet18") # 4. Create the trainer. Run once on data trainer = flash.Trainer(max_epochs=1) # 5. Train the model - trainer.finetune(model, datamodule=datamodule, unfreeze_milestones=(0, 1)) + trainer.finetune(model, datamodule=datamodule) # 6. Test the model trainer.test() diff --git a/flash/core/model.py b/flash/core/model.py index f2cb5771f2..fd3f547328 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl @@ -91,6 +92,17 @@ def test_step(self, batch: Any, batch_idx: int) -> None: output = self.step(batch, batch_idx) self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) + @property + @contextmanager + def predict_context(self): + try: + self.eval() + torch.set_grad_enabled(False) + yield + finally: + self.train() + torch.set_grad_enabled(True) + def predict( self, x: Any, @@ -115,13 +127,13 @@ def predict( The post-processed model predictions """ - self.eval() - with torch.no_grad(): + with self.predict_context: data_pipeline = data_pipeline or self.data_pipeline batch = x if skip_collate_fn else data_pipeline.collate_fn(x) batch_x, batch_y = batch if len(batch) == 2 else (batch, None) predictions = self.forward(batch_x) - return data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x + output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x + return output def configure_optimizers(self) -> torch.optim.Optimizer: return self.optimizer_cls(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate) diff --git a/flash/model_map.py b/flash/model_map.py index 6ef61b1a5b..dceecf332c 100644 --- a/flash/model_map.py +++ b/flash/model_map.py @@ -1,37 +1,33 @@ +from functools import partial + from pytorch_lightning.utilities import _BOLTS_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException if _BOLTS_AVAILABLE: from pl_bolts.models.self_supervised import SimCLR, SwAV -DEFAULT_URLS = { - "SimCLR": 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt', - "SwAV": 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/swav/swav_imagenet/swav_imagenet.pth.tar', -} - - -def load_model(name): - if name == 'simclr-imagenet': - return load_simclr_imagenet() - - elif name == 'swav-imagenet': - return load_swav_imagenet() - - else: - raise MisconfigurationException("Currently, only `simclr-imagenet` and `swav-imagenet` are supported.") - +ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" -def load_simclr_imagenet(): - simclr = SimCLR.load_from_checkpoint(DEFAULT_URLS["SimCLR"], strict=False) +def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"): + simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False) model_config = {'model': simclr.encoder, 'emb_size': 2048} return model_config -def load_swav_imagenet(): - swav = SwAV.load_from_checkpoint(DEFAULT_URLS["SwAV"], strict=True) +def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"): + swav = SwAV.load_from_checkpoint(path_or_url, strict=True) model_config = {'model': swav.model, 'num_features': 3000} return model_config -models = {'simclr-imagenet': load_simclr_imagenet, 'swav-imagenet': load_swav_imagenet} +_models = {'simclr-imagenet': load_simclr_imagenet, 'swav-imagenet': load_swav_imagenet} + + +def load_model(name): + if not _BOLTS_AVAILABLE: + raise MisconfigurationException("Bolts isn't installed. Please, use pip install `lightning-bolts`.") + if name in _models: + return _models[name]() + else: + raise MisconfigurationException("Currently, only `simclr-imagenet` and `swav-imagenet` are supported.") diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 99b43e26ce..0d70927046 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -25,7 +25,7 @@ from flash.core import Task from flash.core.data import TaskDataPipeline from flash.core.data.utils import _contains_any_tensor -from flash.model_map import models +from flash.model_map import _models from flash.vision.classification.data import _default_valid_transforms, _pil_loader, ImageClassificationData _resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731 @@ -110,8 +110,8 @@ def __init__( assert pooling_fn in [torch.mean, torch.max] self.pooling_fn = pooling_fn - if backbone in models: - config = models[backbone]() + if backbone in _models: + config = _models[backbone]() self.backbone = config['model'] num_features = config['num_features'] From 2d72438fc5c3f01ef5416cf0f9e0675aa98634fd Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 17:56:36 +0000 Subject: [PATCH 08/16] update on comments --- docs/source/reference/image_embedder.rst | 4 ++-- flash/core/model.py | 1 + flash/vision/embedding/image_embedder_model.py | 10 +++++----- flash/{ => vision/embedding}/model_map.py | 7 +++---- flash_examples/predict/image_embedder.py | 8 ++++---- 5 files changed, 15 insertions(+), 15 deletions(-) rename flash/{ => vision/embedding}/model_map.py (85%) diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index 426e61e9bf..1fc3a9330a 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -79,7 +79,7 @@ To tailor this image embedder to your dataset, finetune first. trainer.test() # 7. Save it! - trainer.save_checkpoint("image_classification_model.pt") + trainer.save_checkpoint("image_embedder_model.pt") ------ @@ -88,7 +88,7 @@ Changing the backbone ********************* By default, we use the encoder from `Swav `_ pretrained on Imagenet via contrastive learning. You can change the model run by the task by passing in a different backbone. -.. note:: +.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! diff --git a/flash/core/model.py b/flash/core/model.py index fd3f547328..86b0378d8a 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -95,6 +95,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: @property @contextmanager def predict_context(self): + # todo: Move this within Lightning. try: self.eval() torch.set_grad_enabled(False) diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 0d70927046..9949dda4be 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -25,8 +25,8 @@ from flash.core import Task from flash.core.data import TaskDataPipeline from flash.core.data.utils import _contains_any_tensor -from flash.model_map import _models from flash.vision.classification.data import _default_valid_transforms, _pil_loader, ImageClassificationData +from flash.vision.embedding.model_map import _load_model, _models _resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731 _resnet_feats = lambda model: model.fc.in_features # noqa: E731 @@ -87,9 +87,9 @@ class ImageEmbedder(Task): def __init__( self, - embedding_dim=None, - backbone="swav-imagenet", - pretrained=True, + embedding_dim: Optional[int] = None, + backbone: str = "swav-imagenet", + pretrained: bool = True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, metrics: Union[Callable, Mapping, Sequence, None] = (Accuracy()), @@ -111,7 +111,7 @@ def __init__( self.pooling_fn = pooling_fn if backbone in _models: - config = _models[backbone]() + config = _load_model(backbone) self.backbone = config['model'] num_features = config['num_features'] diff --git a/flash/model_map.py b/flash/vision/embedding/model_map.py similarity index 85% rename from flash/model_map.py rename to flash/vision/embedding/model_map.py index dceecf332c..84fd4d1e02 100644 --- a/flash/model_map.py +++ b/flash/vision/embedding/model_map.py @@ -24,10 +24,9 @@ def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/ _models = {'simclr-imagenet': load_simclr_imagenet, 'swav-imagenet': load_swav_imagenet} -def load_model(name): +def _load_model(name): if not _BOLTS_AVAILABLE: - raise MisconfigurationException("Bolts isn't installed. Please, use pip install `lightning-bolts`.") + raise MisconfigurationException("Bolts isn't installed. Please, use ``pip install lightning-bolts``.") if name in _models: return _models[name]() - else: - raise MisconfigurationException("Currently, only `simclr-imagenet` and `swav-imagenet` are supported.") + raise MisconfigurationException("Currently, only `simclr-imagenet` and `swav-imagenet` are supported.") diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index f929b6e664..1bb429794a 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -15,8 +15,8 @@ # 3. Generate an embedding from an image path. embeddings = embedder.predict('data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg') - # 4. Assert dimension - assert embeddings.shape == torch.Size((1, 128)) + # 4. Print embeddings shape + print(embeddings.shape) # 5. Create a tensor random image random_image = torch.randn(1, 3, 32, 32) @@ -24,5 +24,5 @@ # 6. Generate an embedding from this random image embeddings = embedder.predict(random_image) - # 7. Assert dimension - assert embeddings.shape == torch.Size((1, 128)) + # 7. Print embeddings shape + print(embeddings.shape) From 00f10df00ddb199cbaa9f9a693330bf46d130423 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 18:06:32 +0000 Subject: [PATCH 09/16] removing todo --- flash/core/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flash/core/model.py b/flash/core/model.py index 86b0378d8a..fd3f547328 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -95,7 +95,6 @@ def test_step(self, batch: Any, batch_idx: int) -> None: @property @contextmanager def predict_context(self): - # todo: Move this within Lightning. try: self.eval() torch.set_grad_enabled(False) From 1c98f8d79e14c087f64470e2ba81b041254bcabb Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 18:25:20 +0000 Subject: [PATCH 10/16] change to decorator --- flash/core/model.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index fd3f547328..0eb86a176c 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -22,6 +22,25 @@ from flash.core.utils import get_callable_dict +def predict_context(func: Callable) -> Callable: + """ + This decorator is used as context manager + to put model in eval mode before running predict and reset to train after. + """ + + def wrapper(self, *args, **kwargs) -> Any: + self.eval() + torch.set_grad_enabled(False) + + result = func(self, *args, **kwargs) + + self.train() + torch.set_grad_enabled(True) + return result + + return wrapper + + class Task(pl.LightningModule): """A general Task. @@ -92,17 +111,7 @@ def test_step(self, batch: Any, batch_idx: int) -> None: output = self.step(batch, batch_idx) self.log_dict({f"test_{k}": v for k, v in output["logs"].items()}, on_step=False, on_epoch=True, prog_bar=True) - @property - @contextmanager - def predict_context(self): - try: - self.eval() - torch.set_grad_enabled(False) - yield - finally: - self.train() - torch.set_grad_enabled(True) - + @predict_context def predict( self, x: Any, @@ -127,12 +136,11 @@ def predict( The post-processed model predictions """ - with self.predict_context: - data_pipeline = data_pipeline or self.data_pipeline - batch = x if skip_collate_fn else data_pipeline.collate_fn(x) - batch_x, batch_y = batch if len(batch) == 2 else (batch, None) - predictions = self.forward(batch_x) - output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x + data_pipeline = data_pipeline or self.data_pipeline + batch = x if skip_collate_fn else data_pipeline.collate_fn(x) + batch_x, batch_y = batch if len(batch) == 2 else (batch, None) + predictions = self.forward(batch_x) + output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x return output def configure_optimizers(self) -> torch.optim.Optimizer: From 15b39daa78b2d4cc59ca94921a2d5afb25b4eef8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 18:26:30 +0000 Subject: [PATCH 11/16] add comment --- flash/vision/embedding/image_embedder_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 9949dda4be..af82dc7fab 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -135,6 +135,7 @@ def __init__( def apply_pool(self, x): if self.pooling_fn == torch.max: + # torch.max also returns argmax x = self.pooling_fn(x, dim=-1)[0] x = self.pooling_fn(x, dim=-1)[0] else: From 5587b41d1cbd9410003f2a214952c2d922f1c9d4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 18:38:31 +0000 Subject: [PATCH 12/16] add `grad_enabled` to predict --- flash/core/model.py | 10 +++++++++- flash_examples/predict/image_embedder.py | 8 ++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 951fcb2129..0b574e9677 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -30,7 +30,7 @@ def predict_context(func: Callable) -> Callable: def wrapper(self, *args, **kwargs) -> Any: self.eval() - torch.set_grad_enabled(False) + torch.set_grad_enabled(kwargs.get("grad_enabled", False)) result = func(self, *args, **kwargs) @@ -119,19 +119,27 @@ def predict( skip_collate_fn: bool = False, dataloader_idx: Optional[int] = None, data_pipeline: Optional[DataPipeline] = None, + grad_enabled: bool = False, ) -> Any: """ Predict function for raw data or processed data Args: + x: Input to predict. Can be raw data or processed data. + batch_idx: Batch index + dataloader_idx: Dataloader index + skip_collate_fn: Whether to skip the collate step. this is required when passing data already processed for the model, for example, data from a dataloader + data_pipeline: Use this to override the current data pipeline + grad_enabled: Wether to activate gradients. + Returns: The post-processed model predictions diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index 1bb429794a..734e9d7fdf 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -19,10 +19,10 @@ print(embeddings.shape) # 5. Create a tensor random image - random_image = torch.randn(1, 3, 32, 32) + random_image = torch.randn(1, 3, 32, 32, requires_grad=True) - # 6. Generate an embedding from this random image - embeddings = embedder.predict(random_image) + # 6. Generate an embedding from this random image. Can be used + embeddings = embedder.predict(random_image, grad_enabled=True) # 7. Print embeddings shape - print(embeddings.shape) + print(embeddings.shape, embeddings.requires_grad) From 6250ce9f414a4b100ca30da97d0b5b8230e02afa Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 18:51:22 +0000 Subject: [PATCH 13/16] bolts break the doc --- flash/vision/embedding/model_map.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/flash/vision/embedding/model_map.py b/flash/vision/embedding/model_map.py index 84fd4d1e02..5098954053 100644 --- a/flash/vision/embedding/model_map.py +++ b/flash/vision/embedding/model_map.py @@ -1,10 +1,12 @@ -from functools import partial +from contextlib import suppress +from typing import Type from pytorch_lightning.utilities import _BOLTS_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException if _BOLTS_AVAILABLE: - from pl_bolts.models.self_supervised import SimCLR, SwAV + with suppress(TypeError): + from pl_bolts.models.self_supervised import SimCLR, SwAV ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" From 0497fd657d5d557c5270d50e4a23d061e0675fc3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 19:09:38 +0000 Subject: [PATCH 14/16] update --- docs/source/reference/image_embedder.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index 1fc3a9330a..5092287b10 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -26,7 +26,7 @@ Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on a from flash.vision import ImageEmbedder # Load finetuned task - embedder = ImageEmbedder(backbone='swav-imagenet') + embedder = ImageEmbedder(backbone='resnet18') # 2. Perform inference on an image file embeddings = model.predict('path/to/image.png') @@ -36,6 +36,8 @@ Or on a random image tensor .. testcode:: python + import torch + # 2. Perform inference on an image file images = torch.rand(32, 3, 224, 224) embeddings = model.predict(images) From 3e6a75fe32b8d7d502907e14d15121e553eb3b63 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 19:19:04 +0000 Subject: [PATCH 15/16] resolve doc tests --- docs/source/reference/image_embedder.rst | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/docs/source/reference/image_embedder.rst b/docs/source/reference/image_embedder.rst index 5092287b10..de40b1807b 100644 --- a/docs/source/reference/image_embedder.rst +++ b/docs/source/reference/image_embedder.rst @@ -21,7 +21,7 @@ The :class:`~flash.vision.ImageEmbedder` is already pre-trained on [ImageNet](ht Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on any image tensor or image path using :func:`~flash.vision.ImageEmbedder.predict`: -.. testcode:: python +.. code-block:: python from flash.vision import ImageEmbedder @@ -34,14 +34,13 @@ Use the :class:`~flash.vision.ImageEmbedder` pretrained model for inference on a Or on a random image tensor -.. testcode:: python +.. code-block:: python + # 2. Perform inference on an image file import torch - - # 2. Perform inference on an image file - images = torch.rand(32, 3, 224, 224) - embeddings = model.predict(images) - print(predictions) + images = torch.rand(32, 3, 224, 224) + embeddings = model.predict(images) + print(predictions) For more advanced inference options, see :ref:`predictions`. From 356d02dd0db0afb5506ca9d6eb78017d221ca366 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Feb 2021 19:29:35 +0000 Subject: [PATCH 16/16] remove grad enabled to False --- flash/core/model.py | 5 +---- flash_examples/predict/image_embedder.py | 8 ++++---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 0b574e9677..7c4e2a2f1b 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -30,7 +30,7 @@ def predict_context(func: Callable) -> Callable: def wrapper(self, *args, **kwargs) -> Any: self.eval() - torch.set_grad_enabled(kwargs.get("grad_enabled", False)) + torch.set_grad_enabled(False) result = func(self, *args, **kwargs) @@ -119,7 +119,6 @@ def predict( skip_collate_fn: bool = False, dataloader_idx: Optional[int] = None, data_pipeline: Optional[DataPipeline] = None, - grad_enabled: bool = False, ) -> Any: """ Predict function for raw data or processed data @@ -138,8 +137,6 @@ def predict( data_pipeline: Use this to override the current data pipeline - grad_enabled: Wether to activate gradients. - Returns: The post-processed model predictions diff --git a/flash_examples/predict/image_embedder.py b/flash_examples/predict/image_embedder.py index 734e9d7fdf..3463258a12 100644 --- a/flash_examples/predict/image_embedder.py +++ b/flash_examples/predict/image_embedder.py @@ -19,10 +19,10 @@ print(embeddings.shape) # 5. Create a tensor random image - random_image = torch.randn(1, 3, 32, 32, requires_grad=True) + random_image = torch.randn(1, 3, 32, 32) - # 6. Generate an embedding from this random image. Can be used - embeddings = embedder.predict(random_image, grad_enabled=True) + # 6. Generate an embedding from this random image. + embeddings = embedder.predict(random_image) # 7. Print embeddings shape - print(embeddings.shape, embeddings.requires_grad) + print(embeddings.shape)