Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add backbone API #204

Merged
merged 23 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions flash/core/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from collections import defaultdict
from functools import partial
from types import FunctionType
from typing import Callable, Dict, Mapping, Optional, Union

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.nn import Module


class FlashRegistry(Dict):
"""
This class is used to register function or partial to a registry:

Example::

backbones = FlashRegistry("backbones")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@backbones.register_function()
def my_model(nc_input=5, nc_output=6):
return nn.Linear(nc_input, nc_output), nc_input, nc_output

mlp, nc_input, nc_output = backbones.get("my_model")(nc_output=7)

backbones.register_function(my_model, name="cho")
assert backbones.get("cho")

"""

def __init__(self, registry_name: str, verbose: bool = False):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._registry_name = registry_name
self._registered_functions: Mapping[str, Callable] = defaultdict()
self._verbose = verbose

def __len__(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return len(self._registered_functions)

def __contains__(self, key):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self._registered_functions.get(key, None)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
format_str = self.__class__.__name__ + \
f'(name={self._registry_name}, ' \
f'registered_items={dict(**self._registered_functions)})'
return format_str

@property
def name(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self._registry_name

@property
def registered_funcs(self):
return self._registered_functions

def __getitem__(self, key: str) -> Optional[Callable]:
return self.get(key)

def get(self, key: str) -> Optional[Callable]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if key in self._registered_functions:
fn = self._registered_functions[key]
return fn
else:
raise MisconfigurationException(f"Key: {key} is not in {self.__repr__()}")

def _register_function(self, fn: Callable, name: Optional[str] = None):
if not isinstance(fn, FunctionType) and not isinstance(fn, partial):
raise MisconfigurationException("``register_function`` should be used with a function")

name = name or fn.__name__

self._registered_functions[name] = fn
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def register_function(self, fn: Optional[Callable] = None, name: Optional[str] = None) -> Callable:
"""Register a callable
"""
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if fn is not None:
if self._verbose:
print(f"Registering: {fn} {name}")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._register_function(fn=fn, name=name)
return fn

# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')

def _register(cls):
self._register_function(fn=cls, name=name)
return cls

return _register


BACKBONES_REGISTRY = FlashRegistry("backbones")
tchaton marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions flash/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
_KORNIA_AVAILABLE = _module_available("kornia")
_COCO_AVAILABLE = _module_available("pycocotools")
_TIMM_AVAILABLE = _module_available("timm")
_TORCHVISION_AVAILABLE = _module_available("torchvision")
201 changes: 79 additions & 122 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Tuple

import torchvision
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn as nn
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone

from flash.utils.imports import _TIMM_AVAILABLE
from flash.core.registry import BACKBONES_REGISTRY
from flash.utils.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE

if _TIMM_AVAILABLE:
import timm

if _TORCHVISION_AVAILABLE:
import torchvision

if _BOLTS_AVAILABLE:
from pl_bolts.models.self_supervised import SimCLR, SwAV

Expand All @@ -38,125 +42,78 @@
BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]


def backbone_and_num_features(
model_name: str,
fpn: bool = False,
pretrained: bool = True,
trainable_backbone_layers: int = 3,
**kwargs
) -> Tuple[nn.Module, int]:
"""
Args:
model_name: backbone supported by `torchvision` and `bolts`
fpn: If True, creates a Feature Pyramind Network on top of Resnet based CNNs.
pretrained: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block.

>>> backbone_and_num_features('mobilenet_v2') # doctest: +ELLIPSIS
(Sequential(...), 1280)
>>> backbone_and_num_features('resnet50', fpn=True) # doctest: +ELLIPSIS
(BackboneWithFPN(...), 256)
>>> backbone_and_num_features('swav-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 2048)
"""
if fpn:
if model_name in RESNET_MODELS:
@BACKBONES_REGISTRY.register_function(name="simclr-imagenet")
def load_simclr_imagenet(
path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt", **__
):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
return backbone, 2048


@BACKBONES_REGISTRY.register_function(name="swav-imagenet")
def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar", **__):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(swav.model.children())[:-2])
return backbone, 2048


if _TORCHVISION_AVAILABLE:

for model_name in MOBILENET_MODELS + VGG_MODELS:

def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model = getattr(torchvision.models, model_name, None)(pretrained)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
backbone = model.features
num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features
return backbone, num_features

BACKBONES_REGISTRY.register_function(fn=partial(_fn_mobilenet_vgg, model_name), name=model_name)

for model_name in RESNET_MODELS:

def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

BACKBONES_REGISTRY.register_function(fn=partial(_fn_resnet, model_name), name=model_name)

def _fn_resnet_fpn(model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs) -> Tuple[nn.Module, int]:
backbone = resnet_fpn_backbone(
model_name, pretrained=pretrained, trainable_layers=trainable_backbone_layers, **kwargs
model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs
)
return backbone, 256

BACKBONES_REGISTRY.register_function(fn=partial(_fn_resnet_fpn, model_name), name=f"{model_name}-fpn")

for model_name in DENSENET_MODELS:

def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features

BACKBONES_REGISTRY.register_function(fn=partial(_fn_densenet, model_name), name=model_name)

if _TIMM_AVAILABLE:
for model_name in timm.list_models():

def _fn_timm(model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '') -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
fpn_out_channels = 256
return backbone, fpn_out_channels
else:
rank_zero_warn(f"{model_name} backbone is not supported with `fpn=True`, `fpn` won't be added.")

if model_name in BOLTS_MODELS:
return bolts_backbone_and_num_features(model_name)

if model_name in TORCHVISION_MODELS:
return torchvision_backbone_and_num_features(model_name, pretrained)

if _TIMM_AVAILABLE and model_name in timm.list_models():
return timm_backbone_and_num_features(model_name, pretrained)

raise ValueError(f"{model_name} is not supported yet.")


def bolts_backbone_and_num_features(model_name: str) -> Tuple[nn.Module, int]:
"""
>>> bolts_backbone_and_num_features('simclr-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 2048)
>>> bolts_backbone_and_num_features('swav-imagenet') # doctest: +ELLIPSIS
(Sequential(...), 2048)
"""

# TODO: maybe we should plain pytorch weights so we don't need to rely on bolts to load these
# also mabye just use torchhub for the ssl lib
def load_simclr_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/simclr/bolts_simclr_imagenet/simclr_imagenet.ckpt"):
simclr = SimCLR.load_from_checkpoint(path_or_url, strict=False)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(simclr.encoder.children())[:-2])
return backbone, 2048

