Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Avoid function redefinition
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jun 7, 2021
1 parent 413f0ce commit e62fc9b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 59 deletions.
84 changes: 40 additions & 44 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ def load_swav_imagenet(

if _TORCHVISION_AVAILABLE:

for model_name in MOBILENET_MODELS + VGG_MODELS:

def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = model.features
num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features
return backbone, num_features
def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = model.features
num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features
return backbone, num_features

for model_name in MOBILENET_MODELS + VGG_MODELS:
_type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg"

IMAGE_CLASSIFIER_BACKBONES(
Expand All @@ -110,14 +109,22 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu
type=_type
)

for model_name in RESNET_MODELS:

def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features
def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

def _fn_resnet_fpn(
model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs,
) -> Tuple[nn.Module, int]:
backbone = resnet_fpn_backbone(model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs)
return backbone, 256

for model_name in RESNET_MODELS:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_resnet, model_name)),
name=model_name,
Expand All @@ -126,32 +133,20 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int
type="resnet"
)

def _fn_resnet_fpn(
model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs,
) -> Tuple[nn.Module, int]:
backbone = resnet_fpn_backbone(
model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs
)
return backbone, 256

OBJ_DETECTION_BACKBONES(
fn=catch_url_error(partial(_fn_resnet_fpn, model_name)),
name=model_name,
package="torchvision",
type="resnet-fpn"
)

for model_name in DENSENET_MODELS:

def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features
def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features

for model_name in DENSENET_MODELS:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_densenet, model_name)),
name=model_name,
Expand All @@ -161,23 +156,24 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i
)

if _TIMM_AVAILABLE:

def _fn_timm(
model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '',
) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
num_features = backbone.num_features
return backbone, num_features

for model_name in timm.list_models():

if model_name in TORCHVISION_MODELS:
continue

def _fn_timm(
model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '',
) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
num_features = backbone.num_features
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm"
)
Expand Down
29 changes: 14 additions & 15 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:
for model_name in FCN_MODELS + DEEPLABV3_MODELS:

def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)
in_channels = model.classifier[-1].in_channels
model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1)
return model
def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)
in_channels = model.classifier[-1].in_channels
model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1)
return model

for model_name in FCN_MODELS + DEEPLABV3_MODELS:
_type = model_name.split("_")[0]

SEMANTIC_SEGMENTATION_BACKBONES(
Expand All @@ -57,18 +57,17 @@ def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True
type=_type
)

for model_name in LRASPP_MODELS:
def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)

def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)
low_channels = model.classifier.low_classifier.in_channels
high_channels = model.classifier.high_classifier.in_channels

low_channels = model.classifier.low_classifier.in_channels
high_channels = model.classifier.high_classifier.in_channels

model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1)
return model
model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1)
return model

for model_name in LRASPP_MODELS:
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_lraspp, model_name)),
name=model_name,
Expand Down

0 comments on commit e62fc9b

Please sign in to comment.