Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add backbones for image embedding model #63

Merged
merged 4 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr
>>> torchvision_backbone_and_num_features('densenet121') # doctest: +ELLIPSIS
(Sequential(...), 1024)
"""

model = getattr(torchvision.models, model_name, None)
if model is None:
raise MisconfigurationException(f"{model_name} is not supported by torchvision")
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import functional as F

from flash.core.classification import ClassificationTask
from flash.vision.classification.backbones import torchvision_backbone_and_num_features
from flash.vision.backbones import torchvision_backbone_and_num_features
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline


Expand Down
21 changes: 2 additions & 19 deletions flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,10 @@
from flash.core import Task
from flash.core.data import TaskDataPipeline
from flash.core.data.utils import _contains_any_tensor
from flash.vision.backbones import torchvision_backbone_and_num_features
from flash.vision.classification.data import _default_valid_transforms, _pil_loader
from flash.vision.embedding.model_map import _load_bolts_model, _models

_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731
_resnet_feats = lambda model: model.fc.in_features # noqa: E731

_backbones = {
"resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats),
"resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats),
"resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats),
"resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats),
"resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats),
}


class ImageEmbedderDataPipeline(TaskDataPipeline):
"""
Expand Down Expand Up @@ -129,15 +119,8 @@ def __init__(
config = _load_bolts_model(backbone)
self.backbone = config['model']
num_features = config['num_features']

elif backbone not in _backbones:
raise NotImplementedError(f"Backbone {backbone} is not yet supported")

else:
backbone_fn, split, num_feats = _backbones[backbone]
backbone = backbone_fn(pretrained=pretrained)
self.backbone = split(backbone)
num_features = num_feats(backbone)
self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)

if embedding_dim is None:
self.head = nn.Identity()
Expand Down