From e62fc9b08ac00639cbf957a9d3a502d883a5156b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 7 Jun 2021 11:41:50 +0100 Subject: [PATCH] Avoid function redefinition --- flash/image/backbones.py | 84 +++++++++++++-------------- flash/image/segmentation/backbones.py | 29 +++++---- 2 files changed, 54 insertions(+), 59 deletions(-) diff --git a/flash/image/backbones.py b/flash/image/backbones.py index 790a1650de..47b2252f94 100644 --- a/flash/image/backbones.py +++ b/flash/image/backbones.py @@ -92,14 +92,13 @@ def load_swav_imagenet( if _TORCHVISION_AVAILABLE: - for model_name in MOBILENET_MODELS + VGG_MODELS: - - def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) - backbone = model.features - num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features - return backbone, num_features + def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: + model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) + backbone = model.features + num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features + return backbone, num_features + for model_name in MOBILENET_MODELS + VGG_MODELS: _type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg" IMAGE_CLASSIFIER_BACKBONES( @@ -110,14 +109,22 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu type=_type ) - for model_name in RESNET_MODELS: - - def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) - backbone = nn.Sequential(*list(model.children())[:-2]) - num_features = model.fc.in_features - return backbone, num_features + def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: + model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) + backbone = nn.Sequential(*list(model.children())[:-2]) + num_features = model.fc.in_features + return backbone, num_features + + def _fn_resnet_fpn( + model_name: str, + pretrained: bool = True, + trainable_layers: bool = True, + **kwargs, + ) -> Tuple[nn.Module, int]: + backbone = resnet_fpn_backbone(model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs) + return backbone, 256 + for model_name in RESNET_MODELS: IMAGE_CLASSIFIER_BACKBONES( fn=catch_url_error(partial(_fn_resnet, model_name)), name=model_name, @@ -126,17 +133,6 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int type="resnet" ) - def _fn_resnet_fpn( - model_name: str, - pretrained: bool = True, - trainable_layers: bool = True, - **kwargs, - ) -> Tuple[nn.Module, int]: - backbone = resnet_fpn_backbone( - model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs - ) - return backbone, 256 - OBJ_DETECTION_BACKBONES( fn=catch_url_error(partial(_fn_resnet_fpn, model_name)), name=model_name, @@ -144,14 +140,13 @@ def _fn_resnet_fpn( type="resnet-fpn" ) - for model_name in DENSENET_MODELS: - - def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: - model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) - backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) - num_features = model.classifier.in_features - return backbone, num_features + def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]: + model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained) + backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True)) + num_features = model.classifier.in_features + return backbone, num_features + for model_name in DENSENET_MODELS: IMAGE_CLASSIFIER_BACKBONES( fn=catch_url_error(partial(_fn_densenet, model_name)), name=model_name, @@ -161,23 +156,24 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i ) if _TIMM_AVAILABLE: + + def _fn_timm( + model_name: str, + pretrained: bool = True, + num_classes: int = 0, + global_pool: str = '', + ) -> Tuple[nn.Module, int]: + backbone = timm.create_model( + model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool + ) + num_features = backbone.num_features + return backbone, num_features + for model_name in timm.list_models(): if model_name in TORCHVISION_MODELS: continue - def _fn_timm( - model_name: str, - pretrained: bool = True, - num_classes: int = 0, - global_pool: str = '', - ) -> Tuple[nn.Module, int]: - backbone = timm.create_model( - model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool - ) - num_features = backbone.num_features - return backbone, num_features - IMAGE_CLASSIFIER_BACKBONES( fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm" ) diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index 85fffe0cd1..dca68c5bc8 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -39,14 +39,14 @@ SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") if _TORCHVISION_AVAILABLE: - for model_name in FCN_MODELS + DEEPLABV3_MODELS: - def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module: - model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs) - in_channels = model.classifier[-1].in_channels - model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1) - return model + def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module: + model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs) + in_channels = model.classifier[-1].in_channels + model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1) + return model + for model_name in FCN_MODELS + DEEPLABV3_MODELS: _type = model_name.split("_")[0] SEMANTIC_SEGMENTATION_BACKBONES( @@ -57,18 +57,17 @@ def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True type=_type ) - for model_name in LRASPP_MODELS: + def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module: + model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs) - def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module: - model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs) + low_channels = model.classifier.low_classifier.in_channels + high_channels = model.classifier.high_classifier.in_channels - low_channels = model.classifier.low_classifier.in_channels - high_channels = model.classifier.high_classifier.in_channels - - model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1) - model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1) - return model + model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1) + model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1) + return model + for model_name in LRASPP_MODELS: SEMANTIC_SEGMENTATION_BACKBONES( fn=catch_url_error(partial(_fn_lraspp, model_name)), name=model_name,