Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

segmentation_models.pytorch integration #562

Merged
merged 19 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _compare_version(package: str, op, version) -> bool:
_CYTOOLZ_AVAILABLE = _module_available("cytoolz")
_UVICORN_AVAILABLE = _module_available("uvicorn")
_PIL_AVAILABLE = _module_available("PIL")
_SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch")

if Version:
_TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0")
Expand Down
20 changes: 19 additions & 1 deletion flash/image/segmentation/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.image.backbones import catch_url_error

if _TORCHVISION_AVAILABLE:
from torchvision.models import mobilenetv3, resnet

if _SEGMENTATION_MODELS_AVAILABLE:
import segmentation_models_pytorch as smp
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved

ENCODERS = smp.encoders.get_encoder_names()

MOBILENET_MODELS = ["mobilenet_v3_large"]
RESNET_MODELS = ["resnet50", "resnet101"]

Expand Down Expand Up @@ -56,3 +61,16 @@ def _load_mobilenetv3(model_name: str, pretrained: bool = True):
namespace="image/segmentation",
package="torchvision",
)

if _SEGMENTATION_MODELS_AVAILABLE:

def _load_smp_backbone(backbone: str, **_) -> str:
return backbone

for encoder_name in ENCODERS:
SEMANTIC_SEGMENTATION_BACKBONES(
partial(_load_smp_backbone, backbone=encoder_name),
backbone=encoder_name,
name=encoder_name,
namespace="image/segmentation"
)
91 changes: 45 additions & 46 deletions flash/image/segmentation/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,65 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
from functools import partial
from typing import Callable

import torch.nn as nn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _BOLTS_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.imports import _SEGMENTATION_MODELS_AVAILABLE, _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
from torchvision.models import MobileNetV3, ResNet
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3
from torchvision.models.segmentation.fcn import FCN, FCNHead
from torchvision.models.segmentation.lraspp import LRASPP

if _BOLTS_AVAILABLE:
if os.getenv("WARN_MISSING_PACKAGE") == "0":
with warnings.catch_warnings(record=True) as w:
from pl_bolts.models.vision import UNet
else:
from pl_bolts.models.vision import UNet
if _SEGMENTATION_MODELS_AVAILABLE:
import segmentation_models_pytorch as smp

SMP_MODEL_CLASS = [
smp.Unet, smp.UnetPlusPlus, smp.MAnet, smp.Linknet, smp.FPN, smp.PSPNet, smp.DeepLabV3, smp.DeepLabV3Plus,
smp.PAN
]
SMP_MODELS = {a.__name__.lower(): a for a in SMP_MODEL_CLASS}

SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones")

if _SEGMENTATION_MODELS_AVAILABLE:

def _load_smp_head(
head: str,
backbone: str,
pretrained: bool = True,
num_classes: int = 1,
in_channels: int = 3,
**kwargs,
) -> Callable:

if head not in SMP_MODELS:
raise NotImplementedError(f"{head} is not implemented! Supported heads -> {SMP_MODELS.keys()}")

encoder_weights = None
if pretrained:
encoder_weights = "imagenet"

return smp.create_model(
arch=head,
encoder_name=backbone,
encoder_weights=encoder_weights,
classes=num_classes,
in_channels=in_channels,
**kwargs,
)

for model_name in SMP_MODELS:
SEMANTIC_SEGMENTATION_HEADS(
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
partial(_load_smp_head, head=model_name),
name=model_name,
namespace="image/segmentation",
package="segmentation_models.pytorch"
)

if _TORCHVISION_AVAILABLE:

def _get_backbone_meta(backbone):
Expand Down Expand Up @@ -67,29 +99,6 @@ def _get_backbone_meta(backbone):
)
return backbone, out_layer, out_inplanes, aux_layer, aux_inplanes

def _load_fcn_deeplabv3(model_name, backbone, num_classes):
backbone, out_layer, out_inplanes, aux_layer, aux_inplanes = _get_backbone_meta(backbone)

return_layers = {out_layer: 'out'}
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

model_map = {
"deeplabv3": (DeepLabHead, DeepLabV3),
"fcn": (FCNHead, FCN),
}
classifier = model_map[model_name][0](out_inplanes, num_classes)
base_model = model_map[model_name][1]

return base_model(backbone, classifier, None)

for model_name in ["fcn", "deeplabv3"]:
SEMANTIC_SEGMENTATION_HEADS(
fn=partial(_load_fcn_deeplabv3, model_name),
name=model_name,
namespace="image/segmentation",
package="torchvision",
)

def _load_lraspp(backbone, num_classes):
backbone, high_pos, high_channels, low_pos, low_channels = _get_backbone_meta(backbone)
backbone = IntermediateLayerGetter(backbone, return_layers={low_pos: 'low', high_pos: 'high'})
Expand All @@ -101,13 +110,3 @@ def _load_lraspp(backbone, num_classes):
namespace="image/segmentation",
package="torchvision",
)

if _BOLTS_AVAILABLE:

def _load_bolts_unet(_, num_classes: int, **kwargs) -> nn.Module:
rank_zero_warn("The UNet model does not require a backbone, so the backbone will be ignored.", UserWarning)
return UNet(num_classes, **kwargs)

SEMANTIC_SEGMENTATION_HEADS(
fn=_load_bolts_unet, name="unet", namespace="image/segmentation", package="bolts", type="unet"
)
4 changes: 3 additions & 1 deletion flash/image/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def __init__(
else:
self.backbone = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

self.head = self.heads.get(head)(self.backbone, num_classes, **head_kwargs)
self.head: nn.Module = self.heads.get(head)(
backbone=self.backbone, num_classes=num_classes, pretrained=pretrained, **head_kwargs
)

def training_step(self, batch: Any, batch_idx: int) -> Any:
batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
Expand Down
1 change: 1 addition & 0 deletions requirements/datatype_image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ matplotlib
pycocotools>=2.0.2 ; python_version >= "3.7"
fiftyone
pystiche>=0.7.2
segmentation-models-pytorch