diff --git a/CHANGELOG.md b/CHANGELOG.md index c8e1f16c16..4dae6cbb71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Split `backbone` argument to `SemanticSegmentation` into `backbone` and `head` arguments ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412)) ### Deprecated @@ -22,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed a bug where the `DefaultDataKeys.METADATA` couldn't be a dict ([#393](https://github.com/PyTorchLightning/lightning-flash/pull/393)) +- Fixed a bug where the `SemanticSegmentation` task would not work as expected with finetuning callbacks ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412)) ## [0.3.2] - 2021-06-08 diff --git a/flash/core/model.py b/flash/core/model.py index 8664195dfe..a52de73dfb 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -488,8 +488,8 @@ def available_backbones(cls) -> List[str]: return registry.available_keys() @classmethod - def available_models(cls) -> List[str]: - registry: Optional[FlashRegistry] = getattr(cls, "models", None) + def available_heads(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "heads", None) if registry is None: return [] return registry.available_keys() diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index 2e65c800fc..7244caf6c5 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -11,106 +11,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import warnings from functools import partial -import torch.nn as nn -from deprecate import deprecated -from pytorch_lightning.utilities import rank_zero_warn - from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE from flash.image.backbones import catch_url_error if _TORCHVISION_AVAILABLE: - from torchvision.models import segmentation - -if _BOLTS_AVAILABLE: - if os.getenv("WARN_MISSING_PACKAGE") == "0": - with warnings.catch_warnings(record=True) as w: - from pl_bolts.models.vision import UNet - else: - from pl_bolts.models.vision import UNet + from torchvision.models import mobilenetv3, resnet -FCN_MODELS = ["fcn_resnet50", "fcn_resnet101"] -DEEPLABV3_MODELS = ["deeplabv3_resnet50", "deeplabv3_resnet101", "deeplabv3_mobilenet_v3_large"] -LRASPP_MODELS = ["lraspp_mobilenet_v3_large"] +RESNET_MODELS = ["resnet50", "resnet101"] +MOBILENET_MODELS = ["mobilenet_v3_large"] SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") if _TORCHVISION_AVAILABLE: - def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module: - model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs) - in_channels = model.classifier[-1].in_channels - model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1) - return model - - for model_name in FCN_MODELS + DEEPLABV3_MODELS: - _type = model_name.split("_")[0] + def _load_resnet(model_name: str, pretrained: bool = True): + backbone = resnet.__dict__[model_name]( + pretrained=pretrained, + replace_stride_with_dilation=[False, True, True], + ) + return backbone + for model_name in RESNET_MODELS: SEMANTIC_SEGMENTATION_BACKBONES( - fn=catch_url_error(partial(_fn_fcn_deeplabv3, model_name)), + fn=catch_url_error(partial(_load_resnet, model_name)), name=model_name, namespace="image/segmentation", package="torchvision", - type=_type ) - SEMANTIC_SEGMENTATION_BACKBONES( - fn=deprecated( - target=None, - stream=partial(warnings.warn, category=UserWarning), - deprecated_in="0.3.1", - remove_in="0.5.0", - template_mgs="The 'torchvision/fcn_resnet50' backbone has been deprecated since v%(deprecated_in)s in " - "favor of 'fcn_resnet50'. It will be removed in v%(remove_in)s.", - )(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet50")), - name="torchvision/fcn_resnet50", - ) - - SEMANTIC_SEGMENTATION_BACKBONES( - fn=deprecated( - target=None, - stream=partial(warnings.warn, category=UserWarning), - deprecated_in="0.3.1", - remove_in="0.5.0", - template_mgs="The 'torchvision/fcn_resnet101' backbone has been deprecated since v%(deprecated_in)s in " - "favor of 'fcn_resnet101'. It will be removed in v%(remove_in)s.", - )(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet101")), - name="torchvision/fcn_resnet101", - ) - - def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module: - model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs) - - low_channels = model.classifier.low_classifier.in_channels - high_channels = model.classifier.high_classifier.in_channels - - model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1) - model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1) - return model + def _load_mobilenetv3(model_name: str, pretrained: bool = True): + backbone = mobilenetv3.__dict__[model_name]( + pretrained=pretrained, + _dilated=True, + ) + return backbone - for model_name in LRASPP_MODELS: + for model_name in MOBILENET_MODELS: SEMANTIC_SEGMENTATION_BACKBONES( - fn=catch_url_error(partial(_fn_lraspp, model_name)), + fn=catch_url_error(partial(_load_mobilenetv3, model_name)), name=model_name, namespace="image/segmentation", package="torchvision", - type="lraspp" ) - -if _BOLTS_AVAILABLE: - - def load_bolts_unet(num_classes: int, pretrained: bool = False, **kwargs) -> nn.Module: - if pretrained: - rank_zero_warn( - "No pretrained weights are available for the pl_bolts.models.vision.UNet model. This backbone will be " - "initialized with random weights!", UserWarning - ) - return UNet(num_classes, **kwargs) - - SEMANTIC_SEGMENTATION_BACKBONES( - fn=load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet" - ) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py new file mode 100644 index 0000000000..eab47e0cf5 --- /dev/null +++ b/flash/image/segmentation/heads.py @@ -0,0 +1,116 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from functools import partial + +import torch.nn as nn +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + from torchvision.models import MobileNetV3, ResNet + from torchvision.models._utils import IntermediateLayerGetter + from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3 + from torchvision.models.segmentation.fcn import FCN, FCNHead + from torchvision.models.segmentation.lraspp import LRASPP + +if _BOLTS_AVAILABLE: + if os.getenv("WARN_MISSING_PACKAGE") == "0": + with warnings.catch_warnings(record=True) as w: + from pl_bolts.models.vision import UNet + else: + from pl_bolts.models.vision import UNet + +RESNET_MODELS = ["resnet50", "resnet101"] +MOBILENET_MODELS = ["mobilenet_v3_large"] + +SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones") + +if _TORCHVISION_AVAILABLE: + + def _get_backbone_meta(backbone): + """Adapted from torchvision.models.segmentation.segmentation._segm_model: + https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/segmentation.py#L25 + """ + if isinstance(backbone, ResNet): + out_layer = 'layer4' + out_inplanes = 2048 + aux_layer = 'layer3' + aux_inplanes = 1024 + elif isinstance(backbone, MobileNetV3): + backbone = backbone.features + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + stage_indices = [0] + stage_indices + [len(backbone) - 1] + out_pos = stage_indices[-1] # use C5 which has output_stride = 16 + out_layer = str(out_pos) + out_inplanes = backbone[out_pos].out_channels + aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + aux_layer = str(aux_pos) + aux_inplanes = backbone[aux_pos].out_channels + else: + raise MisconfigurationException( + f"{type(backbone)} backbone is not currently supported for semantic segmentation." + ) + return backbone, out_layer, out_inplanes, aux_layer, aux_inplanes + + def _load_fcn_deeplabv3(model_name, backbone, num_classes): + backbone, out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone) + + return_layers = {out_layer: 'out'} + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + model_map = { + "deeplabv3": (DeepLabHead, DeepLabV3), + "fcn": (FCNHead, FCN), + } + classifier = model_map[model_name][0](out_inplanes, num_classes) + base_model = model_map[model_name][1] + + return base_model(backbone, classifier, None) + + for model_name in ["fcn", "deeplabv3"]: + SEMANTIC_SEGMENTATION_HEADS( + fn=partial(_load_fcn_deeplabv3, model_name), + name=model_name, + namespace="image/segmentation", + package="torchvision", + ) + + def _load_lraspp(backbone, num_classes): + backbone, high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone) + backbone = IntermediateLayerGetter(backbone, return_layers={low_pos: 'low', high_pos: 'high'}) + return LRASPP(backbone, low_channels, high_channels, num_classes) + + SEMANTIC_SEGMENTATION_HEADS( + fn=_load_lraspp, + name="lraspp", + namespace="image/segmentation", + package="torchvision", + ) + +if _BOLTS_AVAILABLE: + + def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module: + rank_zero_warn("The UNet model does not require a backbone, so the backbone will be ignored.", UserWarning) + return UNet(num_classes, **kwargs) + + SEMANTIC_SEGMENTATION_HEADS( + fn=_load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet" + ) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 8d89c4d4a8..fb2e4d4b55 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -24,6 +24,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES +from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS from flash.image.segmentation.serialization import SegmentationLabels if _KORNIA_AVAILABLE: @@ -54,14 +55,15 @@ class SemanticSegmentation(ClassificationTask): Args: num_classes: Number of classes to classify. - backbone: A string or (model, num_features) tuple to use to compute image features, - defaults to ``"torchvision/fcn_resnet50"``. + backbone: A string or model to use to compute image features. backbone_kwargs: Additional arguments for the backbone configuration. - pretrained: Use a pretrained backbone, defaults to ``False``. - loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. - optimizer: Optimizer to use for training, defaults to :class:`torch.optim.AdamW`. - metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.IoU`. - learning_rate: Learning rate to use for training, defaults to ``1e-3``. + head: A string or (model, num_features) tuple to use to compute image features. + head_kwargs: Additional arguments for the head configuration. + pretrained: Use a pretrained backbone. + loss_fn: Loss function for training. + optimizer: Optimizer to use for training. + metrics: Metrics to compute for training and evaluation. + learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. """ @@ -70,11 +72,15 @@ class SemanticSegmentation(ClassificationTask): backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES + heads: FlashRegistry = SEMANTIC_SEGMENTATION_HEADS + def __init__( self, num_classes: int, - backbone: Union[str, Tuple[nn.Module, int]] = "fcn_resnet50", + backbone: Union[str, nn.Module] = "resnet50", backbone_kwargs: Optional[Dict] = None, + head: str = "fcn", + head_kwargs: Optional[Dict] = None, pretrained: bool = True, loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, @@ -113,8 +119,15 @@ def __init__( if not backbone_kwargs: backbone_kwargs = {} - # TODO: pretrained to True causes some issues - self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs) + if not head_kwargs: + head_kwargs = {} + + if isinstance(backbone, nn.Module): + self.backbone = backbone + else: + self.backbone = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) + + self.head = self.heads.get(head)(self.backbone, num_classes, **head_kwargs) def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) @@ -134,8 +147,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return batch def forward(self, x) -> torch.Tensor: - # infer the image to the model - res = self.backbone(x) + res = self.head(x) # some frameworks like torchvision return a dict. # In particular, torchvision segmentation models return the output logits @@ -145,7 +157,7 @@ def forward(self, x) -> torch.Tensor: elif torch.is_tensor(res): out = res else: - raise NotImplementedError(f"Unsupported output type: {type(out)}") + raise NotImplementedError(f"Unsupported output type: {type(res)}") return out diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 049bd9d318..d7f75d1e46 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -29,31 +29,33 @@ train_folder="data/CameraRGB", train_target_folder="data/CameraSeg", batch_size=4, - val_split=0.3, - image_size=(200, 200), # (600, 800) + val_split=0.1, + image_size=(200, 200), num_classes=21, ) # 2.2 Visualise the samples datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) -# 3.a List available backbones -print(SemanticSegmentation.available_backbones()) +# 3.a List available backbones and heads +print(f"Backbones: {SemanticSegmentation.available_backbones()}") +print(f"Heads: {SemanticSegmentation.available_heads()}") # 3.b Build the model model = SemanticSegmentation( - backbone="fcn_resnet50", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True) + backbone="mobilenet_v3_large", + head="fcn", + num_classes=datamodule.num_classes, + serializer=SegmentationLabels(visualize=True), ) # 4. Create the trainer. -trainer = flash.Trainer( - max_epochs=1, - fast_dev_run=1, -) +trainer = flash.Trainer(fast_dev_run=True) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") +# 6. Segment a few images! predictions = model.predict([ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", diff --git a/tests/image/segmentation/test_backbones.py b/tests/image/segmentation/test_backbones.py index cf09425c87..a4c850b0bc 100644 --- a/tests/image/segmentation/test_backbones.py +++ b/tests/image/segmentation/test_backbones.py @@ -20,20 +20,12 @@ @pytest.mark.parametrize(["backbone"], [ - pytest.param("fcn_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("deeplabv3_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param( - "lraspp_mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") - ), - pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), + pytest.param("resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), ]) -def test_image_classifier_backbones_registry(backbone): +def test_semantic_segmentation_backbones_registry(backbone): img = torch.rand(1, 3, 32, 32) - backbone_fn = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone) - backbone_model = backbone_fn(10, pretrained=False) - assert backbone_model - backbone_model.eval() - res = backbone_model(img) - if isinstance(res, dict): - res = res["out"] - assert res.shape[1] == 10 + backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)(pretrained=False) + assert backbone + backbone.eval() + assert backbone(img) is not None diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index 089871fedb..b464c35ad5 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -377,6 +377,6 @@ def test_map_labels(self, tmpdir): assert labels.dtype == torch.int64 # now train with `fast_dev_run` - model = SemanticSegmentation(num_classes=2, backbone="torchvision/fcn_resnet50") + model = SemanticSegmentation(num_classes=2, backbone="resnet50", head="fcn") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, dm, strategy="freeze_unfreeze") diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py new file mode 100644 index 0000000000..ec90b03670 --- /dev/null +++ b/tests/image/segmentation/test_heads.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES +from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS + + +@pytest.mark.parametrize( + "head", [ + pytest.param("fcn", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("deeplabv3", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("lraspp", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), + ] +) +def test_semantic_segmentation_heads_registry(head): + img = torch.rand(1, 3, 32, 32) + backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet50")(pretrained=False) + head = SEMANTIC_SEGMENTATION_HEADS.get(head)(backbone, 10) + assert backbone + assert head + head.eval() + res = head(img) + if isinstance(res, dict): + res = res["out"] + assert res.shape[1] == 10 diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index e85ec85ebe..302cb99cfc 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -57,7 +57,8 @@ def test_smoke(): def test_forward(num_classes, img_shape): model = SemanticSegmentation( num_classes=num_classes, - backbone='fcn_resnet50', + backbone="resnet50", + head="fcn", ) B, C, H, W = img_shape