Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

add available weights to SMP #587

Merged
merged 14 commits into from
Jul 14, 2021
2 changes: 1 addition & 1 deletion docs/source/reference/semantic_segmentation.rst
Original file line number Diff line number Diff line change
@@ -36,7 +36,7 @@ Here's the structure:

Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.SemanticSegmentationData`.
We select a pre-trained ``mobilenet_v3_large`` backbone with an ``fpn`` head to use for our :class:`~flash.image.segmentation.model.SemanticSegmentation` task and fine-tune on the CARLA data.
We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference.
We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference. You can check the available pretrained weights for the backbones like this `SemanticSegmentation.available_pretrained_weights("resnet18")`.
Finally, we save the model.
Here's the full example:

7 changes: 6 additions & 1 deletion flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
@@ -32,6 +32,11 @@ def _load_smp_backbone(backbone: str, **_) -> str:
short_name = encoder_name
if short_name.startswith("timm-"):
short_name = encoder_name[5:]

available_weights = smp.encoders.encoders[encoder_name]["pretrained_settings"].keys()
SEMANTIC_SEGMENTATION_BACKBONES(
partial(_load_smp_backbone, backbone=encoder_name), name=short_name, namespace="image/segmentation"
partial(_load_smp_backbone, backbone=encoder_name),
name=short_name,
namespace="image/segmentation",
weights_paths=available_weights,
)
5 changes: 3 additions & 2 deletions flash/image/segmentation/heads.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable

from torch import nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE
@@ -37,7 +38,7 @@ def _load_smp_head(
num_classes: int = 1,
in_channels: int = 3,
**kwargs,
) -> Callable:
) -> nn.Module:

if head not in SMP_MODELS:
raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}")
12 changes: 11 additions & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
@@ -77,7 +77,7 @@ def __init__(
backbone_kwargs: Optional[Dict] = None,
head: str = "fpn",
head_kwargs: Optional[Dict] = None,
pretrained: bool = True,
pretrained: Union[bool, str] = True,
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
metrics: Union[Metric, Callable, Mapping, Sequence, None] = None,
@@ -156,6 +156,16 @@ def forward(self, x) -> torch.Tensor:

return out

@classmethod
def available_pretrained_weights(cls, backbone: str):
result = cls.backbones.get(backbone, with_metadata=True)
pretrained_weights = None

if "weights_paths" in result["metadata"]:
pretrained_weights = list(result["metadata"]["weights_paths"])

return pretrained_weights

@staticmethod
def _ci_benchmark_fn(history: List[Dict[str, Any]]):
"""
5 changes: 5 additions & 0 deletions tests/image/segmentation/test_model.py
Original file line number Diff line number Diff line change
@@ -155,3 +155,8 @@ def test_serve():
def test_load_from_checkpoint_dependency_error():
with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[image]'")):
SemanticSegmentation.load_from_checkpoint("not_a_real_checkpoint.pt")


@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
def test_available_pretrained_weights():
assert SemanticSegmentation.available_pretrained_weights("resnet18") == ['imagenet', 'ssl', 'swsl']