Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add more backbones to semantic segmentation #370

Merged
merged 6 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.1] - YYYY-MM-DD

### Added

- Added `deeplabv3`, `lraspp`, and `unet` backbones for the `SemanticSegmentation` task ([#370](https://github.com/PyTorchLightning/lightning-flash/pull/370))

### Fixed

- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))
Expand Down
6 changes: 3 additions & 3 deletions flash/image/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@
def catch_url_error(fn):

@functools.wraps(fn)
def wrapper(pretrained=False, **kwargs):
def wrapper(*args, pretrained=False, **kwargs):
try:
return fn(pretrained=pretrained, **kwargs)
return fn(*args, pretrained=pretrained, **kwargs)
except urllib.error.URLError:
result = fn(pretrained=False, **kwargs)
result = fn(*args, pretrained=False, **kwargs)
rank_zero_warn(
"Failed to download pretrained weights for the selected backbone. The backbone has been created with"
" `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely"
Expand Down
84 changes: 71 additions & 13 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,82 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial

import torch.nn as nn
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.image.backbones import catch_url_error

if _TORCHVISION_AVAILABLE:
from torchvision.models import segmentation

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet

FCN_MODELS = ["fcn_resnet50", "fcn_resnet101"]
DEEPLABV3_MODELS = ["deeplabv3_resnet50", "deeplabv3_resnet101", "deeplabv3_mobilenet_v3_large"]
LRASPP_MODELS = ["lraspp_mobilenet_v3_large"]

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:
import torchvision

@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50")
def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
return model

@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101")
def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
return model
for model_name in FCN_MODELS + DEEPLABV3_MODELS:

def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)
in_channels = model.classifier[-1].in_channels
model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1)
return model

_type = model_name.split("_")[0]

SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_fcn_deeplabv3, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type=_type
)

for model_name in LRASPP_MODELS:

def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)

low_channels = model.classifier.low_classifier.in_channels
high_channels = model.classifier.high_classifier.in_channels

model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1)
return model

SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_lraspp, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type="lraspp"
)

if _BOLTS_AVAILABLE:

def load_bolts_unet(num_classes: int, pretrained: bool = False, **kwargs) -> nn.Module:
if pretrained:
rank_zero_warn(
"No pretrained weights are available for the pl_bolts.models.vision.UNet model. This backbone will be "
"initialized with random weights!", UserWarning
)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
4 changes: 3 additions & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class SemanticSegmentation(ClassificationTask):
def __init__(
self,
num_classes: int,
backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50",
backbone: Union[str, Tuple[nn.Module, int]] = "fcn_resnet50",
backbone_kwargs: Optional[Dict] = None,
pretrained: bool = True,
loss_fn: Optional[Callable] = None,
Expand Down Expand Up @@ -144,6 +144,8 @@ def forward(self, x) -> torch.Tensor:
out: torch.Tensor
if isinstance(res, dict):
out = res['out']
elif torch.is_tensor(res):
out = res
else:
raise NotImplementedError(f"Unsupported output type: {type(out)}")

Expand Down
9 changes: 5 additions & 4 deletions flash_examples/finetuning/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@
# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3. Build the model
# 3.a List available backbones
print(SemanticSegmentation.available_backbones())

# 3.b Build the model
model = SemanticSegmentation(
backbone="torchvision/fcn_resnet50",
num_classes=datamodule.num_classes,
serializer=SegmentationLabels(visualize=True)
backbone="fcn_resnet50", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True)
)

# 4. Create the trainer.
Expand Down
38 changes: 38 additions & 0 deletions tests/image/segmentation/test_backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES


@pytest.mark.parametrize(["backbone"], [
pytest.param("fcn_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param("deeplabv3_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")),
pytest.param(
"lraspp_mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")
),
pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")),
])
def test_image_classifier_backbones_registry(backbone):
img = torch.rand(1, 3, 32, 32)
backbone_fn = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)
backbone_model = backbone_fn(10, pretrained=False)
assert backbone_model
backbone_model.eval()
res = backbone_model(img)
if isinstance(res, dict):
res = res["out"]
assert res.shape[1] == 10
26 changes: 16 additions & 10 deletions tests/image/segmentation/test_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import numpy as np
Expand Down Expand Up @@ -43,7 +56,7 @@ def test_smoke():
def test_forward(num_classes, img_shape):
model = SemanticSegmentation(
num_classes=num_classes,
backbone='torchvision/fcn_resnet50',
backbone='fcn_resnet50',
)

B, C, H, W = img_shape
Expand All @@ -54,15 +67,8 @@ def test_forward(num_classes, img_shape):


@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.parametrize(
"backbone",
[
"torchvision/fcn_resnet50",
"torchvision/fcn_resnet101",
],
)
def test_init_train(tmpdir, backbone):
model = SemanticSegmentation(num_classes=10, backbone=backbone)
def test_init_train(tmpdir):
model = SemanticSegmentation(num_classes=10)
train_dl = torch.utils.data.DataLoader(DummyDataset())
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
Expand Down
13 changes: 13 additions & 0 deletions tests/image/test_backbones.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import urllib.error

import pytest
Expand Down