From 536032715898b7adfc93dcb2b5d03c1b68fdc96d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 15 Jun 2021 15:20:24 +0100 Subject: [PATCH 01/10] Enable semantic segmentation backbone and head --- flash/core/model.py | 4 +- flash/image/segmentation/backbones.py | 96 +++------------ flash/image/segmentation/heads.py | 112 ++++++++++++++++++ flash/image/segmentation/model.py | 38 ++++-- .../finetuning/semantic_segmentation.py | 11 +- tests/image/segmentation/test_backbones.py | 24 ++-- tests/image/segmentation/test_data.py | 2 +- tests/image/segmentation/test_heads.py | 40 +++++++ tests/image/segmentation/test_model.py | 3 +- 9 files changed, 217 insertions(+), 113 deletions(-) create mode 100644 flash/image/segmentation/heads.py create mode 100644 tests/image/segmentation/test_heads.py diff --git a/flash/core/model.py b/flash/core/model.py index 8664195dfe..a52de73dfb 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -488,8 +488,8 @@ def available_backbones(cls) -> List[str]: return registry.available_keys() @classmethod - def available_models(cls) -> List[str]: - registry: Optional[FlashRegistry] = getattr(cls, "models", None) + def available_heads(cls) -> List[str]: + registry: Optional[FlashRegistry] = getattr(cls, "heads", None) if registry is None: return [] return registry.available_keys() diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index fa26f8e76e..8c79663c5a 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -11,106 +11,48 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import warnings from functools import partial -import torch.nn as nn -from deprecate import deprecated -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn - from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TORCHVISION_AVAILABLE from flash.image.backbones import catch_url_error if _TORCHVISION_AVAILABLE: - from torchvision.models import segmentation - -if _BOLTS_AVAILABLE: - if os.getenv("WARN_MISSING_PACKAGE") == "0": - with warnings.catch_warnings(record=True) as w: - from pl_bolts.models.vision import UNet - else: - from pl_bolts.models.vision import UNet + from torchvision.models import mobilenetv3, resnet -FCN_MODELS = ["fcn_resnet50", "fcn_resnet101"] -DEEPLABV3_MODELS = ["deeplabv3_resnet50", "deeplabv3_resnet101", "deeplabv3_mobilenet_v3_large"] -LRASPP_MODELS = ["lraspp_mobilenet_v3_large"] +RESNET_MODELS = ["resnet50", "resnet101"] +MOBILENET_MODELS = ["mobilenet_v3_large"] SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") if _TORCHVISION_AVAILABLE: - def _fn_fcn_deeplabv3(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module: - model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs) - in_channels = model.classifier[-1].in_channels - model.classifier[-1] = nn.Conv2d(in_channels, num_classes, 1) - return model - - for model_name in FCN_MODELS + DEEPLABV3_MODELS: - _type = model_name.split("_")[0] + def _load_resnet(model_name: str, pretrained: bool = True): + backbone = resnet.__dict__[model_name]( + pretrained=pretrained, + replace_stride_with_dilation=[False, True, True], + ) + return backbone + for model_name in RESNET_MODELS: SEMANTIC_SEGMENTATION_BACKBONES( - fn=catch_url_error(partial(_fn_fcn_deeplabv3, model_name)), + fn=catch_url_error(partial(_load_resnet, model_name)), name=model_name, namespace="image/segmentation", package="torchvision", - type=_type ) - SEMANTIC_SEGMENTATION_BACKBONES( - fn=deprecated( - target=None, - stream=partial(warnings.warn, category=UserWarning), - deprecated_in="0.3.1", - remove_in="0.5.0", - template_mgs="The 'torchvision/fcn_resnet50' backbone has been deprecated since v%(deprecated_in)s in " - "favor of 'fcn_resnet50'. It will be removed in v%(remove_in)s.", - )(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet50")), - name="torchvision/fcn_resnet50", - ) - - SEMANTIC_SEGMENTATION_BACKBONES( - fn=deprecated( - target=None, - stream=partial(warnings.warn, category=UserWarning), - deprecated_in="0.3.1", - remove_in="0.5.0", - template_mgs="The 'torchvision/fcn_resnet101' backbone has been deprecated since v%(deprecated_in)s in " - "favor of 'fcn_resnet101'. It will be removed in v%(remove_in)s.", - )(SEMANTIC_SEGMENTATION_BACKBONES.get("fcn_resnet101")), - name="torchvision/fcn_resnet101", - ) - - def _fn_lraspp(model_name: str, num_classes: int, pretrained: bool = True, **kwargs) -> nn.Module: - model: nn.Module = getattr(segmentation, model_name, None)(pretrained, **kwargs) + def _load_mobilenetv3(model_name: str, pretrained: bool = True): + backbone = mobilenetv3.__dict__[model_name]( + pretrained=pretrained, + _dilated=True, + ).features + return backbone - low_channels = model.classifier.low_classifier.in_channels - high_channels = model.classifier.high_classifier.in_channels - - model.classifier.low_classifier = nn.Conv2d(low_channels, num_classes, 1) - model.classifier.high_classifier = nn.Conv2d(high_channels, num_classes, 1) - return model - - for model_name in LRASPP_MODELS: + for model_name in MOBILENET_MODELS: SEMANTIC_SEGMENTATION_BACKBONES( - fn=catch_url_error(partial(_fn_lraspp, model_name)), + fn=catch_url_error(partial(_load_mobilenetv3, model_name)), name=model_name, namespace="image/segmentation", package="torchvision", - type="lraspp" ) - -if _BOLTS_AVAILABLE: - - def load_bolts_unet(num_classes: int, pretrained: bool = False, **kwargs) -> nn.Module: - if pretrained: - rank_zero_warn( - "No pretrained weights are available for the pl_bolts.models.vision.UNet model. This backbone will be " - "initialized with random weights!", UserWarning - ) - return UNet(num_classes, **kwargs) - - SEMANTIC_SEGMENTATION_BACKBONES( - fn=load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet" - ) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py new file mode 100644 index 0000000000..2529298f46 --- /dev/null +++ b/flash/image/segmentation/heads.py @@ -0,0 +1,112 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from functools import partial + +import torch.nn as nn +from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + from torchvision.models import MobileNetV3, ResNet + from torchvision.models._utils import IntermediateLayerGetter + from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3 + from torchvision.models.segmentation.fcn import FCN, FCNHead + from torchvision.models.segmentation.lraspp import LRASPP + +if _BOLTS_AVAILABLE: + if os.getenv("WARN_MISSING_PACKAGE") == "0": + with warnings.catch_warnings(record=True) as w: + from pl_bolts.models.vision import UNet + else: + from pl_bolts.models.vision import UNet + +RESNET_MODELS = ["resnet50", "resnet101"] +MOBILENET_MODELS = ["mobilenet_v3_large"] + +SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones") + +if _TORCHVISION_AVAILABLE: + + def _get_backbone_meta(backbone): + if isinstance(backbone, ResNet): + out_layer = 'layer4' + out_inplanes = 2048 + aux_layer = 'layer3' + aux_inplanes = 1024 + elif isinstance(backbone, MobileNetV3): + # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False) + ] + [len(backbone) - 1] + out_pos = stage_indices[-1] # use C5 which has output_stride = 16 + out_layer = str(out_pos) + out_inplanes = backbone[out_pos].out_channels + aux_pos = stage_indices[-4] # use C2 here which has output_stride = 8 + aux_layer = str(aux_pos) + aux_inplanes = backbone[aux_pos].out_channels + else: + raise MisconfigurationException( + f"{type(backbone)} backbone is not currently supported for semantic segmentation." + ) + return out_layer, out_inplanes, aux_layer, aux_inplanes + + def _load_fcn_deeplabv3(model_name, backbone, num_classes): + out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone) + + return_layers = {out_layer: 'out'} + backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) + + model_map = { + "deeplabv3": (DeepLabHead, DeepLabV3), + "fcn": (FCNHead, FCN), + } + classifier = model_map[model_name][0](out_inplanes, num_classes) + base_model = model_map[model_name][1] + + return base_model(backbone, classifier, None) + + for model_name in ["fcn", "deeplabv3"]: + SEMANTIC_SEGMENTATION_HEADS( + fn=partial(_load_fcn_deeplabv3, model_name), + name=model_name, + namespace="image/segmentation", + package="torchvision", + ) + + def _load_lraspp(backbone, num_classes): + high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone) + backbone = IntermediateLayerGetter(backbone, return_layers={low_pos: 'low', high_pos: 'high'}) + return LRASPP(backbone, low_channels, high_channels, num_classes) + + SEMANTIC_SEGMENTATION_HEADS( + fn=_load_lraspp, + name="lraspp", + namespace="image/segmentation", + package="torchvision", + ) + +if _BOLTS_AVAILABLE: + + def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module: + rank_zero_warn("The UNet model does not require a backbone, so the backbone will be ignored.", UserWarning) + return UNet(num_classes, **kwargs) + + SEMANTIC_SEGMENTATION_HEADS( + fn=_load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet" + ) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 1a810afa7f..fa0bae0cff 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -24,6 +24,7 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES +from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS from flash.image.segmentation.serialization import SegmentationLabels if _KORNIA_AVAILABLE: @@ -54,14 +55,15 @@ class SemanticSegmentation(ClassificationTask): Args: num_classes: Number of classes to classify. - backbone: A string or (model, num_features) tuple to use to compute image features, - defaults to ``"torchvision/fcn_resnet50"``. + backbone: A string or model to use to compute image features. backbone_kwargs: Additional arguments for the backbone configuration. - pretrained: Use a pretrained backbone, defaults to ``False``. - loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. - optimizer: Optimizer to use for training, defaults to :class:`torch.optim.AdamW`. - metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.IoU`. - learning_rate: Learning rate to use for training, defaults to ``1e-3``. + head: A string or (model, num_features) tuple to use to compute image features. + head_kwargs: Additional arguments for the head configuration. + pretrained: Use a pretrained backbone. + loss_fn: Loss function for training. + optimizer: Optimizer to use for training. + metrics: Metrics to compute for training and evaluation. + learning_rate: Learning rate to use for training. multi_label: Whether the targets are multi-label or not. serializer: The :class:`~flash.core.data.process.Serializer` to use when serializing prediction outputs. """ @@ -70,11 +72,15 @@ class SemanticSegmentation(ClassificationTask): backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES + heads: FlashRegistry = SEMANTIC_SEGMENTATION_HEADS + def __init__( self, num_classes: int, - backbone: Union[str, Tuple[nn.Module, int]] = "fcn_resnet50", + backbone: Union[str, nn.Module] = "resnet50", backbone_kwargs: Optional[Dict] = None, + head: str = "fcn", + head_kwargs: Optional[Dict] = None, pretrained: bool = True, loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, @@ -113,8 +119,15 @@ def __init__( if not backbone_kwargs: backbone_kwargs = {} - # TODO: pretrained to True causes some issues - self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs) + if not head_kwargs: + head_kwargs = {} + + if isinstance(backbone, nn.Module): + self.backbone = backbone + else: + self.backbone = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs) + + self.head = self.heads.get(head)(self.backbone, num_classes, **head_kwargs) def training_step(self, batch: Any, batch_idx: int) -> Any: batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) @@ -134,8 +147,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A return batch def forward(self, x) -> torch.Tensor: - # infer the image to the model - res = self.backbone(x) + res = self.head(x) # some frameworks like torchvision return a dict. # In particular, torchvision segmentation models return the output logits @@ -145,7 +157,7 @@ def forward(self, x) -> torch.Tensor: elif torch.is_tensor(res): out = res else: - raise NotImplementedError(f"Unsupported output type: {type(out)}") + raise NotImplementedError(f"Unsupported output type: {type(res)}") return out diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 049bd9d318..94c99244b0 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -37,12 +37,16 @@ # 2.2 Visualise the samples datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) -# 3.a List available backbones -print(SemanticSegmentation.available_backbones()) +# 3.a List available backbones and heads +print(f"Backbones: {SemanticSegmentation.available_backbones()}") +print(f"Heads: {SemanticSegmentation.available_heads()}") # 3.b Build the model model = SemanticSegmentation( - backbone="fcn_resnet50", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True) + backbone="mobilenet_v3_large", + head="deeplabv3", + num_classes=datamodule.num_classes, + serializer=SegmentationLabels(visualize=True), ) # 4. Create the trainer. @@ -54,6 +58,7 @@ # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") +# 6. Segment a few images! predictions = model.predict([ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", diff --git a/tests/image/segmentation/test_backbones.py b/tests/image/segmentation/test_backbones.py index 230ca0bc14..0b2b452e17 100644 --- a/tests/image/segmentation/test_backbones.py +++ b/tests/image/segmentation/test_backbones.py @@ -13,26 +13,18 @@ # limitations under the License. import pytest import torch -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE +from pytorch_lightning.utilities import _TORCHVISION_AVAILABLE from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES @pytest.mark.parametrize(["backbone"], [ - pytest.param("fcn_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param("deeplabv3_resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), - pytest.param( - "lraspp_mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision") - ), - pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), + pytest.param("resnet50", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("mobilenet_v3_large", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), ]) -def test_image_classifier_backbones_registry(backbone): +def test_semantic_segmentation_backbones_registry(backbone): img = torch.rand(1, 3, 32, 32) - backbone_fn = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone) - backbone_model = backbone_fn(10, pretrained=False) - assert backbone_model - backbone_model.eval() - res = backbone_model(img) - if isinstance(res, dict): - res = res["out"] - assert res.shape[1] == 10 + backbone = SEMANTIC_SEGMENTATION_BACKBONES.get(backbone)(pretrained=False) + assert backbone + backbone.eval() + assert backbone(img) is not None diff --git a/tests/image/segmentation/test_data.py b/tests/image/segmentation/test_data.py index a45f0a947a..2c14503541 100644 --- a/tests/image/segmentation/test_data.py +++ b/tests/image/segmentation/test_data.py @@ -306,6 +306,6 @@ def test_map_labels(self, tmpdir): assert labels.dtype == torch.int64 # now train with `fast_dev_run` - model = SemanticSegmentation(num_classes=2, backbone="torchvision/fcn_resnet50") + model = SemanticSegmentation(num_classes=2, backbone="resnet50", head="fcn") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.finetune(model, dm, strategy="freeze_unfreeze") diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py new file mode 100644 index 0000000000..218dbbf441 --- /dev/null +++ b/tests/image/segmentation/test_heads.py @@ -0,0 +1,40 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE + +from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES +from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS + + +@pytest.mark.parametrize( + "head", [ + pytest.param("fcn", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("deeplabv3", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("lraspp", marks=pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="No torchvision")), + pytest.param("unet", marks=pytest.mark.skipif(not _BOLTS_AVAILABLE, reason="No bolts")), + ] +) +def test_semantic_segmentation_heads_registry(head): + img = torch.rand(1, 3, 32, 32) + backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet50")(pretrained=False) + head = SEMANTIC_SEGMENTATION_HEADS.get(head)(backbone, 10) + assert backbone + assert head + head.eval() + res = head(img) + if isinstance(res, dict): + res = res["out"] + assert res.shape[1] == 10 diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index e85ec85ebe..302cb99cfc 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -57,7 +57,8 @@ def test_smoke(): def test_forward(num_classes, img_shape): model = SemanticSegmentation( num_classes=num_classes, - backbone='fcn_resnet50', + backbone="resnet50", + head="fcn", ) B, C, H, W = img_shape From 6091c25bfe2a4e0b2e7c63fe9e34151858c79c62 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 15 Jun 2021 15:27:40 +0100 Subject: [PATCH 02/10] Fix --- flash/image/segmentation/backbones.py | 2 +- flash/image/segmentation/heads.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index 8c79663c5a..7244caf6c5 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -46,7 +46,7 @@ def _load_mobilenetv3(model_name: str, pretrained: bool = True): backbone = mobilenetv3.__dict__[model_name]( pretrained=pretrained, _dilated=True, - ).features + ) return backbone for model_name in MOBILENET_MODELS: diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 2529298f46..190574e2ca 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -50,6 +50,7 @@ def _get_backbone_meta(backbone): aux_layer = 'layer3' aux_inplanes = 1024 elif isinstance(backbone, MobileNetV3): + backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False) From b72b03b3fa89116dd6b6004e4bfd436bbc63ff79 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 15 Jun 2021 15:53:45 +0100 Subject: [PATCH 03/10] Fixes --- flash/image/segmentation/heads.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 190574e2ca..ef0d684de2 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -65,10 +65,10 @@ def _get_backbone_meta(backbone): raise MisconfigurationException( f"{type(backbone)} backbone is not currently supported for semantic segmentation." ) - return out_layer, out_inplanes, aux_layer, aux_inplanes + return backbone, out_layer, out_inplanes, aux_layer, aux_inplanes def _load_fcn_deeplabv3(model_name, backbone, num_classes): - out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone) + backbone, out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone) return_layers = {out_layer: 'out'} backbone = IntermediateLayerGetter(backbone, return_layers=return_layers) @@ -91,7 +91,7 @@ def _load_fcn_deeplabv3(model_name, backbone, num_classes): ) def _load_lraspp(backbone, num_classes): - high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone) + backbone, high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone) backbone = IntermediateLayerGetter(backbone, return_layers={low_pos: 'low', high_pos: 'high'}) return LRASPP(backbone, low_channels, high_channels, num_classes) From e5624ac5f01f4ef7bce486e53ff7cbefb8d3bdad Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 15 Jun 2021 16:59:39 +0100 Subject: [PATCH 04/10] Updates --- flash_examples/finetuning/semantic_segmentation.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 94c99244b0..9efbe0be6d 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -44,16 +44,13 @@ # 3.b Build the model model = SemanticSegmentation( backbone="mobilenet_v3_large", - head="deeplabv3", + head="fcn", num_classes=datamodule.num_classes, serializer=SegmentationLabels(visualize=True), ) # 4. Create the trainer. -trainer = flash.Trainer( - max_epochs=1, - fast_dev_run=1, -) +trainer = flash.Trainer(fast_dev_run=True) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") From b7a536af7974134e2a4c13bcf73865aec4d179c9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 16 Jun 2021 10:29:24 +0100 Subject: [PATCH 05/10] Fixes --- flash/image/segmentation/heads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index ef0d684de2..66c72fe669 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -16,11 +16,11 @@ from functools import partial import torch.nn as nn -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash.core.registry import FlashRegistry -from flash.core.utilities.imports import _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: from torchvision.models import MobileNetV3, ResNet From ae093ba673825ef86f8c10498ff18558284e693a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 16 Jun 2021 10:54:53 +0100 Subject: [PATCH 06/10] Updates --- flash_examples/finetuning/semantic_segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 9efbe0be6d..d7f75d1e46 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -29,8 +29,8 @@ train_folder="data/CameraRGB", train_target_folder="data/CameraSeg", batch_size=4, - val_split=0.3, - image_size=(200, 200), # (600, 800) + val_split=0.1, + image_size=(200, 200), num_classes=21, ) From bd5892fdfe935b5afbab5dcc0ccc3f908cc2b2db Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 16 Jun 2021 10:56:35 +0100 Subject: [PATCH 07/10] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8e1f16c16..4dae6cbb71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Split `backbone` argument to `SemanticSegmentation` into `backbone` and `head` arguments ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412)) ### Deprecated @@ -22,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed a bug where the `DefaultDataKeys.METADATA` couldn't be a dict ([#393](https://github.com/PyTorchLightning/lightning-flash/pull/393)) +- Fixed a bug where the `SemanticSegmentation` task would not work as expected with finetuning callbacks ([#412](https://github.com/PyTorchLightning/lightning-flash/pull/412)) ## [0.3.2] - 2021-06-08 From c867c92364b50e96323e2ee87efd67aef9b684c8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 16 Jun 2021 11:21:49 +0100 Subject: [PATCH 08/10] Updates --- flash/image/segmentation/heads.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 66c72fe669..e793b53fed 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -44,6 +44,9 @@ if _TORCHVISION_AVAILABLE: def _get_backbone_meta(backbone): + """Adapted from torchvision.models.segmentation.segmentation._segm_model: + https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/segmentation.py#L25 + """ if isinstance(backbone, ResNet): out_layer = 'layer4' out_inplanes = 2048 @@ -53,8 +56,11 @@ def _get_backbone_meta(backbone): backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False) - ] + [len(backbone) - 1] + stage_indices = sum([ + [0], + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)], + [len(backbone) - 1], + ]) out_pos = stage_indices[-1] # use C5 which has output_stride = 16 out_layer = str(out_pos) out_inplanes = backbone[out_pos].out_channels From f8a1432ac5be63bd1517c4ea912c1612ce3bc51a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 16 Jun 2021 12:48:50 +0100 Subject: [PATCH 09/10] Fixes --- flash/image/segmentation/heads.py | 7 ++----- tests/image/segmentation/test_heads.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index e793b53fed..4444b5c3ab 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -56,11 +56,8 @@ def _get_backbone_meta(backbone): backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = sum([ - [0], - [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)], - [len(backbone) - 1], - ]) + stage_indices = ([0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + + [len(backbone) - 1]) out_pos = stage_indices[-1] # use C5 which has output_stride = 16 out_layer = str(out_pos) out_inplanes = backbone[out_pos].out_channels diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py index 218dbbf441..ec90b03670 100644 --- a/tests/image/segmentation/test_heads.py +++ b/tests/image/segmentation/test_heads.py @@ -13,8 +13,8 @@ # limitations under the License. import pytest import torch -from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS From e256ef1c5b96714fb16fc3aa409a2a7ccd19017a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 16 Jun 2021 12:50:58 +0100 Subject: [PATCH 10/10] Fixes --- flash/image/segmentation/heads.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 4444b5c3ab..eab47e0cf5 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -56,8 +56,8 @@ def _get_backbone_meta(backbone): backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. - stage_indices = ([0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + - [len(backbone) - 1]) + stage_indices = [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + stage_indices = [0] + stage_indices + [len(backbone) - 1] out_pos = stage_indices[-1] # use C5 which has output_stride = 16 out_layer = str(out_pos) out_inplanes = backbone[out_pos].out_channels