def load_swav_imagenet(path_or_url: str = f"{ROOT_S3_BUCKET}/swav/swav_imagenet/swav_imagenet.pth.tar"):
swav = SwAV.load_from_checkpoint(path_or_url, strict=True)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(swav.model.children())[:-2])
return backbone, 2048

models = {
'simclr-imagenet': load_simclr_imagenet,
'swav-imagenet': load_swav_imagenet,
}
if not _BOLTS_AVAILABLE:
raise MisconfigurationException("Bolts isn't installed. Please, use ``pip install lightning-bolts``.")
if model_name in models:
return models[model_name]()

raise ValueError(f"{model_name} is not supported yet.")


def torchvision_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
"""
>>> torchvision_backbone_and_num_features('mobilenet_v2') # doctest: +ELLIPSIS
(Sequential(...), 1280)
>>> torchvision_backbone_and_num_features('resnet18') # doctest: +ELLIPSIS
(Sequential(...), 512)
>>> torchvision_backbone_and_num_features('densenet121') # doctest: +ELLIPSIS
(Sequential(...), 1024)
"""
model = getattr(torchvision.models, model_name, None)
if model is None:
raise MisconfigurationException(f"{model_name} is not supported by torchvision")

if model_name in MOBILENET_MODELS + VGG_MODELS:
model = model(pretrained=pretrained)
backbone = model.features
num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features
return backbone, num_features

elif model_name in RESNET_MODELS:
model = model(pretrained=pretrained)
# remove the last two layers & turn it into a Sequential model
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

elif model_name in DENSENET_MODELS:
model = model(pretrained=pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features

raise ValueError(f"{model_name} is not supported yet.")


def timm_backbone_and_num_features(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:

if model_name in timm.list_models():
backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0, global_pool='')
num_features = backbone.num_features
return backbone, num_features

raise ValueError(
f"{model_name} is not supported in timm yet. https://rwightman.github.io/pytorch-image-models/models/"
)
num_features = backbone.num_features
return backbone, num_features

BACKBONES_REGISTRY.register_function(fn=partial(_fn_timm, model_name), name=model_name)
15 changes: 7 additions & 8 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Sequence, Tuple, Type, Union
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union

import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics import Accuracy

from flash.core.classification import ClassificationTask
from flash.vision.backbones import backbone_and_num_features
from flash.vision.backbones import BACKBONES_REGISTRY
from flash.vision.classification.data import ImageClassificationData, ImageClassificationPreprocess


Expand Down Expand Up @@ -58,14 +58,12 @@ class ImageClassifier(ClassificationTask):
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
"""

@property
def preprocess(self):
return ImageClassificationPreprocess(predict_transform=ImageClassificationData.default_val_transforms())

def __init__(
self,
num_classes: int,
backbone: Union[str, Tuple[nn.Module, int]] = "resnet18",
backbone_kwargs: Dict = {},
head: Optional[Union[Callable, nn.Module]] = None,
pretrained: bool = True,
tchaton marked this conversation as resolved.
Show resolved Hide resolved
loss_fn: Callable = F.cross_entropy,
optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD,
Expand All @@ -85,9 +83,10 @@ def __init__(
if isinstance(backbone, tuple):
self.backbone, num_features = backbone
else:
self.backbone, num_features = backbone_and_num_features(backbone, pretrained=pretrained)
self.backbone, num_features = BACKBONES_REGISTRY.get(backbone)(pretrained=pretrained, **backbone_kwargs)

self.head = nn.Sequential(
head = head(num_features, num_classes) if isinstance(head, Callable) else head
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.head = head or nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(num_features, num_classes),
Expand Down
8 changes: 4 additions & 4 deletions flash/vision/detection/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torchvision.ops import box_iou

from flash.core import Task
from flash.vision.backbones import backbone_and_num_features
from flash.vision.backbones import BACKBONES_REGISTRY
from flash.vision.detection.finetuning import ObjectDetectionFineTuning

_models = {
Expand Down Expand Up @@ -133,9 +133,9 @@ def get_model(
**kwargs
)
else:
backbone_model, num_features = backbone_and_num_features(
backbone,
fpn,
_backbone = f"{backbone}-fpn"
backbone = _backbone if _backbone in BACKBONES_REGISTRY else backbone
backbone_model, num_features = BACKBONES_REGISTRY[backbone](
pretrained_backbone,
trainable_backbone_layers,
**kwargs,
Expand Down
Loading