Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add more backbones to semantic segmentation #370

Merged
merged 6 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.1] - YYYY-MM-DD

### Added

- Added `deeplabv3`, `lraspp`, and `unet` backbones for the `SemanticSegmentation` task ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))

### Fixed

- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))
Expand All @@ -18,6 +22,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue with `KorniaParallelTransforms` to assure to share the random state between transforms ([#351](https://github.com/PyTorchLightning/lightning-flash/pull/351))
- Change resize interpolation default mode to nearest ([#352](https://github.com/PyTorchLightning/lightning-flash/pull/352))

### Deprecated

- Deprecated `SemanticSegmentation` backbone names `torchvision/fcn_resnet50` and `torchvision/fcn_resnet101`, use `fc_resnet50` and `fcn_resnet101` instead ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))


## [0.3.0] - 2021-05-20

Expand Down
90 changes: 43 additions & 47 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@
def catch_url_error(fn):

@functools.wraps(fn)
def wrapper(pretrained=False, **kwargs):
def wrapper(*args, pretrained=False, **kwargs):
try:
return fn(pretrained=pretrained, **kwargs)
return fn(*args, pretrained=pretrained, **kwargs)
except urllib.error.URLError:
result = fn(pretrained=False, **kwargs)
result = fn(*args, pretrained=False, **kwargs)
rank_zero_warn(
"Failed to download pretrained weights for the selected backbone. The backbone has been created with"
" `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely"
Expand Down Expand Up @@ -92,14 +92,13 @@ def load_swav_imagenet(

if _TORCHVISION_AVAILABLE:

for model_name in MOBILENET_MODELS + VGG_MODELS:

def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = model.features
num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features
return backbone, num_features
def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = model.features
num_features = 512 if model_name in VGG_MODELS else model.classifier[-1].in_features
return backbone, num_features

for model_name in MOBILENET_MODELS + VGG_MODELS:
_type = "mobilenet" if model_name in MOBILENET_MODELS else "vgg"

IMAGE_CLASSIFIER_BACKBONES(
Expand All @@ -110,14 +109,22 @@ def _fn_mobilenet_vgg(model_name: str, pretrained: bool = True) -> Tuple[nn.Modu
type=_type
)

for model_name in RESNET_MODELS:

def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features
def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*list(model.children())[:-2])
num_features = model.fc.in_features
return backbone, num_features

def _fn_resnet_fpn(
model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs,
) -> Tuple[nn.Module, int]:
backbone = resnet_fpn_backbone(model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs)
return backbone, 256

for model_name in RESNET_MODELS:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_resnet, model_name)),
name=model_name,
Expand All @@ -126,32 +133,20 @@ def _fn_resnet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int
type="resnet"
)

def _fn_resnet_fpn(
model_name: str,
pretrained: bool = True,
trainable_layers: bool = True,
**kwargs,
) -> Tuple[nn.Module, int]:
backbone = resnet_fpn_backbone(
model_name, pretrained=pretrained, trainable_layers=trainable_layers, **kwargs
)
return backbone, 256

OBJ_DETECTION_BACKBONES(
fn=catch_url_error(partial(_fn_resnet_fpn, model_name)),
name=model_name,
package="torchvision",
type="resnet-fpn"
)

for model_name in DENSENET_MODELS:

def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features
def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, int]:
model: nn.Module = getattr(torchvision.models, model_name, None)(pretrained)
backbone = nn.Sequential(*model.features, nn.ReLU(inplace=True))
num_features = model.classifier.in_features
return backbone, num_features

for model_name in DENSENET_MODELS:
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_densenet, model_name)),
name=model_name,
Expand All @@ -161,23 +156,24 @@ def _fn_densenet(model_name: str, pretrained: bool = True) -> Tuple[nn.Module, i
)

if _TIMM_AVAILABLE:

def _fn_timm(
model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '',
) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
num_features = backbone.num_features
return backbone, num_features

for model_name in timm.list_models():

if model_name in TORCHVISION_MODELS:
continue

def _fn_timm(
model_name: str,
pretrained: bool = True,
num_classes: int = 0,
global_pool: str = '',
) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes, global_pool=global_pool
)
num_features = backbone.num_features
return backbone, num_features

IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm"
)
Expand Down
100 changes: 91 additions & 9 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,106 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial

import torch.nn as nn
from deprecate import deprecated
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.image.backbones import catch_url_error

if _TORCHVISION_AVAILABLE:
from torchvision.models import segmentation

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet

FCN_MODELS = ["fcn_resnet50", "fcn_resnet101"]
DEEPLABV3_MODELS = ["deeplabv3_resnet50", "deeplabv3_resnet101", "deeplabv3_mobilenet_v3_large"]
LRASPP_MODELS = ["lraspp_mobilenet_v3_large"]

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:
import torchvision

@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50")
def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)
in_channels = model.classifier[-1].in_channels
model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1)
return model

@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101")
def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
for model_name in FCN_MODELS + DEEPLABV3_MODELS:
_type = model_name.split("_")[0]

SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_fcn_deeplabv3, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type=_type
)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=deprecated(
target=None,
stream=partial(warnings.warn, category=UserWarning),
deprecated_in="0.3.1",
remove_in="0.5.0",
template_mgs="The 'torchvision/fcn_resnet50' backbone has been deprecated since v%(deprecated_in)s in "
"favor of 'fcn_resnet50'. It will be removed in v%(remove_in)s.",
)(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet50")),
name="torchvision/fcn_resnet50",
)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=deprecated(
target=None,
stream=partial(warnings.warn, category=UserWarning),
deprecated_in="0.3.1",
remove_in="0.5.0",
template_mgs="The 'torchvision/fcn_resnet101' backbone has been deprecated since v%(deprecated_in)s in "
"favor of 'fcn_resnet101'. It will be removed in v%(remove_in)s.",
)(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet101")),
name="torchvision/fcn_resnet101",
)

def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)

low_channels = model.classifier.low_classifier.in_channels
high_channels = model.classifier.high_classifier.in_channels

model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1)
return model

for model_name in LRASPP_MODELS:
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_lraspp, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type="lraspp"
)

if _BOLTS_AVAILABLE:

def load_bolts_unet(num_classes: int, pretrained: bool = False, **kwargs) -> nn.Module:
if pretrained:
rank_zero_warn(
"No pretrained weights are available for the pl_bolts.models.vision.UNet model. This backbone will be "
"initialized with random weights!", UserWarning
)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
4 changes: 3 additions & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class SemanticSegmentation(ClassificationTask):
def __init__(
self,
num_classes: int,
backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50",
backbone: Union[str, Tuple[nn.Module, int]] = "fcn_resnet50",
backbone_kwargs: Optional[Dict] = None,
pretrained: bool = True,
loss_fn: Optional[Callable] = None,
Expand Down Expand Up @@ -144,6 +144,8 @@ def forward(self, x) -> torch.Tensor:
out: torch.Tensor
if isinstance(res, dict):
out = res['out']
elif torch.is_tensor(res):
out = res
else:
raise NotImplementedError(f"Unsupported output type: {type(out)}")

Expand Down
9 changes: 5 additions & 4 deletions flash_examples/finetuning/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@
# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3. Build the model
# 3.a List available backbones
print(SemanticSegmentation.available_backbones())

# 3.b Build the model
model = SemanticSegmentation(
backbone="torchvision/fcn_resnet50",
num_classes=datamodule.num_classes,
serializer=SegmentationLabels(visualize=True)
backbone="fcn_resnet50", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True)
)

# 4. Create the trainer.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
torch>=1.7
torchmetrics
pytorch-lightning>=1.3.1
pyDeprecate
PyYAML>=5.1
numpy
pandas
Expand Down
38 changes: 38 additions & 0 deletions tests/image/segmentation/test_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES


@pytest.mark.parametrize(["backbone"], [
pytest.param("fcn_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param("deeplabv3_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param(
"lraspp_mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")
),
pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
])
def test_image_classifier_backbones_registry(backbone):
img = torch.rand(1, 3, 32, 32)
backbone_fn = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)
backbone_model = backbone_fn(10, pretrained=False)
assert backbone_model
backbone_model.eval()
res = backbone_model(img)
if isinstance(res, dict):
res = res["out"]
assert res.shape[1] == 10
Loading