Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
create segmentation keys enum
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Apr 28, 2021
1 parent 7f8ebab commit 341595e
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 93 deletions.
55 changes: 0 additions & 55 deletions flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

# for visualisation
import kornia as K
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -136,55 +133,3 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]:
"No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning
)
return classes


class SegmentationLabels(Serializer):

def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualise: bool = False):
"""A :class:`.Serializer` which converts the model outputs to the label of the argmax classification
per pixel in the image for semantic segmentation tasks.
Args:
labels_map: A dictionary that map the labels ids to pixel intensities.
visualise: Wether to visualise the image labels.
"""
super().__init__()
self.labels_map = labels_map
self.visualise = visualise

@staticmethod
def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor:
"""Function that given an image with labels ids and their pixels intrensity mapping,
creates a RGB representation for visualisation purposes.
"""
assert len(img_labels.shape) == 2, img_labels.shape
H, W = img_labels.shape
out = torch.empty(3, H, W, dtype=torch.uint8)
for label_id, label_val in labels_map.items():
mask = (img_labels == label_id)
for i in range(3):
out[i].masked_fill_(mask, label_val[i])
return out

@staticmethod
def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]:
labels_map: Dict[int, Tuple[int, int, int]] = {}
for i in range(num_classes):
labels_map[i] = torch.randint(0, 255, (3, ))
return labels_map

def serialize(self, sample: torch.Tensor) -> torch.Tensor:
assert len(sample.shape) == 3, sample.shape
labels = torch.argmax(sample, dim=-3) # HxW
if self.visualise:
if self.labels_map is None:
# create random colors map
num_classes = sample.shape[-3]
labels_map = self.create_random_labels_map(num_classes)
else:
labels_map = self.labels_map
labels_vis = self.labels_to_image(labels, labels_map)
labels_vis = K.utils.tensor_to_image(labels_vis)
plt.imshow(labels_vis)
plt.show()
return labels
48 changes: 21 additions & 27 deletions flash/vision/segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from torch.utils.data import Dataset

import flash.vision.segmentation.transforms as T
from flash.core.classification import SegmentationLabels
from flash.data.auto_dataset import AutoDataset
from flash.data.base_viz import BaseVisualization # for viz
from flash.data.callback import BaseDataFetcher
from flash.data.data_module import DataModule
from flash.data.process import Preprocess
from flash.utils.imports import _MATPLOTLIB_AVAILABLE
from flash.vision.segmentation.serialization import SegmentationKeys, SegmentationLabels

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -89,8 +89,7 @@ def _resolve_transforms(
predict_transform,
)

