Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge a6489c4 into 2d0e6f8
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored Feb 3, 2021
2 parents 2d0e6f8 + a6489c4 commit 020a272
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr
>>> torchvision_backbone_and_num_features('densenet121') # doctest: +ELLIPSIS
(Sequential(...), 1024)
"""

model = getattr(torchvision.models, model_name, None)
if model is None:
raise MisconfigurationException(f"{model_name} is not supported by torchvision")
Expand Down
2 changes: 1 addition & 1 deletion flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import functional as F

from flash.core.classification import ClassificationTask
from flash.vision.classification.backbones import torchvision_backbone_and_num_features
from flash.vision.backbones import torchvision_backbone_and_num_features
from flash.vision.classification.data import ImageClassificationData, ImageClassificationDataPipeline


Expand Down
21 changes: 2 additions & 19 deletions flash/vision/embedding/image_embedder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,10 @@
from flash.core import Task
from flash.core.data import TaskDataPipeline
from flash.core.data.utils import _contains_any_tensor
from flash.vision.backbones import torchvision_backbone_and_num_features
from flash.vision.classification.data import _default_valid_transforms, _pil_loader
from flash.vision.embedding.model_map import _load_bolts_model, _models

_resnet_backbone = lambda model: nn.Sequential(*list(model.children())[:-2]) # noqa: E731
_resnet_feats = lambda model: model.fc.in_features # noqa: E731

_backbones = {
"resnet18": (torchvision.models.resnet18, _resnet_backbone, _resnet_feats),
"resnet34": (torchvision.models.resnet34, _resnet_backbone, _resnet_feats),
"resnet50": (torchvision.models.resnet50, _resnet_backbone, _resnet_feats),
"resnet101": (torchvision.models.resnet101, _resnet_backbone, _resnet_feats),
"resnet152": (torchvision.models.resnet152, _resnet_backbone, _resnet_feats),
}


class ImageEmbedderDataPipeline(TaskDataPipeline):
"""
Expand Down Expand Up @@ -129,15 +119,8 @@ def __init__(
config = _load_bolts_model(backbone)
self.backbone = config['model']
num_features = config['num_features']

elif backbone not in _backbones:
raise NotImplementedError(f"Backbone {backbone} is not yet supported")

else:
backbone_fn, split, num_feats = _backbones[backbone]
backbone = backbone_fn(pretrained=pretrained)
self.backbone = split(backbone)
num_features = num_feats(backbone)
self.backbone, num_features = torchvision_backbone_and_num_features(backbone, pretrained)

if embedding_dim is None:
self.head = nn.Identity()
Expand Down

0 comments on commit 020a272

Please sign in to comment.