From 484885138f2113bcb64a031b066564a99fbbe53b Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 3 Feb 2021 12:15:37 +0530 Subject: [PATCH 1/2] add backbones for image embedding model --- flash/vision/backbones.py | 48 +++++++++++++++++++ flash/vision/classification/model.py | 2 +- .../vision/embedding/image_embedder_model.py | 21 +------- 3 files changed, 51 insertions(+), 20 deletions(-) create mode 100644 flash/vision/backbones.py diff --git a/flash/vision/backbones.py b/flash/vision/backbones.py new file mode 100644 index 00000000000..aad2f4aaf61 --- /dev/null +++ b/flash/vision/backbones.py @@ -0,0 +1,48 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch.nn as nn +import torchvision +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: + + model = getattr(torchvision.models, model_name, None) + if model is None: + raise MisconfigurationException(f"{model_name} is not supported by torchvision") + + if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]: + model = model(pretrained=pretrained) + backbone = model.features + num_features = model.classifier[-1].in_features + return backbone, num_features + + elif model_name in [ + "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d" + ]: + model = model(pretrained=pretrained) + # remove the last two layers & turn it into a Sequential model + backbone = nn.Sequential(*list(model.children())[:-2]) + num_features = model.fc.in_features + return backbone, num_features + + elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]: + model = model(pretrained=pretrained) + backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) + num_features = model.classifier.in_features + return backbone, num_features + + raise ValueError(f"{model_name} is not supported yet.") diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 9e7858bf1ae..4a77264b82d 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -19,7 +19,7 @@ from torch.nn import functional as F from flash.core.classification import ClassificationTask -from flash.vision.classification.backbones import torchvision_backbone_and_num_features +from flash.vision.backbones import torchvision_backbone_and_num_features from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 05e1fc8133f..6648461ad84 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -24,20 +24,10 @@ from flash.core import Task from flash.core.data import TaskDataPipeline from flash.core.data.utils import _contains_any_tensor +from flash.vision.classification.backbones import torchvision_backbone_and_num_features from flash.vision.classification.data import _default_valid_transforms, _pil_loader from flash.vision.embedding.model_map import _load_model, _models -_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731 -_resnet_feats = lambda model: model.fc.in_features # noqa: E731 - -_backbones = { - "resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats), - "resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats), - "resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats), - "resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats), - "resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats), -} - class ImageEmbedderDataPipeline(TaskDataPipeline): @@ -113,15 +103,8 @@ def __init__( config = _load_model(backbone) self.backbone = config['model'] num_features = config['num_features'] - - elif backbone not in _backbones: - raise NotImplementedError(f"Backbone {backbone} is not yet supported") - else: - backbone_fn, split, num_feats = _backbones[backbone] - backbone = backbone_fn(pretrained=pretrained) - self.backbone = split(backbone) - num_features = num_feats(backbone) + self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained) if embedding_dim is None: self.head = nn.Identity() From ebfd23b4db26acee1019f8e14a934cbecebbe131 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Wed, 3 Feb 2021 12:19:21 +0530 Subject: [PATCH 2/2] move backbones to root vision folder --- flash/vision/classification/backbones.py | 48 ------------------- .../vision/embedding/image_embedder_model.py | 2 +- 2 files changed, 1 insertion(+), 49 deletions(-) delete mode 100644 flash/vision/classification/backbones.py diff --git a/flash/vision/classification/backbones.py b/flash/vision/classification/backbones.py deleted file mode 100644 index aad2f4aaf61..00000000000 --- a/flash/vision/classification/backbones.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Tuple - -import torch.nn as nn -import torchvision -from pytorch_lightning.utilities.exceptions import MisconfigurationException - - -def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - - model = getattr(torchvision.models, model_name, None) - if model is None: - raise MisconfigurationException(f"{model_name} is not supported by torchvision") - - if model_name in ["mobilenet_v2", "vgg11", "vgg13", "vgg16", "vgg19"]: - model = model(pretrained=pretrained) - backbone = model.features - num_features = model.classifier[-1].in_features - return backbone, num_features - - elif model_name in [ - "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d" - ]: - model = model(pretrained=pretrained) - # remove the last two layers & turn it into a Sequential model - backbone = nn.Sequential(*list(model.children())[:-2]) - num_features = model.fc.in_features - return backbone, num_features - - elif model_name in ["densenet121", "densenet169", "densenet161", "densenet161"]: - model = model(pretrained=pretrained) - backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) - num_features = model.classifier.in_features - return backbone, num_features - - raise ValueError(f"{model_name} is not supported yet.") diff --git a/flash/vision/embedding/image_embedder_model.py b/flash/vision/embedding/image_embedder_model.py index 6648461ad84..25c759ff70d 100644 --- a/flash/vision/embedding/image_embedder_model.py +++ b/flash/vision/embedding/image_embedder_model.py @@ -24,7 +24,7 @@ from flash.core import Task from flash.core.data import TaskDataPipeline from flash.core.data.utils import _contains_any_tensor -from flash.vision.classification.backbones import torchvision_backbone_and_num_features +from flash.vision.backbones import torchvision_backbone_and_num_features from flash.vision.classification.data import _default_valid_transforms, _pil_loader from flash.vision.embedding.model_map import _load_model, _models