Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add more backbones to semantic segmentation (#370)
Browse files Browse the repository at this point in the history
* Add more backbones to semantic segmentation

* Update example

* Update CHANGELOG.md

* Avoid function redefinition

* Make old backbones deprecated

* change deprecation version
  • Loading branch information
ethanwharris authored Jun 7, 2021
1 parent 447c28e commit e31081c
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 71 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.1] - YYYY-MM-DD

### Added

- Added `deeplabv3`, `lraspp`, and `unet` backbones for the `SemanticSegmentation` task ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))

### Fixed

- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))
Expand All @@ -18,6 +22,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue with `KorniaParallelTransforms` to assure to share the random state between transforms ([#351](https://github.com/PyTorchLightning/lightning-flash/pull/351))
- Change resize interpolation default mode to nearest ([#352](https://github.com/PyTorchLightning/lightning-flash/pull/352))

### Deprecated

- Deprecated `SemanticSegmentation` backbone names `torchvision/fcn_resnet50` and `torchvision/fcn_resnet101`, use `fc_resnet50` and `fcn_resnet101` instead ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))


## [0.3.0] - 2021-05-20

Expand Down
90 changes: 43 additions & 47 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@
def catch_url_error(fn):

@functools.wraps(fn)
def wrapper(pretrained=False, **kwargs):
def wrapper(*args, pretrained=False, **kwargs):
try:
return fn(pretrained=pretrained, **kwargs)
return fn(*args, pretrained=pretrained, **kwargs)
except urllib.error.URLError:
result = fn(pretrained=False, **kwargs)
result = fn(*args, pretrained=False, **kwargs)
rank_zero_warn(
"Failed to download pretrained weights for the selected backbone. The backbone has been created with"
" `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely"
Expand Down Expand Up @@ -92,14 +92,13 @@ def load_swav_imagenet(

if _TORCHVISION_AVAILABLE:

for model_name in MOBILENET_MODELS + VGG_MODELS:

def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = model.features
num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features
return backbone, num_features
def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = model.features
num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features
return backbone, num_features

for model_name in MOBILENET_MODELS + VGG_MODELS:
_type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg"

IMAGE_CLASSIFIER_BACKBONES(
Expand All @@ -110,14 +109,22 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu
type=_type
)

for model_name in RESNET_MODELS:

def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features
def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

def _fn_resnet_fpn(
model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs,
) -> Tuple[nn.Module, int]:
backbone = resnet_fpn_backbone(model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs)
return backbone, 256

for model_name in RESNET_MODELS:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_resnet, model_name)),
name=model_name,
Expand All @@ -126,32 +133,20 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int
type="resnet"
)

def _fn_resnet_fpn(
model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs,
) -> Tuple[nn.Module, int]:
backbone = resnet_fpn_backbone(
model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs
)
return backbone, 256

OBJ_DETECTION_BACKBONES(
fn=catch_url_error(partial(_fn_resnet_fpn, model_name)),
name=model_name,
package="torchvision",
type="resnet-fpn"
)

for model_name in DENSENET_MODELS:

def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features
def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features

for model_name in DENSENET_MODELS:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_densenet, model_name)),
name=model_name,
Expand All @@ -161,23 +156,24 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i
)

if _TIMM_AVAILABLE:

def _fn_timm(
model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '',
) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
num_features = backbone.num_features
return backbone, num_features

for model_name in timm.list_models():

if model_name in TORCHVISION_MODELS:
continue

def _fn_timm(
model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '',
) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
num_features = backbone.num_features
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm"
)
Expand Down
100 changes: 91 additions & 9 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,106 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial

import torch.nn as nn
from deprecate import deprecated
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.image.backbones import catch_url_error

if _TORCHVISION_AVAILABLE:
from torchvision.models import segmentation

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet

FCN_MODELS = ["fcn_resnet50", "fcn_resnet101"]
DEEPLABV3_MODELS = ["deeplabv3_resnet50", "deeplabv3_resnet101", "deeplabv3_mobilenet_v3_large"]
LRASPP_MODELS = ["lraspp_mobilenet_v3_large"]

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:
import torchvision

@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50")
def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)
in_channels = model.classifier[-1].in_channels
model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1)
return model

@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101")
def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
for model_name in FCN_MODELS + DEEPLABV3_MODELS:
_type = model_name.split("_")[0]

SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_fcn_deeplabv3, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type=_type
)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=deprecated(
target=None,
stream=partial(warnings.warn, category=UserWarning),
deprecated_in="0.3.1",
remove_in="0.5.0",
template_mgs="The 'torchvision/fcn_resnet50' backbone has been deprecated since v%(deprecated_in)s in "
"favor of 'fcn_resnet50'. It will be removed in v%(remove_in)s.",
)(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet50")),
name="torchvision/fcn_resnet50",
)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=deprecated(
target=None,
stream=partial(warnings.warn, category=UserWarning),
deprecated_in="0.3.1",
remove_in="0.5.0",
template_mgs="The 'torchvision/fcn_resnet101' backbone has been deprecated since v%(deprecated_in)s in "
"favor of 'fcn_resnet101'. It will be removed in v%(remove_in)s.",
)(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet101")),
name="torchvision/fcn_resnet101",
)

def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)

low_channels = model.classifier.low_classifier.in_channels
high_channels = model.classifier.high_classifier.in_channels

model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1)
return model

for model_name in LRASPP_MODELS:
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_lraspp, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type="lraspp"
)

if _BOLTS_AVAILABLE:

def load_bolts_unet(num_classes: int, pretrained: bool = False, **kwargs) -> nn.Module:
if pretrained:
rank_zero_warn(
"No pretrained weights are available for the pl_bolts.models.vision.UNet model. This backbone will be "
"initialized with random weights!", UserWarning
)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
4 changes: 3 additions & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class SemanticSegmentation(ClassificationTask):
def __init__(
self,
num_classes: int,
backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50",
backbone: Union[str, Tuple[nn.Module, int]] = "fcn_resnet50",
backbone_kwargs: Optional[Dict] = None,
pretrained: bool = True,
loss_fn: Optional[Callable] = None,
Expand Down Expand Up @@ -144,6 +144,8 @@ def forward(self, x) -> torch.Tensor:
out: torch.Tensor
if isinstance(res, dict):
out = res['out']
elif torch.is_tensor(res):
out = res
else:
raise NotImplementedError(f"Unsupported output type: {type(out)}")

Expand Down
9 changes: 5 additions & 4 deletions flash_examples/finetuning/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@
# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3. Build the model
# 3.a List available backbones
print(SemanticSegmentation.available_backbones())

# 3.b Build the model
model = SemanticSegmentation(
backbone="torchvision/fcn_resnet50",
num_classes=datamodule.num_classes,
serializer=SegmentationLabels(visualize=True)
backbone="fcn_resnet50", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True)
)

# 4. Create the trainer.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
torch>=1.7
torchmetrics
pytorch-lightning>=1.3.1
pyDeprecate
PyYAML>=5.1
numpy
pandas
Expand Down
38 changes: 38 additions & 0 deletions tests/image/segmentation/test_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES


@pytest.mark.parametrize(["backbone"], [
pytest.param("fcn_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param("deeplabv3_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param(
"lraspp_mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")
),
pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
])
def test_image_classifier_backbones_registry(backbone):
img = torch.rand(1, 3, 32, 32)
backbone_fn = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)
backbone_model = backbone_fn(10, pretrained=False)
assert backbone_model
backbone_model.eval()
res = backbone_model(img)
if isinstance(res, dict):
res = res["out"]
assert res.shape[1] == 10
Loading

0 comments on commit e31081c

Please sign in to comment.