def load_sample(self, sample: Union[str, Tuple[str,
str]]) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
def load_sample(self, sample: Union[str, Tuple[str, str]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if not isinstance(sample, (
str,
tuple,
Expand All @@ -108,9 +107,13 @@ def load_sample(self, sample: Union[str, Tuple[str,
img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW
img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW

return {'images': img, 'masks': img_labels}
return {SegmentationKeys.IMAGES: img, SegmentationKeys.MASKS: img_labels}

def post_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
# TODO: this routine should be moved to `per_batch_transform` once we have a way to
# forward the labels to the loss function:.
def post_tensor_transform(
self, sample: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if isinstance(sample, torch.Tensor): # case for predict
out = sample.float() / 255. # TODO: define predict transforms
return out
Expand All @@ -119,27 +122,15 @@ def post_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tu
raise TypeError(f"Invalid type, expected `dict`. Got: {sample}.")

# arrange data as floating point and batch before the augmentations
sample['images'] = sample['images'][None].float().contiguous() # 1xCxHxW
sample['masks'] = sample['masks'][None, :1].float().contiguous() # 1x1xHxW
sample[SegmentationKeys.IMAGES] = sample[SegmentationKeys.IMAGES][None].float().contiguous() # 1xCxHxW
sample[SegmentationKeys.MASKS] = sample[SegmentationKeys.MASKS][None, :1].float().contiguous() # 1x1xHxW

out: Dict[str, torch.Tensor] = self.current_transform(sample)

return out['images'][0], out['masks'][0, 0].long()
return out[SegmentationKeys.IMAGES][0], out[SegmentationKeys.MASKS][0, 0].long()

# TODO: the labels are not clear how to forward to the loss once are transform from this point
'''def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
if not isinstance(sample, list):
raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.")
img, img_labels = sample
# THIS IS CRASHING
# out1 = self.current_transform(img) # images
# out2 = self.current_transform(img_labels) # labels
# return out1, out2
return img, img_labels
# TODO: the labels are not clear how to forward to the loss once are transform from this point
def per_batch_transform_on_device(self, sample: Any) -> Any:
pass'''
# TODO: implement `per_batch_transform` and `per_batch_transform_on_device`
##


class SemanticSegmentationData(DataModule):
Expand All @@ -154,7 +145,7 @@ def _check_valid_filepaths(filepaths: List[str]):

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return _MatplotlibVisualization(*args, **kwargs)
return SegmentationMatplotlibVisualization(*args, **kwargs)

def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]):
self.data_fetcher.labels_map = labels_map
Expand Down Expand Up @@ -249,12 +240,15 @@ def from_filepaths(
)


class _MatplotlibVisualization(BaseVisualization):
class SegmentationMatplotlibVisualization(BaseVisualization):
"""Process and show the image batch and its associated label using matplotlib.
"""
max_cols: int = 4 # maximum number of columns we accept
block_viz_window: bool = True # parameter to allow user to block visualisation windows
labels_map: Dict[int, Tuple[int, int, int]] = {}

def __init__(self):
super().__init__(self)
self.max_cols: int = 4 # maximum number of columns we accept
self.block_viz_window: bool = True # parameter to allow user to block visualisation windows
self.labels_map: Dict[int, Tuple[int, int, int]] = {}

@staticmethod
def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray:
Expand Down
3 changes: 2 additions & 1 deletion flash/vision/segmentation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from torch.nn import functional as F
from torchmetrics import Accuracy, IoU

from flash.core.classification import ClassificationTask, SegmentationLabels
from flash.core.classification import ClassificationTask
from flash.core.registry import FlashRegistry
from flash.data.process import Preprocess, Serializer
from flash.utils.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.vision.segmentation.serialization import SegmentationLabels

if _TORCHVISION_AVAILABLE:
import torchvision
Expand Down
87 changes: 87 additions & 0 deletions flash/vision/segmentation/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Dict, Optional, Tuple

import torch

from flash.data.process import ProcessState, Serializer
from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE

if _MATPLOTLIB_AVAILABLE:
import matplotlib.pyplot as plt
else:
plt = None

if _KORNIA_AVAILABLE:
import kornia as K
else:
K = None


class SegmentationKeys(Enum):
IMAGES = 'images'
MASKS = 'masks'


class SegmentationLabels(Serializer):

def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False):
"""A :class:`.Serializer` which converts the model outputs to the label of the argmax classification
per pixel in the image for semantic segmentation tasks.
Args:
labels_map: A dictionary that map the labels ids to pixel intensities.
visualise: Wether to visualise the image labels.
"""
super().__init__()
self.labels_map = labels_map
self.visualize = visualize

@staticmethod
def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor:
"""Function that given an image with labels ids and their pixels intrensity mapping,
creates a RGB representation for visualisation purposes.
"""
assert len(img_labels.shape) == 2, img_labels.shape
H, W = img_labels.shape
out = torch.empty(3, H, W, dtype=torch.uint8)
for label_id, label_val in labels_map.items():
mask = (img_labels == label_id)
for i in range(3):
out[i].masked_fill_(mask, label_val[i])
return out

@staticmethod
def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]:
labels_map: Dict[int, Tuple[int, int, int]] = {}
for i in range(num_classes):
labels_map[i] = torch.randint(0, 255, (3, ))
return labels_map

def serialize(self, sample: torch.Tensor) -> torch.Tensor:
assert len(sample.shape) == 3, sample.shape
labels = torch.argmax(sample, dim=-3) # HxW
if self.visualize:
if self.labels_map is None:
# create random colors map
num_classes = sample.shape[-3]
labels_map = self.create_random_labels_map(num_classes)
else:
labels_map = self.labels_map
labels_vis = self.labels_to_image(labels, labels_map)
labels_vis = K.utils.tensor_to_image(labels_vis)
plt.imshow(labels_vis)
plt.show()
return labels
16 changes: 10 additions & 6 deletions flash/vision/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
import torch.nn as nn

from flash.vision.segmentation.serialization import SegmentationKeys


class ApplyTransformToKeys(nn.Sequential):

Expand All @@ -29,8 +31,10 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
for seq in self.children():
for aug in seq:
for key in self.keys:
# check wether the transform was applied
# and apply transform
# kornia caches the random parameters in `_params` after every
# forward call in the augmentation object. We check whether the
# random parameters have been generated so that we can apply the
# same parameters to images and masks.
if hasattr(aug, "_params") and bool(aug._params):
params = aug._params
x[key] = aug(x[key], params)
Expand All @@ -42,13 +46,13 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
return {
"post_tensor_transform": nn.Sequential(
ApplyTransformToKeys(['images', 'masks'],
ApplyTransformToKeys([SegmentationKeys.IMAGES, SegmentationKeys.MASKS],
nn.Sequential(
K.geometry.Resize(image_size, interpolation='nearest'),
K.augmentation.RandomHorizontalFlip(p=0.75),
)),
ApplyTransformToKeys(
['images'],
[SegmentationKeys.IMAGES],
nn.Sequential(
K.enhance.Normalize(0., 255.),
K.augmentation.ColorJitter(0.4, p=0.5),
Expand All @@ -63,8 +67,8 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]
def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
return {
"post_tensor_transform": nn.Sequential(
ApplyTransformToKeys(['images', 'masks'],
ApplyTransformToKeys([SegmentationKeys.IMAGES, SegmentationKeys.MASKS],
nn.Sequential(K.geometry.Resize(image_size, interpolation='nearest'), )),
ApplyTransformToKeys(['images'], nn.Sequential(K.enhance.Normalize(0., 255.), )),
ApplyTransformToKeys([SegmentationKeys.IMAGES], nn.Sequential(K.enhance.Normalize(0., 255.), )),
),
}
8 changes: 4 additions & 4 deletions flash_examples/finetuning/semantic_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import torch

import flash
from flash.core.classification import SegmentationLabels
from flash.data.utils import download_data
from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess
from flash.vision.segmentation.serialization import SegmentationLabels

# 1. Download the data
# This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator.
Expand Down Expand Up @@ -50,7 +50,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]:
datamodule = SemanticSegmentationData.from_filepaths(
train_filepaths=images_filepaths,
train_labels=labels_filepaths,
batch_size=4,
batch_size=2,
val_split=0.3, # TODO: this needs to be implemented
image_size=(300, 400), # (600, 800)
num_workers=0,
Expand All @@ -70,7 +70,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]:

# 4. Create the trainer.
trainer = flash.Trainer(
max_epochs=1,
max_epochs=10,
gpus=1,
#precision=16, # why slower ? :)
)
Expand All @@ -79,7 +79,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]:
trainer.finetune(model, datamodule=datamodule, strategy='freeze')

# 6. Predict what's on a few images!
model.serializer = SegmentationLabels(labels_map, visualise=True)
model.serializer = SegmentationLabels(labels_map, visualize=True)

predictions = model.predict([
'data/CameraRGB/F61-1.png',
Expand Down

0 comments on commit 341595e

Please sign in to comment.