Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into feature/576_add_field_parameter_to_seq2se2…
Browse files Browse the repository at this point in the history
…_tasks_for_json_datasets
  • Loading branch information
ethanwharris authored Jul 14, 2021
2 parents dc57792 + f6e0d20 commit 1fdf594
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73))

- Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587))

- Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585))

### Changed
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/semantic_segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Here's the structure:
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.SemanticSegmentationData`.
We select a pre-trained ``mobilenet_v3_large`` backbone with an ``fpn`` head to use for our :class:`~flash.image.segmentation.model.SemanticSegmentation` task and fine-tune on the CARLA data.
We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference.
We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference. You can check the available pretrained weights for the backbones like this `SemanticSegmentation.available_pretrained_weights("resnet18")`.
Finally, we save the model.
Here's the full example:

Expand Down
7 changes: 6 additions & 1 deletion flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def _load_smp_backbone(backbone: str, **_) -> str:
short_name = encoder_name
if short_name.startswith("timm-"):
short_name = encoder_name[5:]

available_weights = smp.encoders.encoders[encoder_name]["pretrained_settings"].keys()
SEMANTIC_SEGMENTATION_BACKBONES(
partial(_load_smp_backbone, backbone=encoder_name), name=short_name, namespace="image/segmentation"
partial(_load_smp_backbone, backbone=encoder_name),
name=short_name,
namespace="image/segmentation",
weights_paths=available_weights,
)
12 changes: 8 additions & 4 deletions flash/image/segmentation/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable
from typing import Union

from torch import nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
Expand All @@ -33,17 +35,19 @@
def _load_smp_head(
head: str,
backbone: str,
pretrained: bool = True,
pretrained: Union[bool, str] = True,
num_classes: int = 1,
in_channels: int = 3,
**kwargs,
) -> Callable:
) -> nn.Module:

if head not in SMP_MODELS:
raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}")

encoder_weights = None
if pretrained:
if isinstance(pretrained, str):
encoder_weights = pretrained
elif pretrained:
encoder_weights = "imagenet"

return smp.create_model(
Expand Down
12 changes: 11 additions & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
backbone_kwargs: Optional[Dict] = None,
head: str = "fpn",
head_kwargs: Optional[Dict] = None,
pretrained: bool = True,
pretrained: Union[bool, str] = True,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
Expand Down Expand Up @@ -156,6 +156,16 @@ def forward(self, x) -> torch.Tensor:

return out

@classmethod
def available_pretrained_weights(cls, backbone: str):
result = cls.backbones.get(backbone, with_metadata=True)
pretrained_weights = None

if "weights_paths" in result["metadata"]:
pretrained_weights = list(result["metadata"]["weights_paths"])

return pretrained_weights

@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/image/segmentation/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest.mock

import pytest
import torch

from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
from flash.image.segmentation import SemanticSegmentation
from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS
from tests.helpers.utils import _IMAGE_TESTING


@pytest.mark.parametrize(
Expand All @@ -37,3 +41,26 @@ def test_semantic_segmentation_heads_registry(head):
if isinstance(res, dict):
res = res["out"]
assert res.shape[1] == 10


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
@unittest.mock.patch("flash.image.segmentation.heads.smp")
def test_pretrained_weights(mock_smp):
mock_smp.create_model = unittest.mock.MagicMock()
available_weights = SemanticSegmentation.available_pretrained_weights("resnet18")
backbone = SEMANTIC_SEGMENTATION_BACKBONES.get("resnet18")()
SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=True)

kwargs = {
'arch': 'unet',
'classes': 10,
'encoder_name': 'resnet18',
'in_channels': 3,
"encoder_weights": "imagenet"
}
mock_smp.create_model.assert_called_with(**kwargs)

for weight in available_weights:
SEMANTIC_SEGMENTATION_HEADS.get("unet")(backbone=backbone, num_classes=10, pretrained=weight)
kwargs["encoder_weights"] = weight
mock_smp.create_model.assert_called_with(**kwargs)
5 changes: 5 additions & 0 deletions tests/image/segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,8 @@ def test_serve():
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")):
SemanticSegmentation.load_from_checkpoint("not_a_real_checkpoint.pt")


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_available_pretrained_weights():
assert SemanticSegmentation.available_pretrained_weights("resnet18") == ['imagenet', 'ssl', 'swsl']

0 comments on commit 1fdf594

Please sign in to comment.