Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

timm integration #196

Merged
merged 14 commits into from
Apr 6, 2021
25 changes: 25 additions & 0 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@

import torchvision
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

if _module_available('timm'):
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved
import timm

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV

Expand Down Expand Up @@ -70,6 +74,9 @@ def backbone_and_num_features(
if model_name in TORCHVISION_MODELS:
return torchvision_backbone_and_num_features(model_name, pretrained)

if model_name in timm.list_models():
return timm_backbone_and_num_features(model_name, pretrained)

raise ValueError(f"{model_name} is not supported yet.")


Expand Down Expand Up @@ -140,3 +147,21 @@ def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = Tr
return backbone, num_features

raise ValueError(f"{model_name} is not supported yet.")


def timm_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
>>> timm_backbone_and_num_features('resnet18') # doctest: +ELLIPSIS
(ResNet(...), 512)
>>> timm_backbone_and_num_features('mobilenetv3_large_100') # doctest: +ELLIPSIS
(MobileNetV3(...), 1280)
"""

if model_name in timm.list_models():
backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why num_classes = 0 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is suggested to create a model without a classification head in the timm docs. https://rwightman.github.io/pytorch-image-models/feature_extraction/#create-with-no-classifier-and-pooling

global_pool='')
num_features = backbone.num_features
return backbone, num_features

raise ValueError(
f"{model_name} is not supported in timm yet. https://rwightman.github.io/pytorch-image-models/models/")
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ sentencepiece>=0.1.95
lightning-bolts==0.3.2 # todo: we shall align with proper release
filelock # comes with 3rd-party dependency
pycocotools>=2.0.2 ; python_version >= "3.7"
timm~=0.4.5
kaushikb11 marked this conversation as resolved.
Show resolved Hide resolved