Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
add swav and simclr models to imageclassifier + backbone reorg (#68)
Browse files Browse the repository at this point in the history
* add swav and simclr models to imageclassifier

* pep8

* yapf

* reorg

* pep8

* isort

* fix doctest

* fix pytest

* simclr fix

* fix doctest

* Remove suppress

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
teddykoker and carmocca authored Feb 5, 2021
1 parent aeb0063 commit cffcbd9
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 68 deletions.
68 changes: 62 additions & 6 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,67 @@
# limitations under the License.
from typing import Tuple

import torch.nn as nn
import torchvision
from pytorch_lightning.utilities import _BOLTS_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn as nn

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV

ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com"

MOBILENET_MODELS = ["mobilenet_v2"]
VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"]
RESNET_MODELS = ["resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"]
DENSENET_MODELS = ["densenet121", "densenet169", "densenet161", "densenet161"]
TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS

BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]


def backbone_and_num_features(model_name: str, *args, **kwargs) -> Tuple[nn.Module, int]:
if model_name in BOLTS_MODELS:
return bolts_backbone_and_num_features(model_name)

if model_name in TORCHVISION_MODELS:
return torchvision_backbone_and_num_features(model_name, *args, **kwargs)

raise ValueError(f"{model_name} is not supported yet.")


def bolts_backbone_and_num_features(model_name: str) -> Tuple[nn.Module, int]:
"""
>>> bolts_backbone_and_num_features('simclr-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 2048)
>>> bolts_backbone_and_num_features('swav-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 3000)
"""

# TODO: maybe we should plain pytorch weights so we don't need to rely on bolts to load these
# also mabye just use torchhub for the ssl lib
def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"):
simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
return backbone, 2048

def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"):
swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(swav.model.children())[:-2])
return backbone, 3000

models = {
'simclr-imagenet': load_simclr_imagenet,
'swav-imagenet': load_swav_imagenet,
}
if not _BOLTS_AVAILABLE:
raise MisconfigurationException("Bolts isn't installed. Please, use ``pip install lightning-bolts``.")
if model_name in models:
return models[model_name]()

raise ValueError(f"{model_name} is not supported yet.")


def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
Expand All @@ -31,22 +89,20 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr
if model is None:
raise MisconfigurationException(f"{model_name} is not supported by torchvision")

if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]:
if model_name in MOBILENET_MODELS + VGG_MODELS:
model = model(pretrained=pretrained)
backbone = model.features
num_features = model.classifier[-1].in_features
return backbone, num_features

elif model_name in [
"resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d"
]:
elif model_name in RESNET_MODELS:
model = model(pretrained=pretrained)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]:
elif model_name in DENSENET_MODELS:
model = model(pretrained=pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
Expand Down
4 changes: 2 additions & 2 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import functional as F

from flash.core.classification import ClassificationTask
from flash.vision.backbones import torchvision_backbone_and_num_features
from flash.vision.backbones import backbone_and_num_features
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline


Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(

self.save_hyperparameters()

self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)
self.backbone, num_features = backbone_and_num_features(backbone, pretrained)

self.head = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
Expand Down
10 changes: 2 additions & 8 deletions flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from flash.core import Task
from flash.core.data import TaskDataPipeline
from flash.core.data.utils import _contains_any_tensor
from flash.vision.backbones import torchvision_backbone_and_num_features
from flash.vision.backbones import backbone_and_num_features
from flash.vision.classification.data import _default_valid_transforms, _pil_loader
from flash.vision.embedding.model_map import _load_bolts_model, _models


class ImageEmbedderDataPipeline(TaskDataPipeline):
Expand Down Expand Up @@ -115,12 +114,7 @@ def __init__(
assert pooling_fn in [torch.mean, torch.max]
self.pooling_fn = pooling_fn

if backbone in _models:
config = _load_bolts_model(backbone)
self.backbone = config['model']
num_features = config['num_features']
else:
self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)
self.backbone, num_features = backbone_and_num_features(backbone, pretrained)

if embedding_dim is None:
self.head = nn.Identity()
Expand Down
49 changes: 0 additions & 49 deletions flash/vision/embedding/model_map.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/vision/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_init_train(tmpdir, backbone):


def test_non_existent_backbone():
with pytest.raises(MisconfigurationException):
with pytest.raises(ValueError):
ImageClassifier(2, "i am never going to implement this lol")


Expand Down
4 changes: 2 additions & 2 deletions tests/vision/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.
import pytest

from flash.vision.embedding.model_map import _load_bolts_model
from flash.vision.backbones import bolts_backbone_and_num_features


@pytest.mark.parametrize("name", ['simclr-imagenet', 'swav-imagenet'])
def test_load_bolts(name):
_load_bolts_model(name)
bolts_backbone_and_num_features(name)

0 comments on commit cffcbd9

Please sign in to comment.