Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Enable semantic segmentation backbone and head #412

Merged
merged 11 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Split `backbone` argument to `SemanticSegmentation` into `backbone` and `head` arguments ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412))

### Deprecated


### Fixed

- Fixed a bug where the `DefaultDataKeys.METADATA` couldn't be a dict ([#393](https://github.com/PyTorchLightning/lightning-flash/pull/393))
- Fixed a bug where the `SemanticSegmentation` task would not work as expected with finetuning callbacks ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412))

## [0.3.2] - 2021-06-08

Expand Down
4 changes: 2 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ def available_backbones(cls) -> List[str]:
return registry.available_keys()

@classmethod
def available_models(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "models", None)
def available_heads(cls) -> List[str]:
registry: Optional[FlashRegistry] = getattr(cls, "heads", None)
if registry is None:
return []
return registry.available_keys()
Expand Down
98 changes: 20 additions & 78 deletions flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,106 +11,48 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial

import torch.nn as nn
from deprecate import deprecated
from pytorch_lightning.utilities import rank_zero_warn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.image.backbones import catch_url_error

if _TORCHVISION_AVAILABLE:
from torchvision.models import segmentation

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet
from torchvision.models import mobilenetv3, resnet

FCN_MODELS = ["fcn_resnet50", "fcn_resnet101"]
DEEPLABV3_MODELS = ["deeplabv3_resnet50", "deeplabv3_resnet101", "deeplabv3_mobilenet_v3_large"]
LRASPP_MODELS = ["lraspp_mobilenet_v3_large"]
RESNET_MODELS = ["resnet50", "resnet101"]
MOBILENET_MODELS = ["mobilenet_v3_large"]

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:

def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)
in_channels = model.classifier[-1].in_channels
model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1)
return model

for model_name in FCN_MODELS + DEEPLABV3_MODELS:
_type = model_name.split("_")[0]
def _load_resnet(model_name: str, pretrained: bool = True):
backbone = resnet.__dict__[model_name](
pretrained=pretrained,
replace_stride_with_dilation=[False, True, True],
)
return backbone

for model_name in RESNET_MODELS:
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_fcn_deeplabv3, model_name)),
fn=catch_url_error(partial(_load_resnet, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type=_type
)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=deprecated(
target=None,
stream=partial(warnings.warn, category=UserWarning),
deprecated_in="0.3.1",
remove_in="0.5.0",
template_mgs="The 'torchvision/fcn_resnet50' backbone has been deprecated since v%(deprecated_in)s in "
"favor of 'fcn_resnet50'. It will be removed in v%(remove_in)s.",
)(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet50")),
name="torchvision/fcn_resnet50",
)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=deprecated(
target=None,
stream=partial(warnings.warn, category=UserWarning),
deprecated_in="0.3.1",
remove_in="0.5.0",
template_mgs="The 'torchvision/fcn_resnet101' backbone has been deprecated since v%(deprecated_in)s in "
"favor of 'fcn_resnet101'. It will be removed in v%(remove_in)s.",
)(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet101")),
name="torchvision/fcn_resnet101",
)

def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module:
model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs)

low_channels = model.classifier.low_classifier.in_channels
high_channels = model.classifier.high_classifier.in_channels

model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1)
model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1)
return model
def _load_mobilenetv3(model_name: str, pretrained: bool = True):
backbone = mobilenetv3.__dict__[model_name](
pretrained=pretrained,
_dilated=True,
)
return backbone

for model_name in LRASPP_MODELS:
for model_name in MOBILENET_MODELS:
SEMANTIC_SEGMENTATION_BACKBONES(
fn=catch_url_error(partial(_fn_lraspp, model_name)),
fn=catch_url_error(partial(_load_mobilenetv3, model_name)),
name=model_name,
namespace="image/segmentation",
package="torchvision",
type="lraspp"
)

if _BOLTS_AVAILABLE:

def load_bolts_unet(num_classes: int, pretrained: bool = False, **kwargs) -> nn.Module:
if pretrained:
rank_zero_warn(
"No pretrained weights are available for the pl_bolts.models.vision.UNet model. This backbone will be "
"initialized with random weights!", UserWarning
)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_BACKBONES(
fn=load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
116 changes: 116 additions & 0 deletions flash/image/segmentation/heads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial

import torch.nn as nn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
from torchvision.models import MobileNetV3, ResNet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3
from torchvision.models.segmentation.fcn import FCN, FCNHead
from torchvision.models.segmentation.lraspp import LRASPP

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet

RESNET_MODELS = ["resnet50", "resnet101"]
MOBILENET_MODELS = ["mobilenet_v3_large"]

SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:

def _get_backbone_meta(backbone):
"""Adapted from torchvision.models.segmentation.segmentation._segm_model:
https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/segmentation.py#L25
"""
if isinstance(backbone, ResNet):
out_layer = 'layer4'
out_inplanes = 2048
aux_layer = 'layer3'
aux_inplanes = 1024
elif isinstance(backbone, MobileNetV3):
backbone = backbone.features
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)]
stage_indices = [0] + stage_indices + [len(backbone) - 1]
out_pos = stage_indices[-1] # use C5 which has output_stride = 16
out_layer = str(out_pos)
out_inplanes = backbone[out_pos].out_channels
aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8
aux_layer = str(aux_pos)
aux_inplanes = backbone[aux_pos].out_channels
else:
raise MisconfigurationException(
f"{type(backbone)} backbone is not currently supported for semantic segmentation."
)
return backbone, out_layer, out_inplanes, aux_layer, aux_inplanes

