From 341595e683805bfd7cd40fcc6afaae4d9931b42d Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 20:41:57 +0200 Subject: [PATCH] create segmentation keys enum --- flash/core/classification.py | 55 ------------ flash/vision/segmentation/data.py | 48 +++++----- flash/vision/segmentation/model.py | 3 +- flash/vision/segmentation/serialization.py | 87 +++++++++++++++++++ flash/vision/segmentation/transforms.py | 16 ++-- .../finetuning/semantic_segmentation.py | 8 +- 6 files changed, 124 insertions(+), 93 deletions(-) create mode 100644 flash/vision/segmentation/serialization.py diff --git a/flash/core/classification.py b/flash/core/classification.py index a08157d5523..2df94ebb227 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -14,9 +14,6 @@ from dataclasses import dataclass from typing import Any, Dict, List, Mapping, Optional, Tuple, Union -# for visualisation -import kornia as K -import matplotlib.pyplot as plt import torch import torch.nn.functional as F from pytorch_lightning.utilities import rank_zero_warn @@ -136,55 +133,3 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: "No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning ) return classes - - -class SegmentationLabels(Serializer): - - def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualise: bool = False): - """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification - per pixel in the image for semantic segmentation tasks. - - Args: - labels_map: A dictionary that map the labels ids to pixel intensities. - visualise: Wether to visualise the image labels. - """ - super().__init__() - self.labels_map = labels_map - self.visualise = visualise - - @staticmethod - def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor: - """Function that given an image with labels ids and their pixels intrensity mapping, - creates a RGB representation for visualisation purposes. - """ - assert len(img_labels.shape) == 2, img_labels.shape - H, W = img_labels.shape - out = torch.empty(3, H, W, dtype=torch.uint8) - for label_id, label_val in labels_map.items(): - mask = (img_labels == label_id) - for i in range(3): - out[i].masked_fill_(mask, label_val[i]) - return out - - @staticmethod - def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: - labels_map: Dict[int, Tuple[int, int, int]] = {} - for i in range(num_classes): - labels_map[i] = torch.randint(0, 255, (3, )) - return labels_map - - def serialize(self, sample: torch.Tensor) -> torch.Tensor: - assert len(sample.shape) == 3, sample.shape - labels = torch.argmax(sample, dim=-3) # HxW - if self.visualise: - if self.labels_map is None: - # create random colors map - num_classes = sample.shape[-3] - labels_map = self.create_random_labels_map(num_classes) - else: - labels_map = self.labels_map - labels_vis = self.labels_to_image(labels, labels_map) - labels_vis = K.utils.tensor_to_image(labels_vis) - plt.imshow(labels_vis) - plt.show() - return labels diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 9092b9b449d..8584e537ee8 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -23,13 +23,13 @@ from torch.utils.data import Dataset import flash.vision.segmentation.transforms as T -from flash.core.classification import SegmentationLabels from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE +from flash.vision.segmentation.serialization import SegmentationKeys, SegmentationLabels if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -89,8 +89,7 @@ def _resolve_transforms( predict_transform, ) - def load_sample(self, sample: Union[str, Tuple[str, - str]]) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + def load_sample(self, sample: Union[str, Tuple[str, str]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: if not isinstance(sample, ( str, tuple, @@ -108,9 +107,13 @@ def load_sample(self, sample: Union[str, Tuple[str, img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW - return {'images': img, 'masks': img_labels} + return {SegmentationKeys.IMAGES: img, SegmentationKeys.MASKS: img_labels} - def post_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO: this routine should be moved to `per_batch_transform` once we have a way to + # forward the labels to the loss function:. + def post_tensor_transform( + self, sample: Union[torch.Tensor, Dict[str, torch.Tensor]] + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if isinstance(sample, torch.Tensor): # case for predict out = sample.float() / 255. # TODO: define predict transforms return out @@ -119,27 +122,15 @@ def post_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tu raise TypeError(f"Invalid type, expected `dict`. Got: {sample}.") # arrange data as floating point and batch before the augmentations - sample['images'] = sample['images'][None].float().contiguous() # 1xCxHxW - sample['masks'] = sample['masks'][None, :1].float().contiguous() # 1x1xHxW + sample[SegmentationKeys.IMAGES] = sample[SegmentationKeys.IMAGES][None].float().contiguous() # 1xCxHxW + sample[SegmentationKeys.MASKS] = sample[SegmentationKeys.MASKS][None, :1].float().contiguous() # 1x1xHxW out: Dict[str, torch.Tensor] = self.current_transform(sample) - return out['images'][0], out['masks'][0, 0].long() + return out[SegmentationKeys.IMAGES][0], out[SegmentationKeys.MASKS][0, 0].long() - # TODO: the labels are not clear how to forward to the loss once are transform from this point - '''def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - if not isinstance(sample, list): - raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") - img, img_labels = sample - # THIS IS CRASHING - # out1 = self.current_transform(img) # images - # out2 = self.current_transform(img_labels) # labels - # return out1, out2 - return img, img_labels - - # TODO: the labels are not clear how to forward to the loss once are transform from this point - def per_batch_transform_on_device(self, sample: Any) -> Any: - pass''' + # TODO: implement `per_batch_transform` and `per_batch_transform_on_device` + ## class SemanticSegmentationData(DataModule): @@ -154,7 +145,7 @@ def _check_valid_filepaths(filepaths: List[str]): @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - return _MatplotlibVisualization(*args, **kwargs) + return SegmentationMatplotlibVisualization(*args, **kwargs) def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]): self.data_fetcher.labels_map = labels_map @@ -249,12 +240,15 @@ def from_filepaths( ) -class _MatplotlibVisualization(BaseVisualization): +class SegmentationMatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. """ - max_cols: int = 4 # maximum number of columns we accept - block_viz_window: bool = True # parameter to allow user to block visualisation windows - labels_map: Dict[int, Tuple[int, int, int]] = {} + + def __init__(self): + super().__init__(self) + self.max_cols: int = 4 # maximum number of columns we accept + self.block_viz_window: bool = True # parameter to allow user to block visualisation windows + self.labels_map: Dict[int, Tuple[int, int, int]] = {} @staticmethod def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index 2785eb65a3c..acc5a8d609c 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -19,10 +19,11 @@ from torch.nn import functional as F from torchmetrics import Accuracy, IoU -from flash.core.classification import ClassificationTask, SegmentationLabels +from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry from flash.data.process import Preprocess, Serializer from flash.utils.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.vision.segmentation.serialization import SegmentationLabels if _TORCHVISION_AVAILABLE: import torchvision diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py new file mode 100644 index 00000000000..907ed937c01 --- /dev/null +++ b/flash/vision/segmentation/serialization.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum +from typing import Dict, Optional, Tuple + +import torch + +from flash.data.process import ProcessState, Serializer +from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + plt = None + +if _KORNIA_AVAILABLE: + import kornia as K +else: + K = None + + +class SegmentationKeys(Enum): + IMAGES = 'images' + MASKS = 'masks' + + +class SegmentationLabels(Serializer): + + def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False): + """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification + per pixel in the image for semantic segmentation tasks. + + Args: + labels_map: A dictionary that map the labels ids to pixel intensities. + visualise: Wether to visualise the image labels. + """ + super().__init__() + self.labels_map = labels_map + self.visualize = visualize + + @staticmethod + def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor: + """Function that given an image with labels ids and their pixels intrensity mapping, + creates a RGB representation for visualisation purposes. + """ + assert len(img_labels.shape) == 2, img_labels.shape + H, W = img_labels.shape + out = torch.empty(3, H, W, dtype=torch.uint8) + for label_id, label_val in labels_map.items(): + mask = (img_labels == label_id) + for i in range(3): + out[i].masked_fill_(mask, label_val[i]) + return out + + @staticmethod + def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: + labels_map: Dict[int, Tuple[int, int, int]] = {} + for i in range(num_classes): + labels_map[i] = torch.randint(0, 255, (3, )) + return labels_map + + def serialize(self, sample: torch.Tensor) -> torch.Tensor: + assert len(sample.shape) == 3, sample.shape + labels = torch.argmax(sample, dim=-3) # HxW + if self.visualize: + if self.labels_map is None: + # create random colors map + num_classes = sample.shape[-3] + labels_map = self.create_random_labels_map(num_classes) + else: + labels_map = self.labels_map + labels_vis = self.labels_to_image(labels, labels_map) + labels_vis = K.utils.tensor_to_image(labels_vis) + plt.imshow(labels_vis) + plt.show() + return labels diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index b02778dd91b..9c56898dcfb 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn +from flash.vision.segmentation.serialization import SegmentationKeys + class ApplyTransformToKeys(nn.Sequential): @@ -29,8 +31,10 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: for seq in self.children(): for aug in seq: for key in self.keys: - # check wether the transform was applied - # and apply transform + # kornia caches the random parameters in `_params` after every + # forward call in the augmentation object. We check whether the + # random parameters have been generated so that we can apply the + # same parameters to images and masks. if hasattr(aug, "_params") and bool(aug._params): params = aug._params x[key] = aug(x[key], params) @@ -42,13 +46,13 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { "post_tensor_transform": nn.Sequential( - ApplyTransformToKeys(['images', 'masks'], + ApplyTransformToKeys([SegmentationKeys.IMAGES, SegmentationKeys.MASKS], nn.Sequential( K.geometry.Resize(image_size, interpolation='nearest'), K.augmentation.RandomHorizontalFlip(p=0.75), )), ApplyTransformToKeys( - ['images'], + [SegmentationKeys.IMAGES], nn.Sequential( K.enhance.Normalize(0., 255.), K.augmentation.ColorJitter(0.4, p=0.5), @@ -63,8 +67,8 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { "post_tensor_transform": nn.Sequential( - ApplyTransformToKeys(['images', 'masks'], + ApplyTransformToKeys([SegmentationKeys.IMAGES, SegmentationKeys.MASKS], nn.Sequential(K.geometry.Resize(image_size, interpolation='nearest'), )), - ApplyTransformToKeys(['images'], nn.Sequential(K.enhance.Normalize(0., 255.), )), + ApplyTransformToKeys([SegmentationKeys.IMAGES], nn.Sequential(K.enhance.Normalize(0., 255.), )), ), } diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index fb91be70398..00e1378cf46 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -18,9 +18,9 @@ import torch import flash -from flash.core.classification import SegmentationLabels from flash.data.utils import download_data from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess +from flash.vision.segmentation.serialization import SegmentationLabels # 1. Download the data # This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator. @@ -50,7 +50,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: datamodule = SemanticSegmentationData.from_filepaths( train_filepaths=images_filepaths, train_labels=labels_filepaths, - batch_size=4, + batch_size=2, val_split=0.3, # TODO: this needs to be implemented image_size=(300, 400), # (600, 800) num_workers=0, @@ -70,7 +70,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: # 4. Create the trainer. trainer = flash.Trainer( - max_epochs=1, + max_epochs=10, gpus=1, #precision=16, # why slower ? :) ) @@ -79,7 +79,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: trainer.finetune(model, datamodule=datamodule, strategy='freeze') # 6. Predict what's on a few images! -model.serializer = SegmentationLabels(labels_map, visualise=True) +model.serializer = SegmentationLabels(labels_map, visualize=True) predictions = model.predict([ 'data/CameraRGB/F61-1.png',