diff --git a/.gitignore b/.gitignore index 73b96a16dd..26ab5033dc 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,5 @@ wmt_en_ro action_youtube_naudio kinetics movie_posters +CameraRGB +CameraSeg diff --git a/docs/source/index.rst b/docs/source/index.rst index 92ceb5d022..49de343ea3 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -29,7 +29,7 @@ Lightning Flash reference/translation reference/object_detection reference/video_classification - + reference/semantic_segmentation .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst new file mode 100644 index 0000000000..0a9e01d8bb --- /dev/null +++ b/docs/source/reference/semantic_segmentation.rst @@ -0,0 +1,151 @@ + +.. _semantinc_segmentation: + +###################### +Semantinc Segmentation +###################### + +******** +The task +******** +Semantic segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. The model output shape is ``(batch_size, num_classes, heigh, width)``. + +See more: https://paperswithcode.com/task/semantic-segmentation + +.. raw:: html + +

+ + + +

+ +------ + +********* +Inference +********* + +A :class:`~flash.vision.SemanticSegmentation` `fcn_resnet50` pre-trained on `CARLA `_ simulator is provided for the inference example. + + +Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inference on any string sequence using :func:`~flash.vision.SemanticSegmentation.predict`: + +.. code-block:: python + + # import our libraries + from flash.data.utils import download_data + from flash.vision import SemanticSegmentation + from flash.vision.segmentation.serialization import SegmentationLabels + + # 1. Download the data + download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", + "data/" + ) + + # 2. Load the model from a checkpoint + model = SemanticSegmentation.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" + ) + model.serializer = SegmentationLabels(visualize=True) + + # 3. Predict what's on a few images and visualize! + predictions = model.predict([ + 'data/CameraRGB/F61-1.png', + 'data/CameraRGB/F62-1.png', + 'data/CameraRGB/F63-1.png', + ]) + +For more advanced inference options, see :ref:`predictions`. + +------ + +********** +Finetuning +********** + +you now want to customise your model with new data using the same dataset. +Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.SemanticSegmentationData`. + +.. note:: the dataset is structured in a way that each sample (an image and its corresponding labels) is stored in separated directories but keeping the same filename. + +.. code-block:: + + data + ├── CameraRGB + │ ├── F61-1.png + │ ├── F61-2.png + │ ... + └── CameraSeg + ├── F61-1.png + ├── F61-2.png + ... + + +Now all we need is three lines of code to build to train our task! + +.. code-block:: python + + import flash + from flash.data.utils import download_data + from flash.vision import SemanticSegmentation, SemanticSegmentationData + from flash.vision.segmentation.serialization import SegmentationLabels + + # 1. Download the data + download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", + "data/" + ) + + # 2.1 Load the data + datamodule = SemanticSegmentationData.from_folders( + train_folder="data/CameraRGB", + train_target_folder="data/CameraSeg", + batch_size=4, + val_split=0.3, + image_size=(200, 200), # (600, 800) + ) + + # 2.2 Visualise the samples + labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) + datamodule.set_labels_map(labels_map) + datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) + + # 3. Build the model + model = SemanticSegmentation(backbone="torchvision/fcn_resnet50", num_classes=21) + + # 4. Create the trainer. + trainer = flash.Trainer(max_epochs=1) + + # 5. Train the model + trainer.finetune(model, datamodule=datamodule, strategy='freeze') + + # 7. Save it! + trainer.save_checkpoint("semantic_segmentation_model.pt") + +------ + +************* +API reference +************* + +.. _segmentation: + +SemanticSegmentation +-------------------- + +.. autoclass:: flash.vision.SemanticSegmentation + :members: + :exclude-members: forward + +.. _segmentation_data: + +SemanticSegmentationData +------------------------ + +.. autoclass:: flash.vision.SemanticSegmentationData + +.. automethod:: flash.vision.SemanticSegmentationData.from_folders + +.. autoclass:: flash.vision.SemanticSegmentationPreprocess diff --git a/flash/core/classification.py b/flash/core/classification.py index b85a529b3a..0965e684ef 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -55,7 +55,8 @@ def __init__( def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: if getattr(self.hparams, "multi_label", False): return torch.sigmoid(x) - return torch.softmax(x, -1) + # we'll assume that the data always comes as `(B, C, ...)` + return torch.softmax(x, dim=1) class ClassificationSerializer(Serializer): diff --git a/flash/data/batch.py b/flash/data/batch.py index 739f4704ea..f08be37d02 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -138,6 +138,12 @@ def __init__( self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess) def forward(self, samples: Sequence[Any]) -> Any: + # we create a new dict to prevent from potential memory leaks + # assuming that the dictionary samples are stored in between and + # potentially modified before the transforms are applied. + if isinstance(samples, dict): + samples = dict(samples.items()) + with self._current_stage_context: if self.apply_per_sample_transform: diff --git a/flash/data/data_module.py b/flash/data/data_module.py index f64c25284a..e36af6fa9b 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -190,6 +190,8 @@ def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool _ = next(iter_dataloader) data_fetcher: BaseVisualization = self.data_fetcher data_fetcher._show(stage, func_names) + if reset: + self.data_fetcher.batches[stage] = {} def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: """This function is used to visualize a batch from the train dataloader.""" diff --git a/flash/data/transforms.py b/flash/data/transforms.py index 0a26224791..67b1229ad4 100644 --- a/flash/data/transforms.py +++ b/flash/data/transforms.py @@ -27,15 +27,41 @@ def __init__(self, keys: Union[str, Sequence[str]], *args): self.keys = keys def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: - inputs = [x[key] for key in filter(lambda key: key in x, self.keys)] + keys = list(filter(lambda key: key in x, self.keys)) + inputs = [x[key] for key in keys] if len(inputs) > 0: - outputs = super().forward(*inputs) - if not isinstance(outputs, tuple): + if len(inputs) == 1: + inputs = inputs[0] + outputs = super().forward(inputs) + if not isinstance(outputs, Sequence): outputs = (outputs, ) result = {} result.update(x) - for i, key in enumerate(self.keys): + for i, key in enumerate(keys): result[key] = outputs[i] return result return x + + +class KorniaParallelTransforms(nn.Sequential): + """The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each + input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when + multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask).""" + + def __init__(self, *args): + super().__init__(*[convert_to_modules(arg) for arg in args]) + + def forward(self, inputs: Any): + result = list(inputs) if isinstance(inputs, Sequence) else [inputs] + for transform in self.children(): + inputs = result + for i, input in enumerate(inputs): + if hasattr(transform, "_params") and bool(transform._params): + params = transform._params + result[i] = transform(input, params) + else: # case for non random transforms + result[i] = transform(input) + if hasattr(transform, "_params") and bool(transform._params): + transform._params = None + return result diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 39dce803d8..346c84870a 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -2,3 +2,4 @@ from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder +from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 928605b244..79d0fca863 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -73,7 +73,7 @@ def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: for key in sample.keys(): if torch.is_tensor(sample[key]): sample[key] = sample[key].squeeze(0) - return default_collate(samples) + return super().collate(samples) @property def default_train_transforms(self) -> Optional[Dict[str, Callable]]: diff --git a/flash/vision/segmentation/__init__.py b/flash/vision/segmentation/__init__.py new file mode 100644 index 0000000000..08f9742e47 --- /dev/null +++ b/flash/vision/segmentation/__init__.py @@ -0,0 +1,2 @@ +from flash.vision.segmentation.data import SemanticSegmentationData, SemanticSegmentationPreprocess +from flash.vision.segmentation.model import SemanticSegmentation diff --git a/flash/vision/segmentation/backbones.py b/flash/vision/segmentation/backbones.py new file mode 100644 index 0000000000..2a1661be6c --- /dev/null +++ b/flash/vision/segmentation/backbones.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.utils.imports import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + import torchvision + +SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") + + +@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50") +def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module: + model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained) + model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) + return model + + +@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101") +def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module: + model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained) + model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) + return model diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py new file mode 100644 index 0000000000..d674205786 --- /dev/null +++ b/flash/vision/segmentation/data.py @@ -0,0 +1,285 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torchvision +from PIL import Image +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS + +from flash.data.base_viz import BaseVisualization # for viz +from flash.data.callback import BaseDataFetcher +from flash.data.data_module import DataModule +from flash.data.data_source import DefaultDataKeys, DefaultDataSources, PathsDataSource +from flash.data.process import Preprocess +from flash.utils.imports import _MATPLOTLIB_AVAILABLE +from flash.vision.segmentation.serialization import SegmentationLabels +from flash.vision.segmentation.transforms import default_train_transforms, default_val_transforms + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + plt = None + + +class SemanticSegmentationPathsDataSource(PathsDataSource): + + def __init__(self): + super().__init__(IMG_EXTENSIONS) + + def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) -> Sequence[Mapping[str, Any]]: + input_data, target_data = data + + if self.isdir(input_data) and self.isdir(target_data): + input_files = os.listdir(input_data) + target_files = os.listdir(target_data) + + all_files = set(input_files).intersection(set(target_files)) + + if len(all_files) != len(input_files) or len(all_files) != len(target_files): + rank_zero_warn( + f"Found inconsistent files in input_dir: {input_data} and target_dir: {target_data}. Some files" + " have been dropped.", + UserWarning, + ) + + input_data = [os.path.join(input_data, file) for file in all_files] + target_data = [os.path.join(target_data, file) for file in all_files] + + if not isinstance(input_data, list) and not isinstance(target_data, list): + input_data = [input_data] + target_data = [target_data] + + if len(input_data) != len(target_data): + raise MisconfigurationException( + f"The number of input files ({len(input_data)}) and number of target files ({len(target_data)}) must be" + " the same.", + ) + + data = filter( + lambda sample: ( + has_file_allowed_extension(sample[0], self.extensions) and + has_file_allowed_extension(sample[1], self.extensions) + ), + zip(input_data, target_data), + ) + + return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] + + def predict_load_data(self, data: Union[str, List[str]]): + return super().predict_load_data(data) + + def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: + # unpack data paths + img_path = sample[DefaultDataKeys.INPUT] + img_labels_path = sample[DefaultDataKeys.TARGET] + + # load images directly to torch tensors + img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW + img_labels = img_labels[0] # HxW + + return {DefaultDataKeys.INPUT: img.float(), DefaultDataKeys.TARGET: img_labels.float()} + + def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()} + + +class SemanticSegmentationPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (196, 196), + ) -> None: + """Preprocess pipeline for semantic segmentation tasks. + + Args: + train_transform: Dictionary with the set of transforms to apply during training. + val_transform: Dictionary with the set of transforms to apply during validation. + test_transform: Dictionary with the set of transforms to apply during testing. + predict_transform: Dictionary with the set of transforms to apply during prediction. + image_size: A tuple with the expected output image size. + """ + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource()}, + default_data_source=DefaultDataSources.PATHS, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return { + **self.transforms, + "image_size": self.image_size, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: + # todo: Kornia transforms add batch dimension which need to be removed + for sample in samples: + for key in sample.keys(): + if torch.is_tensor(sample[key]): + sample[key] = sample[key].squeeze(0) + return super().collate(samples) + + @property + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return default_train_transforms(self.image_size) + + @property + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) + + @property + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) + + @property + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) + + +class SemanticSegmentationData(DataModule): + """Data module for semantic segmentation tasks.""" + + preprocess_cls = SemanticSegmentationPreprocess + + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + return SegmentationMatplotlibVisualization(*args, **kwargs) + + def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]): + self.data_fetcher.labels_map = labels_map + + def set_block_viz_window(self, value: bool) -> None: + """Setter method to switch on/off matplotlib to pop up windows.""" + self.data_fetcher.block_viz_window = value + + @classmethod + def from_folders( + cls, + train_folder: Optional[str] = None, + train_target_folder: Optional[str] = None, + val_folder: Optional[str] = None, + val_target_folder: Optional[str] = None, + test_folder: Optional[str] = None, + test_target_folder: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.PATHS, + (train_folder, train_target_folder), + (val_folder, val_target_folder), + (test_folder, test_target_folder), + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + + +class SegmentationMatplotlibVisualization(BaseVisualization): + """Process and show the image batch and its associated label using matplotlib. + """ + + def __init__(self): + super().__init__(self) + self.max_cols: int = 4 # maximum number of columns we accept + self.block_viz_window: bool = True # parameter to allow user to block visualisation windows + self.labels_map: Dict[int, Tuple[int, int, int]] = {} + + @staticmethod + def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: + out: np.ndarray + if isinstance(img, Image.Image): + out = np.array(img) + elif isinstance(img, torch.Tensor): + out = img.squeeze(0).permute(1, 2, 0).cpu().numpy() + else: + raise TypeError(f"Unknown image type. Got: {type(img)}.") + return out + + def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): + # define the image grid + cols: int = min(num_samples, self.max_cols) + rows: int = num_samples // cols + + if not _MATPLOTLIB_AVAILABLE: + raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib") + + # create figure and set title + fig, axs = plt.subplots(rows, cols) + fig.suptitle(title) + + for i, ax in enumerate(axs.ravel()): + # unpack images and labels + sample = data[i] + if isinstance(sample, dict): + image = sample[DefaultDataKeys.INPUT] + label = sample[DefaultDataKeys.TARGET] + elif isinstance(sample, tuple): + image = sample[0] + label = sample[1] + else: + raise TypeError(f"Unknown data type. Got: {type(data)}.") + # convert images and labels to numpy and stack horizontally + image_vis: np.ndarray = self._to_numpy(image.byte()) + label_tmp: torch.Tensor = SegmentationLabels.labels_to_image(label.squeeze().byte(), self.labels_map) + label_vis: np.ndarray = self._to_numpy(label_tmp) + img_vis = np.hstack((image_vis, label_vis)) + # send to visualiser + ax.imshow(img_vis) + ax.axis('off') + plt.show(block=self.block_viz_window) + + def show_load_sample(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_load_sample" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_post_tensor_transform" + self._show_images_and_labels(samples, len(samples), win_title) diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py new file mode 100644 index 0000000000..e543b341ed --- /dev/null +++ b/flash/vision/segmentation/model.py @@ -0,0 +1,128 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union + +import torch +from torch import nn +from torch.nn import functional as F +from torchmetrics import IoU + +from flash.core.classification import ClassificationTask +from flash.core.registry import FlashRegistry +from flash.data.data_source import DefaultDataKeys +from flash.data.process import Serializer +from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES +from flash.vision.segmentation.serialization import SegmentationLabels + + +class SemanticSegmentation(ClassificationTask): + """Task that performs semantic segmentation on images. + + Use a built in backbone + + Example:: + + from flash.vision import SemanticSegmentation + + segmentation = SemanticSegmentation( + num_classes=21, backbone="torchvision/fcn_resnet50" + ) + + Args: + num_classes: Number of classes to classify. + backbone: A string or (model, num_features) tuple to use to compute image features, + defaults to ``"torchvision/fcn_resnet50"``. + backbone_kwargs: Additional arguments for the backbone configuration. + pretrained: Use a pretrained backbone, defaults to ``False``. + loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. + optimizer: Optimizer to use for training, defaults to :class:`torch.optim.AdamW`. + metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.IoU`. + learning_rate: Learning rate to use for training, defaults to ``1e-3``. + multi_label: Whether the targets are multi-label or not. + serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs. + """ + + backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES + + def __init__( + self, + num_classes: int, + backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50", + backbone_kwargs: Optional[Dict] = None, + pretrained: bool = True, + loss_fn: Optional[Callable] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, + metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, + learning_rate: float = 1e-3, + multi_label: bool = False, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + ) -> None: + + if metrics is None: + metrics = IoU(num_classes=num_classes) + + if loss_fn is None: + loss_fn = F.cross_entropy + + # TODO: need to check for multi_label + if multi_label: + raise NotImplementedError("Multi-label not supported yet.") + + super().__init__( + model=None, + loss_fn=loss_fn, + optimizer=optimizer, + metrics=metrics, + learning_rate=learning_rate, + serializer=serializer or SegmentationLabels(), + ) + + self.save_hyperparameters() + + if not backbone_kwargs: + backbone_kwargs = {} + + # TODO: pretrained to True causes some issues + self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + + def forward(self, x) -> torch.Tensor: + # infer the image to the model + res: Union[torch.Tensor, Dict[str, torch.Tensor]] = self.backbone(x) + + # some frameworks like torchvision return a dict. + # In particular, torchvision segmentation models return the output logits + # in the key `out`. + out: torch.Tensor + if isinstance(res, dict): + out = res['out'] + else: + raise NotImplementedError(f"Unsupported output type: {type(out)}") + + return out diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py new file mode 100644 index 0000000000..50ba5be9a9 --- /dev/null +++ b/flash/vision/segmentation/serialization.py @@ -0,0 +1,82 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Dict, Optional, Tuple + +import torch + +from flash.data.process import Serializer +from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + plt = None + +if _KORNIA_AVAILABLE: + import kornia as K +else: + K = None + + +class SegmentationLabels(Serializer): + + def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False): + """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification + per pixel in the image for semantic segmentation tasks. + + Args: + labels_map: A dictionary that map the labels ids to pixel intensities. + visualise: Wether to visualise the image labels. + """ + super().__init__() + self.labels_map = labels_map + self.visualize = visualize + + @staticmethod + def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor: + """Function that given an image with labels ids and their pixels intrensity mapping, + creates a RGB representation for visualisation purposes. + """ + assert len(img_labels.shape) == 2, img_labels.shape + H, W = img_labels.shape + out = torch.empty(3, H, W, dtype=torch.uint8) + for label_id, label_val in labels_map.items(): + mask = (img_labels == label_id) + for i in range(3): + out[i].masked_fill_(mask, label_val[i]) + return out + + @staticmethod + def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: + labels_map: Dict[int, Tuple[int, int, int]] = {} + for i in range(num_classes): + labels_map[i] = torch.randint(0, 255, (3, )) + return labels_map + + def serialize(self, sample: torch.Tensor) -> torch.Tensor: + assert len(sample.shape) == 3, sample.shape + labels = torch.argmax(sample, dim=-3) # HxW + if self.visualize and os.getenv("FLASH_TESTING", "0") == "0": + if self.labels_map is None: + # create random colors map + num_classes = sample.shape[-3] + labels_map = self.create_random_labels_map(num_classes) + else: + labels_map = self.labels_map + labels_vis = self.labels_to_image(labels, labels_map) + labels_vis = K.utils.tensor_to_image(labels_vis) + plt.imshow(labels_vis) + plt.show() + return labels diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py new file mode 100644 index 0000000000..1cf491793f --- /dev/null +++ b/flash/vision/segmentation/transforms.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Tuple + +import kornia as K +import torch +import torch.nn as nn + +from flash.data.data_source import DefaultDataKeys +from flash.data.transforms import ApplyToKeys, KorniaParallelTransforms + + +def prepare_target(tensor: torch.Tensor) -> torch.Tensor: + return tensor.long().squeeze() + + +def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + return { + "post_tensor_transform": nn.Sequential( + ApplyToKeys( + [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + KorniaParallelTransforms( + K.geometry.Resize(image_size, interpolation='nearest'), + K.augmentation.RandomHorizontalFlip(p=0.75), + ), + ), + ApplyToKeys(DefaultDataKeys.TARGET, prepare_target), + ), + "per_batch_transform_on_device": ApplyToKeys( + DefaultDataKeys.INPUT, + K.enhance.Normalize(0., 255.), + K.augmentation.ColorJitter(0.4, p=0.5), + ), + } + + +def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + return { + "post_tensor_transform": nn.Sequential( + ApplyToKeys( + [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest')), + ), + ApplyToKeys(DefaultDataKeys.TARGET, prepare_target), + ), + "per_batch_transform_on_device": ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)), + } diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py new file mode 100644 index 0000000000..3676353ec8 --- /dev/null +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -0,0 +1,66 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.data.utils import download_data +from flash.vision import SemanticSegmentation, SemanticSegmentationData +from flash.vision.segmentation.serialization import SegmentationLabels + +# 1. Download the data +# This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator. +# The data was generated as part of the Lyft Udacity Challenge. +# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge +download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/" +) + +# 2.1 Load the data +datamodule = SemanticSegmentationData.from_folders( + train_folder="data/CameraRGB", + train_target_folder="data/CameraSeg", + batch_size=4, + val_split=0.3, + image_size=(200, 200), # (600, 800) +) + +# 2.2 Visualise the samples +labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) +datamodule.set_labels_map(labels_map) +datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) + +# 3. Build the model +model = SemanticSegmentation( + backbone="torchvision/fcn_resnet50", + num_classes=21, +) + +# 4. Create the trainer. +trainer = flash.Trainer( + max_epochs=1, + fast_dev_run=1, +) + +# 5. Train the model +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 6. Predict what's on a few images! +model.serializer = SegmentationLabels(labels_map, visualize=True) + +predictions = model.predict([ + "data/CameraRGB/F61-1.png", + "data/CameraRGB/F62-1.png", + "data/CameraRGB/F63-1.png", +]) + +# 7. Save it! +trainer.save_checkpoint("semantic_segmentation_model.pt") diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py new file mode 100644 index 0000000000..f507f2a6a6 --- /dev/null +++ b/flash_examples/predict/semantic_segmentation.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.data.utils import download_data +from flash.vision import SemanticSegmentation +from flash.vision.segmentation.serialization import SegmentationLabels + +# 1. Download the data +# This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator. +# The data was generated as part of the Lyft Udacity Challenge. +# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge +download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/" +) + +# 2. Load the model from a checkpoint +model = SemanticSegmentation.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" +) +model.serializer = SegmentationLabels(visualize=True) + +# 3. Predict what's on a few images and visualize! +predictions = model.predict([ + "data/CameraRGB/F61-1.png", + "data/CameraRGB/F62-1.png", + "data/CameraRGB/F63-1.png", +]) diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index f4748a5149..b53db09b6e 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -157,8 +157,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: for _ in range(num_tests): for fcn_name in _CALLBACK_FUNCS: + dm.data_fetcher.reset() fcn = getattr(dm, f"show_{stage}_batch") - fcn(fcn_name, reset=True) + fcn(fcn_name, reset=False) is_predict = stage == "predict" diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 2fc4ee18f3..a60ecbf021 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -58,6 +58,7 @@ def run_test(filepath): ("finetuning", "image_classification.py"), ("finetuning", "image_classification_multi_label.py"), # ("finetuning", "object_detection.py"), # TODO: takes too long. + ("finetuning", "semantic_segmentation.py"), # ("finetuning", "summarization.py"), # TODO: takes too long. ("finetuning", "tabular_classification.py"), # ("finetuning", "video_classification.py"), @@ -65,6 +66,7 @@ def run_test(filepath): ("finetuning", "translation.py"), ("predict", "image_classification.py"), ("predict", "image_classification_multi_label.py"), + ("predict", "semantic_segmentation.py"), ("predict", "tabular_classification.py"), # ("predict", "text_classification.py"), ("predict", "image_embedder.py"), diff --git a/tests/vision/segmentation/__init__.py b/tests/vision/segmentation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py new file mode 100644 index 0000000000..bd51f09d21 --- /dev/null +++ b/tests/vision/segmentation/test_data.py @@ -0,0 +1,303 @@ +import os +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import pytest +import torch +from PIL import Image +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash import Trainer +from flash.data.data_source import DefaultDataKeys +from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess + + +def build_checkboard(n, m, k=8): + x = np.zeros((n, m)) + x[k::k * 2, ::k] = 1 + x[::k * 2, k::k * 2] = 1 + return x + + +def _rand_image(size: Tuple[int, int]): + data = build_checkboard(*size).astype(np.uint8)[..., None].repeat(3, -1) + return Image.fromarray(data) + + +# usually labels come as rgb images -> need to map to labels +def _rand_labels(size: Tuple[int, int], num_classes: int): + data: np.ndarray = np.random.randint(0, num_classes, (*size, 1)) + data = data.repeat(3, axis=-1) + return Image.fromarray(data.astype(np.uint8)) + + +def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int): + for img_file in image_files: + _rand_image(size).save(img_file) + + for label_file in label_files: + _rand_labels(size, num_classes).save(label_file) + + +class TestSemanticSegmentationPreprocess: + + @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") + def test_smoke(self): + prep = SemanticSegmentationPreprocess() + assert prep is not None + + +class TestSemanticSegmentationData: + + def test_smoke(self): + dm = SemanticSegmentationData() + assert dm is not None + + def test_from_folders(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + os.makedirs(str(tmp_dir / "images")) + os.makedirs(str(tmp_dir / "targets")) + + images = [ + str(tmp_dir / "images" / "img1.png"), + str(tmp_dir / "images" / "img2.png"), + str(tmp_dir / "images" / "img3.png"), + ] + + targets = [ + str(tmp_dir / "targets" / "img1.png"), + str(tmp_dir / "targets" / "img2.png"), + str(tmp_dir / "targets" / "img3.png"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (196, 196) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + dm = SemanticSegmentationData.from_folders( + train_folder=str(tmp_dir / "images"), + train_target_folder=str(tmp_dir / "targets"), + val_folder=str(tmp_dir / "images"), + val_target_folder=str(tmp_dir / "targets"), + test_folder=str(tmp_dir / "images"), + test_target_folder=str(tmp_dir / "targets"), + batch_size=2, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check val data + data = next(iter(dm.val_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + def test_from_folders_warning(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + os.makedirs(str(tmp_dir / "images")) + os.makedirs(str(tmp_dir / "targets")) + + images = [ + str(tmp_dir / "images" / "img1.png"), + str(tmp_dir / "images" / "img3.png"), + ] + + targets = [ + str(tmp_dir / "targets" / "img1.png"), + str(tmp_dir / "targets" / "img2.png"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (196, 196) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + with pytest.warns(UserWarning, match="Found inconsistent files"): + dm = SemanticSegmentationData.from_folders( + train_folder=str(tmp_dir / "images"), + train_target_folder=str(tmp_dir / "targets"), + batch_size=1, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1, 196, 196) + + def test_from_files(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + images = [ + str(tmp_dir / "img1.png"), + str(tmp_dir / "img2.png"), + str(tmp_dir / "img3.png"), + ] + + targets = [ + str(tmp_dir / "labels_img1.png"), + str(tmp_dir / "labels_img2.png"), + str(tmp_dir / "labels_img3.png"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (196, 196) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + dm = SemanticSegmentationData.from_files( + train_files=images, + train_targets=targets, + val_files=images, + val_targets=targets, + test_files=images, + test_targets=targets, + batch_size=2, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check val data + data = next(iter(dm.val_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + def test_from_files_warning(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + images = [ + str(tmp_dir / "img1.png"), + str(tmp_dir / "img2.png"), + str(tmp_dir / "img3.png"), + ] + + targets = [ + str(tmp_dir / "labels_img1.png"), + str(tmp_dir / "labels_img2.png"), + str(tmp_dir / "labels_img3.png"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (196, 196) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + with pytest.raises(MisconfigurationException, match="The number of input files"): + SemanticSegmentationData.from_files( + train_files=images, + train_targets=targets + [str(tmp_dir / "labels_img4.png")], + batch_size=2, + num_workers=0, + ) + + def test_map_labels(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + images = [ + str(tmp_dir / "img1.png"), + str(tmp_dir / "img2.png"), + str(tmp_dir / "img3.png"), + ] + + targets = [ + str(tmp_dir / "labels_img1.png"), + str(tmp_dir / "labels_img2.png"), + str(tmp_dir / "labels_img3.png"), + ] + + labels_map: Dict[int, Tuple[int, int, int]] = { + 0: [0, 0, 0], + 1: [255, 255, 255], + } + + num_classes: int = len(labels_map.keys()) + img_size: Tuple[int, int] = (196, 196) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + dm = SemanticSegmentationData.from_files( + train_files=images, + train_targets=targets, + val_files=images, + val_targets=targets, + batch_size=2, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + dm.set_labels_map(labels_map) + dm.show_train_batch("load_sample") + dm.show_train_batch("to_tensor_transform") + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + assert labels.min().item() == 0 + assert labels.max().item() == 1 + assert labels.dtype == torch.int64 + + # now train with `fast_dev_run` + model = SemanticSegmentation(num_classes=2, backbone="torchvision/fcn_resnet50") + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.finetune(model, dm, strategy="freeze_unfreeze") diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py new file mode 100644 index 0000000000..0ebaa5d956 --- /dev/null +++ b/tests/vision/segmentation/test_model.py @@ -0,0 +1,80 @@ +from typing import Tuple + +import pytest +import torch + +from flash import Trainer +from flash.data.data_source import DefaultDataKeys +from flash.vision import SemanticSegmentation + +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + size: Tuple[int, int] = (224, 224) + num_classes: int = 8 + + def __getitem__(self, index): + return { + DefaultDataKeys.INPUT: torch.rand(3, *self.size), + DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, self.size), + } + + def __len__(self) -> int: + return 10 + + +# ============================== + + +def test_smoke(): + model = SemanticSegmentation(num_classes=1) + assert model is not None + + +@pytest.mark.parametrize("num_classes", [8, 256]) +@pytest.mark.parametrize("img_shape", [(1, 3, 224, 192), (2, 3, 127, 212)]) +def test_forward(num_classes, img_shape): + model = SemanticSegmentation( + num_classes=num_classes, + backbone='torchvision/fcn_resnet50', + ) + + B, C, H, W = img_shape + img = torch.rand(B, C, H, W) + + out = model(img) + assert out.shape == (B, num_classes, H, W) + + +@pytest.mark.parametrize( + "backbone", + [ + "torchvision/fcn_resnet50", + "torchvision/fcn_resnet101", + ], +) +def test_init_train(tmpdir, backbone): + model = SemanticSegmentation(num_classes=10, backbone=backbone) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.finetune(model, train_dl, strategy="freeze_unfreeze") + + +def test_non_existent_backbone(): + with pytest.raises(KeyError): + SemanticSegmentation(2, "i am never going to implement this lol") + + +def test_freeze(): + model = SemanticSegmentation(2) + model.freeze() + for p in model.backbone.parameters(): + assert p.requires_grad is False + + +def test_unfreeze(): + model = SemanticSegmentation(2) + model.unfreeze() + for p in model.backbone.parameters(): + assert p.requires_grad is True diff --git a/tests/vision/segmentation/test_serialization.py b/tests/vision/segmentation/test_serialization.py new file mode 100644 index 0000000000..a971c91fbf --- /dev/null +++ b/tests/vision/segmentation/test_serialization.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from flash.vision.segmentation.serialization import SegmentationLabels + + +class TestSemanticSegmentationLabels: + + def test_smoke(self): + serial = SegmentationLabels() + assert serial is not None + assert serial.labels_map is None + assert serial.visualize is False + + def test_exception(self): + serial = SegmentationLabels() + + with pytest.raises(Exception): + sample = torch.zeros(1, 5, 2, 3) + serial.serialize(sample) + + with pytest.raises(Exception): + sample = torch.zeros(2, 3) + serial.serialize(sample) + + def test_serialize(self): + serial = SegmentationLabels() + + sample = torch.zeros(5, 2, 3) + sample[1, 1, 2] = 1 # add peak in class 2 + sample[3, 0, 1] = 1 # add peak in class 4 + + classes = serial.serialize(sample) + assert classes[1, 2] == 1 + assert classes[0, 1] == 3 + + # TODO: implement me + def test_create_random_labels(self): + pass + + # TODO: implement me + def test_labels_to_image(self): + pass