From 568ad4a7c49bf98676f17530937c4e93f9502910 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 13 Jul 2021 20:50:02 +0530 Subject: [PATCH 01/11] add weights path --- flash/image/segmentation/heads.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index e870f3e1c3..174698422b 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable +from typing import Optional + +from torch import nn from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE @@ -36,8 +38,9 @@ def _load_smp_head( pretrained: bool = True, num_classes: int = 1, in_channels: int = 3, + weights_path: Optional[str] = None, **kwargs, - ) -> Callable: + ) -> nn.Module: if head not in SMP_MODELS: raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}") @@ -45,6 +48,7 @@ def _load_smp_head( encoder_weights = None if pretrained: encoder_weights = "imagenet" + encoder_weights = encoder_weights or weights_path return smp.create_model( arch=head, From 729f607d9936f6cde1673f1f1a167d2dc3d6ba79 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 14 Jul 2021 20:11:14 +0530 Subject: [PATCH 02/11] add available weights --- flash/image/segmentation/backbones.py | 7 ++++++- flash/image/segmentation/model.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index 15047477f4..30690cfaf1 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -32,6 +32,11 @@ def _load_smp_backbone(backbone: str, **_) -> str: short_name = encoder_name if short_name.startswith("timm-"): short_name = encoder_name[5:] + + available_weights = smp.encoders.encoders[encoder_name]["pretrained_settings"].keys() SEMANTIC_SEGMENTATION_BACKBONES( - partial(_load_smp_backbone, backbone=encoder_name), name=short_name, namespace="image/segmentation" + partial(_load_smp_backbone, backbone=encoder_name), + name=short_name, + namespace="image/segmentation", + weights_paths=available_weights, ) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 59c5b4cc77..4a1c9fc59f 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -156,6 +156,16 @@ def forward(self, x) -> torch.Tensor: return out + @classmethod + def available_pretrained_weights(cls, backbone: str): + result = cls.backbones.get(backbone, with_metadata=True) + pretrained_weights = None + + if "weights_paths" in result["metadata"]: + pretrained_weights = list(result["metadata"]["weights_paths"]) + + return pretrained_weights + @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): """ From 974049ce41c8981cc46199687218a8c4698f8cc6 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 14 Jul 2021 20:15:28 +0530 Subject: [PATCH 03/11] remove weight path --- flash/image/segmentation/heads.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 174698422b..326454a367 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Optional from torch import nn @@ -38,7 +37,6 @@ def _load_smp_head( pretrained: bool = True, num_classes: int = 1, in_channels: int = 3, - weights_path: Optional[str] = None, **kwargs, ) -> nn.Module: @@ -48,7 +46,6 @@ def _load_smp_head( encoder_weights = None if pretrained: encoder_weights = "imagenet" - encoder_weights = encoder_weights or weights_path return smp.create_model( arch=head, From 16e91fc27bbe01f6a85b1dbd778c9f7d9cb15432 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 14 Jul 2021 20:20:54 +0530 Subject: [PATCH 04/11] add tests :white_check_mark: --- tests/image/segmentation/test_model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 68fece463f..332fc52cd0 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -155,3 +155,8 @@ def test_serve(): def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): SemanticSegmentation.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_available_pretrained_weights(backbone): + assert SemanticSegmentation.available_pretrained_weights("resnet18") == ['imagenet', 'ssl', 'swsl'] From f018b4a6b1081a0a8f0aac21824385e15f0a482a Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 14 Jul 2021 20:45:37 +0530 Subject: [PATCH 05/11] fix --- tests/image/segmentation/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 332fc52cd0..5a45226641 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -158,5 +158,5 @@ def test_load_from_checkpoint_dependency_error(): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") -def test_available_pretrained_weights(backbone): +def test_available_pretrained_weights(): assert SemanticSegmentation.available_pretrained_weights("resnet18") == ['imagenet', 'ssl', 'swsl'] From 9428c29295aadd5d6a99fc3f29fb8dc545b5239f Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 14 Jul 2021 21:03:33 +0530 Subject: [PATCH 06/11] update --- docs/source/reference/semantic_segmentation.rst | 2 +- flash/image/segmentation/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index 3f95662c75..863dff2550 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -36,7 +36,7 @@ Here's the structure: Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.SemanticSegmentationData`. We select a pre-trained ``mobilenet_v3_large`` backbone with an ``fpn`` head to use for our :class:`~flash.image.segmentation.model.SemanticSegmentation` task and fine-tune on the CARLA data. -We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference. +We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference. You can check the available pretrained weights for the backbones like this `SemanticSegmentation.available_pretrained_weights("resnet18")`. Finally, we save the model. Here's the full example: diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 4a1c9fc59f..ddb50fdd47 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -77,7 +77,7 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: str = "fpn", head_kwargs: Optional[Dict] = None, - pretrained: bool = True, + pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, From 2070ba7bdc5bae0bae6e382f29e2075851b213b7 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 14 Jul 2021 21:46:42 +0530 Subject: [PATCH 07/11] add str pretrained --- flash/image/segmentation/heads.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index 326454a367..e897d21543 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from typing import Union from torch import nn @@ -34,7 +35,7 @@ def _load_smp_head( head: str, backbone: str, - pretrained: bool = True, + pretrained: Union[bool, str] = True, num_classes: int = 1, in_channels: int = 3, **kwargs, @@ -44,8 +45,10 @@ def _load_smp_head( raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}") encoder_weights = None - if pretrained: + if pretrained is True: encoder_weights = "imagenet" + elif isinstance(pretrained, str): + encoder_weights = pretrained return smp.create_model( arch=head, From 9aad7e688af46f936cf1d447722958ffa3c1785c Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 14 Jul 2021 22:03:49 +0530 Subject: [PATCH 08/11] add test :white_check_mark: --- tests/image/segmentation/test_heads.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py index cf50ed5de5..8313a04e09 100644 --- a/tests/image/segmentation/test_heads.py +++ b/tests/image/segmentation/test_heads.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest.mock + import pytest import torch from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE +from flash.image.segmentation import SemanticSegmentation from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS @@ -37,3 +40,25 @@ def test_semantic_segmentation_heads_registry(head): if isinstance(res, dict): res = res["out"] assert res.shape[1] == 10 + + +@unittest.mock.patch("flash.image.segmentation.heads.smp") +def test_pretrained_weights(mock_smp): + mock_smp.create_model = unittest.mock.MagicMock() + available_weights = SemanticSegmentation.available_pretrained_weights("resnet18") + backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet18")() + SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=True) + + kwargs = { + 'arch': 'unet', + 'classes': 10, + 'encoder_name': 'resnet18', + 'in_channels': 3, + "encoder_weights": "imagenet" + } + mock_smp.create_model.assert_called_with(**kwargs) + + for weight in available_weights: + SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=weight) + kwargs["encoder_weights"] = weight + mock_smp.create_model.assert_called_with(**kwargs) From 55e1102d95c4f267b1bbb5ea831b3559c1ab90eb Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Wed, 14 Jul 2021 22:08:03 +0530 Subject: [PATCH 09/11] fix --- tests/image/segmentation/test_heads.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py index 8313a04e09..f6bfb6fb24 100644 --- a/tests/image/segmentation/test_heads.py +++ b/tests/image/segmentation/test_heads.py @@ -20,6 +20,7 @@ from flash.image.segmentation import SemanticSegmentation from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS +from tests.helpers.utils import _IMAGE_TESTING @pytest.mark.parametrize( @@ -42,6 +43,7 @@ def test_semantic_segmentation_heads_registry(head): assert res.shape[1] == 10 +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @unittest.mock.patch("flash.image.segmentation.heads.smp") def test_pretrained_weights(mock_smp): mock_smp.create_model = unittest.mock.MagicMock() From c283bb6793e21c8ee48aa86889d8641e997bfb3e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 14 Jul 2021 18:38:13 +0100 Subject: [PATCH 10/11] Update flash/image/segmentation/heads.py --- flash/image/segmentation/heads.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index e897d21543..294c7f36d9 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -45,10 +45,10 @@ def _load_smp_head( raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}") encoder_weights = None - if pretrained is True: - encoder_weights = "imagenet" - elif isinstance(pretrained, str): + if isinstance(pretrained, str): encoder_weights = pretrained + elif pretrained: + encoder_weights = "imagenet" return smp.create_model( arch=head, From d996f329fa8b13334448e13f60dd308e43187ba5 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 14 Jul 2021 20:14:21 +0100 Subject: [PATCH 11/11] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index afdf24e5da..ff5acd0bdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73)) +- Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))