diff --git a/CHANGELOG.md b/CHANGELOG.md index c85b31c1a0..97085839cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73)) +- Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587)) + - Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585)) ### Changed diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index 3f95662c75..863dff2550 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -36,7 +36,7 @@ Here's the structure: Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.SemanticSegmentationData`. We select a pre-trained ``mobilenet_v3_large`` backbone with an ``fpn`` head to use for our :class:`~flash.image.segmentation.model.SemanticSegmentation` task and fine-tune on the CARLA data. -We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference. +We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference. You can check the available pretrained weights for the backbones like this `SemanticSegmentation.available_pretrained_weights("resnet18")`. Finally, we save the model. Here's the full example: diff --git a/flash/image/segmentation/backbones.py b/flash/image/segmentation/backbones.py index 15047477f4..30690cfaf1 100644 --- a/flash/image/segmentation/backbones.py +++ b/flash/image/segmentation/backbones.py @@ -32,6 +32,11 @@ def _load_smp_backbone(backbone: str, **_) -> str: short_name = encoder_name if short_name.startswith("timm-"): short_name = encoder_name[5:] + + available_weights = smp.encoders.encoders[encoder_name]["pretrained_settings"].keys() SEMANTIC_SEGMENTATION_BACKBONES( - partial(_load_smp_backbone, backbone=encoder_name), name=short_name, namespace="image/segmentation" + partial(_load_smp_backbone, backbone=encoder_name), + name=short_name, + namespace="image/segmentation", + weights_paths=available_weights, ) diff --git a/flash/image/segmentation/heads.py b/flash/image/segmentation/heads.py index e870f3e1c3..294c7f36d9 100644 --- a/flash/image/segmentation/heads.py +++ b/flash/image/segmentation/heads.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable +from typing import Union + +from torch import nn from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE @@ -33,17 +35,19 @@ def _load_smp_head( head: str, backbone: str, - pretrained: bool = True, + pretrained: Union[bool, str] = True, num_classes: int = 1, in_channels: int = 3, **kwargs, - ) -> Callable: + ) -> nn.Module: if head not in SMP_MODELS: raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}") encoder_weights = None - if pretrained: + if isinstance(pretrained, str): + encoder_weights = pretrained + elif pretrained: encoder_weights = "imagenet" return smp.create_model( diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index 59c5b4cc77..ddb50fdd47 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -77,7 +77,7 @@ def __init__( backbone_kwargs: Optional[Dict] = None, head: str = "fpn", head_kwargs: Optional[Dict] = None, - pretrained: bool = True, + pretrained: Union[bool, str] = True, loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, metrics: Union[Metric, Callable, Mapping, Sequence, None] = None, @@ -156,6 +156,16 @@ def forward(self, x) -> torch.Tensor: return out + @classmethod + def available_pretrained_weights(cls, backbone: str): + result = cls.backbones.get(backbone, with_metadata=True) + pretrained_weights = None + + if "weights_paths" in result["metadata"]: + pretrained_weights = list(result["metadata"]["weights_paths"]) + + return pretrained_weights + @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): """ diff --git a/tests/image/segmentation/test_heads.py b/tests/image/segmentation/test_heads.py index cf50ed5de5..f6bfb6fb24 100644 --- a/tests/image/segmentation/test_heads.py +++ b/tests/image/segmentation/test_heads.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import unittest.mock + import pytest import torch from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE +from flash.image.segmentation import SemanticSegmentation from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS +from tests.helpers.utils import _IMAGE_TESTING @pytest.mark.parametrize( @@ -37,3 +41,26 @@ def test_semantic_segmentation_heads_registry(head): if isinstance(res, dict): res = res["out"] assert res.shape[1] == 10 + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@unittest.mock.patch("flash.image.segmentation.heads.smp") +def test_pretrained_weights(mock_smp): + mock_smp.create_model = unittest.mock.MagicMock() + available_weights = SemanticSegmentation.available_pretrained_weights("resnet18") + backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet18")() + SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=True) + + kwargs = { + 'arch': 'unet', + 'classes': 10, + 'encoder_name': 'resnet18', + 'in_channels': 3, + "encoder_weights": "imagenet" + } + mock_smp.create_model.assert_called_with(**kwargs) + + for weight in available_weights: + SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=weight) + kwargs["encoder_weights"] = weight + mock_smp.create_model.assert_called_with(**kwargs) diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index 68fece463f..5a45226641 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -155,3 +155,8 @@ def test_serve(): def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")): SemanticSegmentation.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_available_pretrained_weights(): + assert SemanticSegmentation.available_pretrained_weights("resnet18") == ['imagenet', 'ssl', 'swsl']