diff --git a/.gitignore b/.gitignore
index 73b96a16dd..26ab5033dc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -153,3 +153,5 @@ wmt_en_ro
action_youtube_naudio
kinetics
movie_posters
+CameraRGB
+CameraSeg
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 92ceb5d022..49de343ea3 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -29,7 +29,7 @@ Lightning Flash
reference/translation
reference/object_detection
reference/video_classification
-
+ reference/semantic_segmentation
.. toctree::
:maxdepth: 1
diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst
new file mode 100644
index 0000000000..0a9e01d8bb
--- /dev/null
+++ b/docs/source/reference/semantic_segmentation.rst
@@ -0,0 +1,151 @@
+
+.. _semantinc_segmentation:
+
+######################
+Semantinc Segmentation
+######################
+
+********
+The task
+********
+Semantic segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. The model output shape is ``(batch_size, num_classes, heigh, width)``.
+
+See more: https://paperswithcode.com/task/semantic-segmentation
+
+.. raw:: html
+
+
+
+
+
+
+
+------
+
+*********
+Inference
+*********
+
+A :class:`~flash.vision.SemanticSegmentation` `fcn_resnet50` pre-trained on `CARLA `_ simulator is provided for the inference example.
+
+
+Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inference on any string sequence using :func:`~flash.vision.SemanticSegmentation.predict`:
+
+.. code-block:: python
+
+ # import our libraries
+ from flash.data.utils import download_data
+ from flash.vision import SemanticSegmentation
+ from flash.vision.segmentation.serialization import SegmentationLabels
+
+ # 1. Download the data
+ download_data(
+ "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
+ "data/"
+ )
+
+ # 2. Load the model from a checkpoint
+ model = SemanticSegmentation.load_from_checkpoint(
+ "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
+ )
+ model.serializer = SegmentationLabels(visualize=True)
+
+ # 3. Predict what's on a few images and visualize!
+ predictions = model.predict([
+ 'data/CameraRGB/F61-1.png',
+ 'data/CameraRGB/F62-1.png',
+ 'data/CameraRGB/F63-1.png',
+ ])
+
+For more advanced inference options, see :ref:`predictions`.
+
+------
+
+**********
+Finetuning
+**********
+
+you now want to customise your model with new data using the same dataset.
+Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.SemanticSegmentationData`.
+
+.. note:: the dataset is structured in a way that each sample (an image and its corresponding labels) is stored in separated directories but keeping the same filename.
+
+.. code-block::
+
+ data
+ ├── CameraRGB
+ │ ├── F61-1.png
+ │ ├── F61-2.png
+ │ ...
+ └── CameraSeg
+ ├── F61-1.png
+ ├── F61-2.png
+ ...
+
+
+Now all we need is three lines of code to build to train our task!
+
+.. code-block:: python
+
+ import flash
+ from flash.data.utils import download_data
+ from flash.vision import SemanticSegmentation, SemanticSegmentationData
+ from flash.vision.segmentation.serialization import SegmentationLabels
+
+ # 1. Download the data
+ download_data(
+ "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
+ "data/"
+ )
+
+ # 2.1 Load the data
+ datamodule = SemanticSegmentationData.from_folders(
+ train_folder="data/CameraRGB",
+ train_target_folder="data/CameraSeg",
+ batch_size=4,
+ val_split=0.3,
+ image_size=(200, 200), # (600, 800)
+ )
+
+ # 2.2 Visualise the samples
+ labels_map = SegmentationLabels.create_random_labels_map(num_classes=21)
+ datamodule.set_labels_map(labels_map)
+ datamodule.show_train_batch(["load_sample", "post_tensor_transform"])
+
+ # 3. Build the model
+ model = SemanticSegmentation(backbone="torchvision/fcn_resnet50", num_classes=21)
+
+ # 4. Create the trainer.
+ trainer = flash.Trainer(max_epochs=1)
+
+ # 5. Train the model
+ trainer.finetune(model, datamodule=datamodule, strategy='freeze')
+
+ # 7. Save it!
+ trainer.save_checkpoint("semantic_segmentation_model.pt")
+
+------
+
+*************
+API reference
+*************
+
+.. _segmentation:
+
+SemanticSegmentation
+--------------------
+
+.. autoclass:: flash.vision.SemanticSegmentation
+ :members:
+ :exclude-members: forward
+
+.. _segmentation_data:
+
+SemanticSegmentationData
+------------------------
+
+.. autoclass:: flash.vision.SemanticSegmentationData
+
+.. automethod:: flash.vision.SemanticSegmentationData.from_folders
+
+.. autoclass:: flash.vision.SemanticSegmentationPreprocess
diff --git a/flash/core/classification.py b/flash/core/classification.py
index b85a529b3a..0965e684ef 100644
--- a/flash/core/classification.py
+++ b/flash/core/classification.py
@@ -55,7 +55,8 @@ def __init__(
def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
- return torch.softmax(x, -1)
+ # we'll assume that the data always comes as `(B, C, ...)`
+ return torch.softmax(x, dim=1)
class ClassificationSerializer(Serializer):
diff --git a/flash/data/batch.py b/flash/data/batch.py
index 739f4704ea..f08be37d02 100644
--- a/flash/data/batch.py
+++ b/flash/data/batch.py
@@ -138,6 +138,12 @@ def __init__(
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)
def forward(self, samples: Sequence[Any]) -> Any:
+ # we create a new dict to prevent from potential memory leaks
+ # assuming that the dictionary samples are stored in between and
+ # potentially modified before the transforms are applied.
+ if isinstance(samples, dict):
+ samples = dict(samples.items())
+
with self._current_stage_context:
if self.apply_per_sample_transform:
diff --git a/flash/data/data_module.py b/flash/data/data_module.py
index f64c25284a..e36af6fa9b 100644
--- a/flash/data/data_module.py
+++ b/flash/data/data_module.py
@@ -190,6 +190,8 @@ def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
data_fetcher._show(stage, func_names)
+ if reset:
+ self.data_fetcher.batches[stage] = {}
def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the train dataloader."""
diff --git a/flash/data/transforms.py b/flash/data/transforms.py
index 0a26224791..67b1229ad4 100644
--- a/flash/data/transforms.py
+++ b/flash/data/transforms.py
@@ -27,15 +27,41 @@ def __init__(self, keys: Union[str, Sequence[str]], *args):
self.keys = keys
def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
- inputs = [x[key] for key in filter(lambda key: key in x, self.keys)]
+ keys = list(filter(lambda key: key in x, self.keys))
+ inputs = [x[key] for key in keys]
if len(inputs) > 0:
- outputs = super().forward(*inputs)
- if not isinstance(outputs, tuple):
+ if len(inputs) == 1:
+ inputs = inputs[0]
+ outputs = super().forward(inputs)
+ if not isinstance(outputs, Sequence):
outputs = (outputs, )
result = {}
result.update(x)
- for i, key in enumerate(self.keys):
+ for i, key in enumerate(keys):
result[key] = outputs[i]
return result
return x
+
+
+class KorniaParallelTransforms(nn.Sequential):
+ """The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each
+ input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when
+ multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask)."""
+
+ def __init__(self, *args):
+ super().__init__(*[convert_to_modules(arg) for arg in args])
+
+ def forward(self, inputs: Any):
+ result = list(inputs) if isinstance(inputs, Sequence) else [inputs]
+ for transform in self.children():
+ inputs = result
+ for i, input in enumerate(inputs):
+ if hasattr(transform, "_params") and bool(transform._params):
+ params = transform._params
+ result[i] = transform(input, params)
+ else: # case for non random transforms
+ result[i] = transform(input)
+ if hasattr(transform, "_params") and bool(transform._params):
+ transform._params = None
+ return result
diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py
index 39dce803d8..346c84870a 100644
--- a/flash/vision/__init__.py
+++ b/flash/vision/__init__.py
@@ -2,3 +2,4 @@
from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
+from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess
diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py
index 928605b244..79d0fca863 100644
--- a/flash/vision/classification/data.py
+++ b/flash/vision/classification/data.py
@@ -73,7 +73,7 @@ def collate(self, samples: Sequence[Dict[str, Any]]) -> Any:
for key in sample.keys():
if torch.is_tensor(sample[key]):
sample[key] = sample[key].squeeze(0)
- return default_collate(samples)
+ return super().collate(samples)
@property
def default_train_transforms(self) -> Optional[Dict[str, Callable]]:
diff --git a/flash/vision/segmentation/__init__.py b/flash/vision/segmentation/__init__.py
new file mode 100644
index 0000000000..08f9742e47
--- /dev/null
+++ b/flash/vision/segmentation/__init__.py
@@ -0,0 +1,2 @@
+from flash.vision.segmentation.data import SemanticSegmentationData, SemanticSegmentationPreprocess
+from flash.vision.segmentation.model import SemanticSegmentation
diff --git a/flash/vision/segmentation/backbones.py b/flash/vision/segmentation/backbones.py
new file mode 100644
index 0000000000..2a1661be6c
--- /dev/null
+++ b/flash/vision/segmentation/backbones.py
@@ -0,0 +1,36 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch.nn as nn
+
+from flash.core.registry import FlashRegistry
+from flash.utils.imports import _TORCHVISION_AVAILABLE
+
+if _TORCHVISION_AVAILABLE:
+ import torchvision
+
+SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")
+
+
+@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50")
+def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module:
+ model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained)
+ model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
+ return model
+
+
+@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101")
+def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module:
+ model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained)
+ model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
+ return model
diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py
new file mode 100644
index 0000000000..d674205786
--- /dev/null
+++ b/flash/vision/segmentation/data.py
@@ -0,0 +1,285 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torchvision
+from PIL import Image
+from pytorch_lightning.trainer.states import RunningStage
+from pytorch_lightning.utilities import rank_zero_warn
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS
+
+from flash.data.base_viz import BaseVisualization # for viz
+from flash.data.callback import BaseDataFetcher
+from flash.data.data_module import DataModule
+from flash.data.data_source import DefaultDataKeys, DefaultDataSources, PathsDataSource
+from flash.data.process import Preprocess
+from flash.utils.imports import _MATPLOTLIB_AVAILABLE
+from flash.vision.segmentation.serialization import SegmentationLabels
+from flash.vision.segmentation.transforms import default_train_transforms, default_val_transforms
+
+if _MATPLOTLIB_AVAILABLE:
+ import matplotlib.pyplot as plt
+else:
+ plt = None
+
+
+class SemanticSegmentationPathsDataSource(PathsDataSource):
+
+ def __init__(self):
+ super().__init__(IMG_EXTENSIONS)
+
+ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) -> Sequence[Mapping[str, Any]]:
+ input_data, target_data = data
+
+ if self.isdir(input_data) and self.isdir(target_data):
+ input_files = os.listdir(input_data)
+ target_files = os.listdir(target_data)
+
+ all_files = set(input_files).intersection(set(target_files))
+
+ if len(all_files) != len(input_files) or len(all_files) != len(target_files):
+ rank_zero_warn(
+ f"Found inconsistent files in input_dir: {input_data} and target_dir: {target_data}. Some files"
+ " have been dropped.",
+ UserWarning,
+ )
+
+ input_data = [os.path.join(input_data, file) for file in all_files]
+ target_data = [os.path.join(target_data, file) for file in all_files]
+
+ if not isinstance(input_data, list) and not isinstance(target_data, list):
+ input_data = [input_data]
+ target_data = [target_data]
+
+ if len(input_data) != len(target_data):
+ raise MisconfigurationException(
+ f"The number of input files ({len(input_data)}) and number of target files ({len(target_data)}) must be"
+ " the same.",
+ )
+
+ data = filter(
+ lambda sample: (
+ has_file_allowed_extension(sample[0], self.extensions) and
+ has_file_allowed_extension(sample[1], self.extensions)
+ ),
+ zip(input_data, target_data),
+ )
+
+ return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data]
+
+ def predict_load_data(self, data: Union[str, List[str]]):
+ return super().predict_load_data(data)
+
+ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]:
+ # unpack data paths
+ img_path = sample[DefaultDataKeys.INPUT]
+ img_labels_path = sample[DefaultDataKeys.TARGET]
+
+ # load images directly to torch tensors
+ img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW
+ img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW
+ img_labels = img_labels[0] # HxW
+
+ return {DefaultDataKeys.INPUT: img.float(), DefaultDataKeys.TARGET: img_labels.float()}
+
+ def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
+ return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()}
+
+
+class SemanticSegmentationPreprocess(Preprocess):
+
+ def __init__(
+ self,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ image_size: Tuple[int, int] = (196, 196),
+ ) -> None:
+ """Preprocess pipeline for semantic segmentation tasks.
+
+ Args:
+ train_transform: Dictionary with the set of transforms to apply during training.
+ val_transform: Dictionary with the set of transforms to apply during validation.
+ test_transform: Dictionary with the set of transforms to apply during testing.
+ predict_transform: Dictionary with the set of transforms to apply during prediction.
+ image_size: A tuple with the expected output image size.
+ """
+ self.image_size = image_size
+
+ super().__init__(
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_sources={DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource()},
+ default_data_source=DefaultDataSources.PATHS,
+ )
+
+ def get_state_dict(self) -> Dict[str, Any]:
+ return {
+ **self.transforms,
+ "image_size": self.image_size,
+ }
+
+ @classmethod
+ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
+ return cls(**state_dict)
+
+ def collate(self, samples: Sequence[Dict[str, Any]]) -> Any:
+ # todo: Kornia transforms add batch dimension which need to be removed
+ for sample in samples:
+ for key in sample.keys():
+ if torch.is_tensor(sample[key]):
+ sample[key] = sample[key].squeeze(0)
+ return super().collate(samples)
+
+ @property
+ def default_train_transforms(self) -> Optional[Dict[str, Callable]]:
+ return default_train_transforms(self.image_size)
+
+ @property
+ def default_val_transforms(self) -> Optional[Dict[str, Callable]]:
+ return default_val_transforms(self.image_size)
+
+ @property
+ def default_test_transforms(self) -> Optional[Dict[str, Callable]]:
+ return default_val_transforms(self.image_size)
+
+ @property
+ def default_predict_transforms(self) -> Optional[Dict[str, Callable]]:
+ return default_val_transforms(self.image_size)
+
+
+class SemanticSegmentationData(DataModule):
+ """Data module for semantic segmentation tasks."""
+
+ preprocess_cls = SemanticSegmentationPreprocess
+
+ @staticmethod
+ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
+ return SegmentationMatplotlibVisualization(*args, **kwargs)
+
+ def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]):
+ self.data_fetcher.labels_map = labels_map
+
+ def set_block_viz_window(self, value: bool) -> None:
+ """Setter method to switch on/off matplotlib to pop up windows."""
+ self.data_fetcher.block_viz_window = value
+
+ @classmethod
+ def from_folders(
+ cls,
+ train_folder: Optional[str] = None,
+ train_target_folder: Optional[str] = None,
+ val_folder: Optional[str] = None,
+ val_target_folder: Optional[str] = None,
+ test_folder: Optional[str] = None,
+ test_target_folder: Optional[str] = None,
+ predict_folder: Optional[str] = None,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ data_fetcher: BaseDataFetcher = None,
+ preprocess: Optional[Preprocess] = None,
+ val_split: Optional[float] = None,
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ **preprocess_kwargs: Any,
+ ) -> 'DataModule':
+ return cls.from_data_source(
+ DefaultDataSources.PATHS,
+ (train_folder, train_target_folder),
+ (val_folder, val_target_folder),
+ (test_folder, test_target_folder),
+ predict_folder,
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_fetcher=data_fetcher,
+ preprocess=preprocess,
+ val_split=val_split,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **preprocess_kwargs,
+ )
+
+
+class SegmentationMatplotlibVisualization(BaseVisualization):
+ """Process and show the image batch and its associated label using matplotlib.
+ """
+
+ def __init__(self):
+ super().__init__(self)
+ self.max_cols: int = 4 # maximum number of columns we accept
+ self.block_viz_window: bool = True # parameter to allow user to block visualisation windows
+ self.labels_map: Dict[int, Tuple[int, int, int]] = {}
+
+ @staticmethod
+ def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
+ out: np.ndarray
+ if isinstance(img, Image.Image):
+ out = np.array(img)
+ elif isinstance(img, torch.Tensor):
+ out = img.squeeze(0).permute(1, 2, 0).cpu().numpy()
+ else:
+ raise TypeError(f"Unknown image type. Got: {type(img)}.")
+ return out
+
+ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str):
+ # define the image grid
+ cols: int = min(num_samples, self.max_cols)
+ rows: int = num_samples // cols
+
+ if not _MATPLOTLIB_AVAILABLE:
+ raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib")
+
+ # create figure and set title
+ fig, axs = plt.subplots(rows, cols)
+ fig.suptitle(title)
+
+ for i, ax in enumerate(axs.ravel()):
+ # unpack images and labels
+ sample = data[i]
+ if isinstance(sample, dict):
+ image = sample[DefaultDataKeys.INPUT]
+ label = sample[DefaultDataKeys.TARGET]
+ elif isinstance(sample, tuple):
+ image = sample[0]
+ label = sample[1]
+ else:
+ raise TypeError(f"Unknown data type. Got: {type(data)}.")
+ # convert images and labels to numpy and stack horizontally
+ image_vis: np.ndarray = self._to_numpy(image.byte())
+ label_tmp: torch.Tensor = SegmentationLabels.labels_to_image(label.squeeze().byte(), self.labels_map)
+ label_vis: np.ndarray = self._to_numpy(label_tmp)
+ img_vis = np.hstack((image_vis, label_vis))
+ # send to visualiser
+ ax.imshow(img_vis)
+ ax.axis('off')
+ plt.show(block=self.block_viz_window)
+
+ def show_load_sample(self, samples: List[Any], running_stage: RunningStage):
+ win_title: str = f"{running_stage} - show_load_sample"
+ self._show_images_and_labels(samples, len(samples), win_title)
+
+ def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage):
+ win_title: str = f"{running_stage} - show_post_tensor_transform"
+ self._show_images_and_labels(samples, len(samples), win_title)
diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py
new file mode 100644
index 0000000000..e543b341ed
--- /dev/null
+++ b/flash/vision/segmentation/model.py
@@ -0,0 +1,128 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torchmetrics import IoU
+
+from flash.core.classification import ClassificationTask
+from flash.core.registry import FlashRegistry
+from flash.data.data_source import DefaultDataKeys
+from flash.data.process import Serializer
+from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES
+from flash.vision.segmentation.serialization import SegmentationLabels
+
+
+class SemanticSegmentation(ClassificationTask):
+ """Task that performs semantic segmentation on images.
+
+ Use a built in backbone
+
+ Example::
+
+ from flash.vision import SemanticSegmentation
+
+ segmentation = SemanticSegmentation(
+ num_classes=21, backbone="torchvision/fcn_resnet50"
+ )
+
+ Args:
+ num_classes: Number of classes to classify.
+ backbone: A string or (model, num_features) tuple to use to compute image features,
+ defaults to ``"torchvision/fcn_resnet50"``.
+ backbone_kwargs: Additional arguments for the backbone configuration.
+ pretrained: Use a pretrained backbone, defaults to ``False``.
+ loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`.
+ optimizer: Optimizer to use for training, defaults to :class:`torch.optim.AdamW`.
+ metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.IoU`.
+ learning_rate: Learning rate to use for training, defaults to ``1e-3``.
+ multi_label: Whether the targets are multi-label or not.
+ serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs.
+ """
+
+ backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES
+
+ def __init__(
+ self,
+ num_classes: int,
+ backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50",
+ backbone_kwargs: Optional[Dict] = None,
+ pretrained: bool = True,
+ loss_fn: Optional[Callable] = None,
+ optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW,
+ metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None,
+ learning_rate: float = 1e-3,
+ multi_label: bool = False,
+ serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
+ ) -> None:
+
+ if metrics is None:
+ metrics = IoU(num_classes=num_classes)
+
+ if loss_fn is None:
+ loss_fn = F.cross_entropy
+
+ # TODO: need to check for multi_label
+ if multi_label:
+ raise NotImplementedError("Multi-label not supported yet.")
+
+ super().__init__(
+ model=None,
+ loss_fn=loss_fn,
+ optimizer=optimizer,
+ metrics=metrics,
+ learning_rate=learning_rate,
+ serializer=serializer or SegmentationLabels(),
+ )
+
+ self.save_hyperparameters()
+
+ if not backbone_kwargs:
+ backbone_kwargs = {}
+
+ # TODO: pretrained to True causes some issues
+ self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs)
+
+ def training_step(self, batch: Any, batch_idx: int) -> Any:
+ batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
+ return super().training_step(batch, batch_idx)
+
+ def validation_step(self, batch: Any, batch_idx: int) -> Any:
+ batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
+ return super().validation_step(batch, batch_idx)
+
+ def test_step(self, batch: Any, batch_idx: int) -> Any:
+ batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET])
+ return super().test_step(batch, batch_idx)
+
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
+ batch = (batch[DefaultDataKeys.INPUT])
+ return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx)
+
+ def forward(self, x) -> torch.Tensor:
+ # infer the image to the model
+ res: Union[torch.Tensor, Dict[str, torch.Tensor]] = self.backbone(x)
+
+ # some frameworks like torchvision return a dict.
+ # In particular, torchvision segmentation models return the output logits
+ # in the key `out`.
+ out: torch.Tensor
+ if isinstance(res, dict):
+ out = res['out']
+ else:
+ raise NotImplementedError(f"Unsupported output type: {type(out)}")
+
+ return out
diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py
new file mode 100644
index 0000000000..50ba5be9a9
--- /dev/null
+++ b/flash/vision/segmentation/serialization.py
@@ -0,0 +1,82 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from typing import Dict, Optional, Tuple
+
+import torch
+
+from flash.data.process import Serializer
+from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE
+
+if _MATPLOTLIB_AVAILABLE:
+ import matplotlib.pyplot as plt
+else:
+ plt = None
+
+if _KORNIA_AVAILABLE:
+ import kornia as K
+else:
+ K = None
+
+
+class SegmentationLabels(Serializer):
+
+ def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False):
+ """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification
+ per pixel in the image for semantic segmentation tasks.
+
+ Args:
+ labels_map: A dictionary that map the labels ids to pixel intensities.
+ visualise: Wether to visualise the image labels.
+ """
+ super().__init__()
+ self.labels_map = labels_map
+ self.visualize = visualize
+
+ @staticmethod
+ def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor:
+ """Function that given an image with labels ids and their pixels intrensity mapping,
+ creates a RGB representation for visualisation purposes.
+ """
+ assert len(img_labels.shape) == 2, img_labels.shape
+ H, W = img_labels.shape
+ out = torch.empty(3, H, W, dtype=torch.uint8)
+ for label_id, label_val in labels_map.items():
+ mask = (img_labels == label_id)
+ for i in range(3):
+ out[i].masked_fill_(mask, label_val[i])
+ return out
+
+ @staticmethod
+ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]:
+ labels_map: Dict[int, Tuple[int, int, int]] = {}
+ for i in range(num_classes):
+ labels_map[i] = torch.randint(0, 255, (3, ))
+ return labels_map
+
+ def serialize(self, sample: torch.Tensor) -> torch.Tensor:
+ assert len(sample.shape) == 3, sample.shape
+ labels = torch.argmax(sample, dim=-3) # HxW
+ if self.visualize and os.getenv("FLASH_TESTING", "0") == "0":
+ if self.labels_map is None:
+ # create random colors map
+ num_classes = sample.shape[-3]
+ labels_map = self.create_random_labels_map(num_classes)
+ else:
+ labels_map = self.labels_map
+ labels_vis = self.labels_to_image(labels, labels_map)
+ labels_vis = K.utils.tensor_to_image(labels_vis)
+ plt.imshow(labels_vis)
+ plt.show()
+ return labels
diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py
new file mode 100644
index 0000000000..1cf491793f
--- /dev/null
+++ b/flash/vision/segmentation/transforms.py
@@ -0,0 +1,58 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Dict, Tuple
+
+import kornia as K
+import torch
+import torch.nn as nn
+
+from flash.data.data_source import DefaultDataKeys
+from flash.data.transforms import ApplyToKeys, KorniaParallelTransforms
+
+
+def prepare_target(tensor: torch.Tensor) -> torch.Tensor:
+ return tensor.long().squeeze()
+
+
+def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
+ return {
+ "post_tensor_transform": nn.Sequential(
+ ApplyToKeys(
+ [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
+ KorniaParallelTransforms(
+ K.geometry.Resize(image_size, interpolation='nearest'),
+ K.augmentation.RandomHorizontalFlip(p=0.75),
+ ),
+ ),
+ ApplyToKeys(DefaultDataKeys.TARGET, prepare_target),
+ ),
+ "per_batch_transform_on_device": ApplyToKeys(
+ DefaultDataKeys.INPUT,
+ K.enhance.Normalize(0., 255.),
+ K.augmentation.ColorJitter(0.4, p=0.5),
+ ),
+ }
+
+
+def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
+ return {
+ "post_tensor_transform": nn.Sequential(
+ ApplyToKeys(
+ [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
+ KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest')),
+ ),
+ ApplyToKeys(DefaultDataKeys.TARGET, prepare_target),
+ ),
+ "per_batch_transform_on_device": ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)),
+ }
diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py
new file mode 100644
index 0000000000..3676353ec8
--- /dev/null
+++ b/flash_examples/finetuning/semantic_segmentation.py
@@ -0,0 +1,66 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import flash
+from flash.data.utils import download_data
+from flash.vision import SemanticSegmentation, SemanticSegmentationData
+from flash.vision.segmentation.serialization import SegmentationLabels
+
+# 1. Download the data
+# This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator.
+# The data was generated as part of the Lyft Udacity Challenge.
+# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
+download_data(
+ "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/"
+)
+
+# 2.1 Load the data
+datamodule = SemanticSegmentationData.from_folders(
+ train_folder="data/CameraRGB",
+ train_target_folder="data/CameraSeg",
+ batch_size=4,
+ val_split=0.3,
+ image_size=(200, 200), # (600, 800)
+)
+
+# 2.2 Visualise the samples
+labels_map = SegmentationLabels.create_random_labels_map(num_classes=21)
+datamodule.set_labels_map(labels_map)
+datamodule.show_train_batch(["load_sample", "post_tensor_transform"])
+
+# 3. Build the model
+model = SemanticSegmentation(
+ backbone="torchvision/fcn_resnet50",
+ num_classes=21,
+)
+
+# 4. Create the trainer.
+trainer = flash.Trainer(
+ max_epochs=1,
+ fast_dev_run=1,
+)
+
+# 5. Train the model
+trainer.finetune(model, datamodule=datamodule, strategy="freeze")
+
+# 6. Predict what's on a few images!
+model.serializer = SegmentationLabels(labels_map, visualize=True)
+
+predictions = model.predict([
+ "data/CameraRGB/F61-1.png",
+ "data/CameraRGB/F62-1.png",
+ "data/CameraRGB/F63-1.png",
+])
+
+# 7. Save it!
+trainer.save_checkpoint("semantic_segmentation_model.pt")
diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py
new file mode 100644
index 0000000000..f507f2a6a6
--- /dev/null
+++ b/flash_examples/predict/semantic_segmentation.py
@@ -0,0 +1,37 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from flash.data.utils import download_data
+from flash.vision import SemanticSegmentation
+from flash.vision.segmentation.serialization import SegmentationLabels
+
+# 1. Download the data
+# This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator.
+# The data was generated as part of the Lyft Udacity Challenge.
+# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
+download_data(
+ "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/"
+)
+
+# 2. Load the model from a checkpoint
+model = SemanticSegmentation.load_from_checkpoint(
+ "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
+)
+model.serializer = SegmentationLabels(visualize=True)
+
+# 3. Predict what's on a few images and visualize!
+predictions = model.predict([
+ "data/CameraRGB/F61-1.png",
+ "data/CameraRGB/F62-1.png",
+ "data/CameraRGB/F63-1.png",
+])
diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py
index f4748a5149..b53db09b6e 100644
--- a/tests/data/test_callbacks.py
+++ b/tests/data/test_callbacks.py
@@ -157,8 +157,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization:
for _ in range(num_tests):
for fcn_name in _CALLBACK_FUNCS:
+ dm.data_fetcher.reset()
fcn = getattr(dm, f"show_{stage}_batch")
- fcn(fcn_name, reset=True)
+ fcn(fcn_name, reset=False)
is_predict = stage == "predict"
diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py
index 2fc4ee18f3..a60ecbf021 100644
--- a/tests/examples/test_scripts.py
+++ b/tests/examples/test_scripts.py
@@ -58,6 +58,7 @@ def run_test(filepath):
("finetuning", "image_classification.py"),
("finetuning", "image_classification_multi_label.py"),
# ("finetuning", "object_detection.py"), # TODO: takes too long.
+ ("finetuning", "semantic_segmentation.py"),
# ("finetuning", "summarization.py"), # TODO: takes too long.
("finetuning", "tabular_classification.py"),
# ("finetuning", "video_classification.py"),
@@ -65,6 +66,7 @@ def run_test(filepath):
("finetuning", "translation.py"),
("predict", "image_classification.py"),
("predict", "image_classification_multi_label.py"),
+ ("predict", "semantic_segmentation.py"),
("predict", "tabular_classification.py"),
# ("predict", "text_classification.py"),
("predict", "image_embedder.py"),
diff --git a/tests/vision/segmentation/__init__.py b/tests/vision/segmentation/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py
new file mode 100644
index 0000000000..bd51f09d21
--- /dev/null
+++ b/tests/vision/segmentation/test_data.py
@@ -0,0 +1,303 @@
+import os
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from flash import Trainer
+from flash.data.data_source import DefaultDataKeys
+from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess
+
+
+def build_checkboard(n, m, k=8):
+ x = np.zeros((n, m))
+ x[k::k * 2, ::k] = 1
+ x[::k * 2, k::k * 2] = 1
+ return x
+
+
+def _rand_image(size: Tuple[int, int]):
+ data = build_checkboard(*size).astype(np.uint8)[..., None].repeat(3, -1)
+ return Image.fromarray(data)
+
+
+# usually labels come as rgb images -> need to map to labels
+def _rand_labels(size: Tuple[int, int], num_classes: int):
+ data: np.ndarray = np.random.randint(0, num_classes, (*size, 1))
+ data = data.repeat(3, axis=-1)
+ return Image.fromarray(data.astype(np.uint8))
+
+
+def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int):
+ for img_file in image_files:
+ _rand_image(size).save(img_file)
+
+ for label_file in label_files:
+ _rand_labels(size, num_classes).save(label_file)
+
+
+class TestSemanticSegmentationPreprocess:
+
+ @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.")
+ def test_smoke(self):
+ prep = SemanticSegmentationPreprocess()
+ assert prep is not None
+
+
+class TestSemanticSegmentationData:
+
+ def test_smoke(self):
+ dm = SemanticSegmentationData()
+ assert dm is not None
+
+ def test_from_folders(self, tmpdir):
+ tmp_dir = Path(tmpdir)
+
+ # create random dummy data
+
+ os.makedirs(str(tmp_dir / "images"))
+ os.makedirs(str(tmp_dir / "targets"))
+
+ images = [
+ str(tmp_dir / "images" / "img1.png"),
+ str(tmp_dir / "images" / "img2.png"),
+ str(tmp_dir / "images" / "img3.png"),
+ ]
+
+ targets = [
+ str(tmp_dir / "targets" / "img1.png"),
+ str(tmp_dir / "targets" / "img2.png"),
+ str(tmp_dir / "targets" / "img3.png"),
+ ]
+
+ num_classes: int = 2
+ img_size: Tuple[int, int] = (196, 196)
+ create_random_data(images, targets, img_size, num_classes)
+
+ # instantiate the data module
+
+ dm = SemanticSegmentationData.from_folders(
+ train_folder=str(tmp_dir / "images"),
+ train_target_folder=str(tmp_dir / "targets"),
+ val_folder=str(tmp_dir / "images"),
+ val_target_folder=str(tmp_dir / "targets"),
+ test_folder=str(tmp_dir / "images"),
+ test_target_folder=str(tmp_dir / "targets"),
+ batch_size=2,
+ num_workers=0,
+ )
+ assert dm is not None
+ assert dm.train_dataloader() is not None
+ assert dm.val_dataloader() is not None
+ assert dm.test_dataloader() is not None
+
+ # check training data
+ data = next(iter(dm.train_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 196, 196)
+
+ # check val data
+ data = next(iter(dm.val_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 196, 196)
+
+ # check test data
+ data = next(iter(dm.test_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 196, 196)
+
+ def test_from_folders_warning(self, tmpdir):
+ tmp_dir = Path(tmpdir)
+
+ # create random dummy data
+
+ os.makedirs(str(tmp_dir / "images"))
+ os.makedirs(str(tmp_dir / "targets"))
+
+ images = [
+ str(tmp_dir / "images" / "img1.png"),
+ str(tmp_dir / "images" / "img3.png"),
+ ]
+
+ targets = [
+ str(tmp_dir / "targets" / "img1.png"),
+ str(tmp_dir / "targets" / "img2.png"),
+ ]
+
+ num_classes: int = 2
+ img_size: Tuple[int, int] = (196, 196)
+ create_random_data(images, targets, img_size, num_classes)
+
+ # instantiate the data module
+
+ with pytest.warns(UserWarning, match="Found inconsistent files"):
+ dm = SemanticSegmentationData.from_folders(
+ train_folder=str(tmp_dir / "images"),
+ train_target_folder=str(tmp_dir / "targets"),
+ batch_size=1,
+ num_workers=0,
+ )
+ assert dm is not None
+ assert dm.train_dataloader() is not None
+
+ # check training data
+ data = next(iter(dm.train_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (1, 3, 196, 196)
+ assert labels.shape == (1, 196, 196)
+
+ def test_from_files(self, tmpdir):
+ tmp_dir = Path(tmpdir)
+
+ # create random dummy data
+
+ images = [
+ str(tmp_dir / "img1.png"),
+ str(tmp_dir / "img2.png"),
+ str(tmp_dir / "img3.png"),
+ ]
+
+ targets = [
+ str(tmp_dir / "labels_img1.png"),
+ str(tmp_dir / "labels_img2.png"),
+ str(tmp_dir / "labels_img3.png"),
+ ]
+
+ num_classes: int = 2
+ img_size: Tuple[int, int] = (196, 196)
+ create_random_data(images, targets, img_size, num_classes)
+
+ # instantiate the data module
+
+ dm = SemanticSegmentationData.from_files(
+ train_files=images,
+ train_targets=targets,
+ val_files=images,
+ val_targets=targets,
+ test_files=images,
+ test_targets=targets,
+ batch_size=2,
+ num_workers=0,
+ )
+ assert dm is not None
+ assert dm.train_dataloader() is not None
+ assert dm.val_dataloader() is not None
+ assert dm.test_dataloader() is not None
+
+ # check training data
+ data = next(iter(dm.train_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 196, 196)
+
+ # check val data
+ data = next(iter(dm.val_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 196, 196)
+
+ # check test data
+ data = next(iter(dm.test_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 196, 196)
+
+ def test_from_files_warning(self, tmpdir):
+ tmp_dir = Path(tmpdir)
+
+ # create random dummy data
+
+ images = [
+ str(tmp_dir / "img1.png"),
+ str(tmp_dir / "img2.png"),
+ str(tmp_dir / "img3.png"),
+ ]
+
+ targets = [
+ str(tmp_dir / "labels_img1.png"),
+ str(tmp_dir / "labels_img2.png"),
+ str(tmp_dir / "labels_img3.png"),
+ ]
+
+ num_classes: int = 2
+ img_size: Tuple[int, int] = (196, 196)
+ create_random_data(images, targets, img_size, num_classes)
+
+ # instantiate the data module
+
+ with pytest.raises(MisconfigurationException, match="The number of input files"):
+ SemanticSegmentationData.from_files(
+ train_files=images,
+ train_targets=targets + [str(tmp_dir / "labels_img4.png")],
+ batch_size=2,
+ num_workers=0,
+ )
+
+ def test_map_labels(self, tmpdir):
+ tmp_dir = Path(tmpdir)
+
+ # create random dummy data
+
+ images = [
+ str(tmp_dir / "img1.png"),
+ str(tmp_dir / "img2.png"),
+ str(tmp_dir / "img3.png"),
+ ]
+
+ targets = [
+ str(tmp_dir / "labels_img1.png"),
+ str(tmp_dir / "labels_img2.png"),
+ str(tmp_dir / "labels_img3.png"),
+ ]
+
+ labels_map: Dict[int, Tuple[int, int, int]] = {
+ 0: [0, 0, 0],
+ 1: [255, 255, 255],
+ }
+
+ num_classes: int = len(labels_map.keys())
+ img_size: Tuple[int, int] = (196, 196)
+ create_random_data(images, targets, img_size, num_classes)
+
+ # instantiate the data module
+
+ dm = SemanticSegmentationData.from_files(
+ train_files=images,
+ train_targets=targets,
+ val_files=images,
+ val_targets=targets,
+ batch_size=2,
+ num_workers=0,
+ )
+ assert dm is not None
+ assert dm.train_dataloader() is not None
+
+ # disable visualisation for testing
+ assert dm.data_fetcher.block_viz_window is True
+ dm.set_block_viz_window(False)
+ assert dm.data_fetcher.block_viz_window is False
+
+ dm.set_labels_map(labels_map)
+ dm.show_train_batch("load_sample")
+ dm.show_train_batch("to_tensor_transform")
+
+ # check training data
+ data = next(iter(dm.train_dataloader()))
+ imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 196, 196)
+ assert labels.min().item() == 0
+ assert labels.max().item() == 1
+ assert labels.dtype == torch.int64
+
+ # now train with `fast_dev_run`
+ model = SemanticSegmentation(num_classes=2, backbone="torchvision/fcn_resnet50")
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.finetune(model, dm, strategy="freeze_unfreeze")
diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py
new file mode 100644
index 0000000000..0ebaa5d956
--- /dev/null
+++ b/tests/vision/segmentation/test_model.py
@@ -0,0 +1,80 @@
+from typing import Tuple
+
+import pytest
+import torch
+
+from flash import Trainer
+from flash.data.data_source import DefaultDataKeys
+from flash.vision import SemanticSegmentation
+
+# ======== Mock functions ========
+
+
+class DummyDataset(torch.utils.data.Dataset):
+ size: Tuple[int, int] = (224, 224)
+ num_classes: int = 8
+
+ def __getitem__(self, index):
+ return {
+ DefaultDataKeys.INPUT: torch.rand(3, *self.size),
+ DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, self.size),
+ }
+
+ def __len__(self) -> int:
+ return 10
+
+
+# ==============================
+
+
+def test_smoke():
+ model = SemanticSegmentation(num_classes=1)
+ assert model is not None
+
+
+@pytest.mark.parametrize("num_classes", [8, 256])
+@pytest.mark.parametrize("img_shape", [(1, 3, 224, 192), (2, 3, 127, 212)])
+def test_forward(num_classes, img_shape):
+ model = SemanticSegmentation(
+ num_classes=num_classes,
+ backbone='torchvision/fcn_resnet50',
+ )
+
+ B, C, H, W = img_shape
+ img = torch.rand(B, C, H, W)
+
+ out = model(img)
+ assert out.shape == (B, num_classes, H, W)
+
+
+@pytest.mark.parametrize(
+ "backbone",
+ [
+ "torchvision/fcn_resnet50",
+ "torchvision/fcn_resnet101",
+ ],
+)
+def test_init_train(tmpdir, backbone):
+ model = SemanticSegmentation(num_classes=10, backbone=backbone)
+ train_dl = torch.utils.data.DataLoader(DummyDataset())
+ trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
+ trainer.finetune(model, train_dl, strategy="freeze_unfreeze")
+
+
+def test_non_existent_backbone():
+ with pytest.raises(KeyError):
+ SemanticSegmentation(2, "i am never going to implement this lol")
+
+
+def test_freeze():
+ model = SemanticSegmentation(2)
+ model.freeze()
+ for p in model.backbone.parameters():
+ assert p.requires_grad is False
+
+
+def test_unfreeze():
+ model = SemanticSegmentation(2)
+ model.unfreeze()
+ for p in model.backbone.parameters():
+ assert p.requires_grad is True
diff --git a/tests/vision/segmentation/test_serialization.py b/tests/vision/segmentation/test_serialization.py
new file mode 100644
index 0000000000..a971c91fbf
--- /dev/null
+++ b/tests/vision/segmentation/test_serialization.py
@@ -0,0 +1,43 @@
+import pytest
+import torch
+
+from flash.vision.segmentation.serialization import SegmentationLabels
+
+
+class TestSemanticSegmentationLabels:
+
+ def test_smoke(self):
+ serial = SegmentationLabels()
+ assert serial is not None
+ assert serial.labels_map is None
+ assert serial.visualize is False
+
+ def test_exception(self):
+ serial = SegmentationLabels()
+
+ with pytest.raises(Exception):
+ sample = torch.zeros(1, 5, 2, 3)
+ serial.serialize(sample)
+
+ with pytest.raises(Exception):
+ sample = torch.zeros(2, 3)
+ serial.serialize(sample)
+
+ def test_serialize(self):
+ serial = SegmentationLabels()
+
+ sample = torch.zeros(5, 2, 3)
+ sample[1, 1, 2] = 1 # add peak in class 2
+ sample[3, 0, 1] = 1 # add peak in class 4
+
+ classes = serial.serialize(sample)
+ assert classes[1, 2] == 1
+ assert classes[0, 1] == 3
+
+ # TODO: implement me
+ def test_create_random_labels(self):
+ pass
+
+ # TODO: implement me
+ def test_labels_to_image(self):
+ pass