def _load_fcn_deeplabv3(model_name, backbone, num_classes):
backbone, out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone)

return_layers = {out_layer: 'out'}
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

model_map = {
"deeplabv3": (DeepLabHead, DeepLabV3),
"fcn": (FCNHead, FCN),
}
classifier = model_map[model_name][0](out_inplanes, num_classes)
base_model = model_map[model_name][1]

return base_model(backbone, classifier, None)

for model_name in ["fcn", "deeplabv3"]:
SEMANTIC_SEGMENTATION_HEADS(
fn=partial(_load_fcn_deeplabv3, model_name),
name=model_name,
namespace="image/segmentation",
package="torchvision",
)

def _load_lraspp(backbone, num_classes):
backbone, high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone)
backbone = IntermediateLayerGetter(backbone, return_layers={low_pos: 'low', high_pos: 'high'})
return LRASPP(backbone, low_channels, high_channels, num_classes)

SEMANTIC_SEGMENTATION_HEADS(
fn=_load_lraspp,
name="lraspp",
namespace="image/segmentation",
package="torchvision",
)

if _BOLTS_AVAILABLE:

def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module:
rank_zero_warn("The UNet model does not require a backbone, so the backbone will be ignored.", UserWarning)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_HEADS(
fn=_load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
38 changes: 25 additions & 13 deletions flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
from flash.image.segmentation.serialization import SegmentationLabels

if _KORNIA_AVAILABLE:
Expand Down Expand Up @@ -54,14 +55,15 @@ class SemanticSegmentation(ClassificationTask):

Args:
num_classes: Number of classes to classify.
backbone: A string or (model, num_features) tuple to use to compute image features,
defaults to ``"torchvision/fcn_resnet50"``.
backbone: A string or model to use to compute image features.
backbone_kwargs: Additional arguments for the backbone configuration.
pretrained: Use a pretrained backbone, defaults to ``False``.
loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
optimizer: Optimizer to use for training, defaults to :class:`torch.optim.AdamW`.
metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.IoU`.
learning_rate: Learning rate to use for training, defaults to ``1e-3``.
head: A string or (model, num_features) tuple to use to compute image features.
head_kwargs: Additional arguments for the head configuration.
pretrained: Use a pretrained backbone.
loss_fn: Loss function for training.
optimizer: Optimizer to use for training.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training.
multi_label: Whether the targets are multi-label or not.
serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs.
"""
Expand All @@ -70,11 +72,15 @@ class SemanticSegmentation(ClassificationTask):

backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES

heads: FlashRegistry = SEMANTIC_SEGMENTATION_HEADS

def __init__(
self,
num_classes: int,
backbone: Union[str, Tuple[nn.Module, int]] = "fcn_resnet50",
backbone: Union[str, nn.Module] = "resnet50",
backbone_kwargs: Optional[Dict] = None,
head: str = "fcn",
head_kwargs: Optional[Dict] = None,
pretrained: bool = True,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
Expand Down Expand Up @@ -113,8 +119,15 @@ def __init__(
if not backbone_kwargs:
backbone_kwargs = {}

# TODO: pretrained to True causes some issues
self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs)
if not head_kwargs:
head_kwargs = {}

if isinstance(backbone, nn.Module):
self.backbone = backbone
else:
self.backbone = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

self.head = self.heads.get(head)(self.backbone, num_classes, **head_kwargs)

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
Expand All @@ -134,8 +147,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
return batch

def forward(self, x) -> torch.Tensor:
# infer the image to the model
res = self.backbone(x)
res = self.head(x)

# some frameworks like torchvision return a dict.
# In particular, torchvision segmentation models return the output logits
Expand All @@ -145,7 +157,7 @@ def forward(self, x) -> torch.Tensor:
elif torch.is_tensor(res):
out = res
else:
raise NotImplementedError(f"Unsupported output type: {type(out)}")
raise NotImplementedError(f"Unsupported output type: {type(res)}")

return out

Expand Down
20 changes: 11 additions & 9 deletions flash_examples/finetuning/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,33 @@
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
batch_size=4,
val_split=0.3,
image_size=(200, 200), # (600, 800)
val_split=0.1,
image_size=(200, 200),
num_classes=21,
)

# 2.2 Visualise the samples
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])

# 3.a List available backbones
print(SemanticSegmentation.available_backbones())
# 3.a List available backbones and heads
print(f"Backbones: {SemanticSegmentation.available_backbones()}")
print(f"Heads: {SemanticSegmentation.available_heads()}")

# 3.b Build the model
model = SemanticSegmentation(
backbone="fcn_resnet50", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True)
backbone="mobilenet_v3_large",
head="fcn",
num_classes=datamodule.num_classes,
serializer=SegmentationLabels(visualize=True),
)

# 4. Create the trainer.
trainer = flash.Trainer(
max_epochs=1,
fast_dev_run=1,
)
trainer = flash.Trainer(fast_dev_run=True)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Segment a few images!
predictions = model.predict([
"data/CameraRGB/F61-1.png",
"data/CameraRGB/F62-1.png",
Expand Down
Loading