From e2f5f20b1c2f4f27e04f6efcf36096cf91960009 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 22 Apr 2021 19:59:17 +0200 Subject: [PATCH 01/53] semantic segmentation skeleton --- flash/vision/__init__.py | 1 + flash/vision/segmentation/__init__.py | 2 + flash/vision/segmentation/data.py | 107 ++++++++++++++++++++++++ flash/vision/segmentation/model.py | 83 ++++++++++++++++++ tests/vision/segmentation/test_data.py | 6 ++ tests/vision/segmentation/test_model.py | 19 +++++ 6 files changed, 218 insertions(+) create mode 100644 flash/vision/segmentation/__init__.py create mode 100644 flash/vision/segmentation/data.py create mode 100644 flash/vision/segmentation/model.py create mode 100644 tests/vision/segmentation/test_data.py create mode 100644 tests/vision/segmentation/test_model.py diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 39dce803d8..18f13e8650 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -2,3 +2,4 @@ from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder +from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData diff --git a/flash/vision/segmentation/__init__.py b/flash/vision/segmentation/__init__.py new file mode 100644 index 0000000000..d0720ceb44 --- /dev/null +++ b/flash/vision/segmentation/__init__.py @@ -0,0 +1,2 @@ +from flash.vision.segmentation.data import SemanticSegmentationData +from flash.vision.segmentation.model import SemanticSegmentation diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py new file mode 100644 index 0000000000..8de1a8b15a --- /dev/null +++ b/flash/vision/segmentation/data.py @@ -0,0 +1,107 @@ +from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union + +import torch +from torch.utils.data import Dataset + +from flash.data.auto_dataset import AutoDataset +from flash.data.data_module import DataModule +from flash.data.process import Preprocess + + +class SemantincSegmentationPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Union[Dict[str, Callable]]] = None, + val_transform: Optional[Union[Dict[str, Callable]]] = None, + test_transform: Optional[Union[Dict[str, Callable]]] = None, + predict_transform: Optional[Union[Dict[str, Callable]]] = None, + ) -> 'SemantincSegmentationPreprocess': + + # TODO: implement me + '''train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( + train_transform, val_transform, test_transform, predict_transform + )''' + super().__init__(train_transform, val_transform, test_transform, predict_transform) + + def load_data(self, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable: + pass + + def load_sample(sample) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + def collate(self, samples: Sequence) -> Any: + pass + + def pre_tensor_transform(self, sample: Any) -> Any: + pass + + def to_tensor_transform(self, sample: Any) -> Any: + pass + + def post_tensor_transform(self, sample: Any) -> Any: + pass + + def per_batch_transform(self, sample: Any) -> Any: + pass + + def per_batch_transform_on_device(self, sample: Any) -> Any: + pass + + +class SemanticSegmentationData(DataModule): + """Data module for semantic segmentation tasks.""" + + def __init__( + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None, + predict_dataset: Optional[Dataset] = None, + batch_size: int = 1, + num_workers: Optional[int] = None, + seed: int = 1234, + train_split: Optional[float] = None, + val_split: Optional[float] = None, + # test_split: Optional[float] = None, ## THIS WILL GO OUT + preprocess: Optional[Preprocess] = None, + ) -> None: + pass + + @classmethod + def from_filepaths( + cls, + train_filepaths: Optional[Sequence[str]] = None, + train_labels: Optional[Sequence[str]] = None, + val_filepaths: Optional[Sequence[str]] = None, + val_labels: Optional[Sequence[str]] = None, + test_filepaths: Optional[Sequence[str]] = None, + test_labels: Optional[Sequence[str]] = None, + predict_filepaths: Optional[Sequence[str]] = None, + train_transform: Union[str, Dict] = 'default', + val_transform: Union[str, Dict] = 'default', + test_transform: Union[str, Dict] = 'default', + predict_transform: Union[str, Dict] = 'default', + batch_size: int = 64, + num_workers: Optional[int] = None, + seed: Optional[int] = 42, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + ) -> 'SemanticSegmentationData': + + preprocess = preprocess or SemantincSegmentationPreprocess( + train_transform, + val_transform, + test_transform, + predict_transform, + ) + + return cls.from_load_data_inputs( + train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, + val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, + test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, + predict_load_data_input=predict_filepaths, + batch_size=batch_size, + num_workers=num_workers, + preprocess=preprocess, + seed=seed, + ) diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py new file mode 100644 index 0000000000..98dedc07d4 --- /dev/null +++ b/flash/vision/segmentation/model.py @@ -0,0 +1,83 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from types import FunctionType +from typing import Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union + +import torch +from torch import nn +from torch.nn import functional as F +from torchmetrics import Accuracy + +from flash.core.classification import Classes, ClassificationTask +from flash.core.registry import FlashRegistry +from flash.data.process import Preprocess, Serializer + +SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") + + +class SemanticSegmentation(ClassificationTask): + """Task that performs semantic segmentation on images. + """ + + backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES + + def __init__( + self, + num_classes: int, + backbone: Union[str, Tuple[nn.Module, int]] = "resnet18", + backbone_kwargs: Optional[Dict] = None, + head: Optional[Union[FunctionType, nn.Module]] = None, + pretrained: bool = True, + loss_fn: Optional[Callable] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, + metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, + learning_rate: float = 1e-3, + multi_label: bool = False, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + ): + + if metrics is None: + metrics = Accuracy(subset_accuracy=multi_label) + + if loss_fn is None: + # loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy + loss_fn = F.cross_entropy + + super().__init__( + model=None, + loss_fn=loss_fn, + optimizer=optimizer, + metrics=metrics, + learning_rate=learning_rate, + serializer=serializer or Classes(multi_label=multi_label), + ) + + self.save_hyperparameters() + + if not backbone_kwargs: + backbone_kwargs = {} + + # TODO: implement first torchvision + self.backbone, num_features = None, 1 + '''if isinstance(backbone, tuple): + self.backbone, num_features = backbone + else: + self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)''' + + head = head(num_features, num_classes) if isinstance(head, FunctionType) else head + self.head = head or nn.Conv2d(num_features, num_classes, kernel_size=1) + + def forward(self, x) -> torch.Tensor: + x = self.backbone(x) + return self.head(x) diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py new file mode 100644 index 0000000000..c82ecc6a85 --- /dev/null +++ b/tests/vision/segmentation/test_data.py @@ -0,0 +1,6 @@ +from flash.vision import SemanticSegmentationData + + +def test_smoke(): + dm = SemanticSegmentationData() + assert dm is not None diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py new file mode 100644 index 0000000000..fdc98674bb --- /dev/null +++ b/tests/vision/segmentation/test_model.py @@ -0,0 +1,19 @@ +import pytest +import torch + +from flash.vision import SemanticSegmentation + + +def test_smoke(): + model = SemanticSegmentation(num_classes=1) + assert model is not None + + +@pytest.mark.skip(reason="todo") +def test_forward(): + num_classes = 5 + model = SemanticSegmentation(num_classes) + + img = torch.rand(1, 3, 224, 224) + out = model(img) + assert out.shape == (1, num_classes, 224, 224) From f3ce4c74dd0e83ce493be1426b5d7547815d1750 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Fri, 23 Apr 2021 13:17:23 +0200 Subject: [PATCH 02/53] expose and add smoke tests for preproces and datamodule --- flash/vision/__init__.py | 2 +- flash/vision/segmentation/__init__.py | 2 +- flash/vision/segmentation/data.py | 24 ++++++--- tests/vision/segmentation/__init__.py | 0 tests/vision/segmentation/test_data.py | 75 ++++++++++++++++++++++++-- 5 files changed, 91 insertions(+), 12 deletions(-) create mode 100644 tests/vision/segmentation/__init__.py diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 18f13e8650..526b1f2892 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -2,4 +2,4 @@ from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder -from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData +from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData, SemantincSegmentationPreprocess diff --git a/flash/vision/segmentation/__init__.py b/flash/vision/segmentation/__init__.py index d0720ceb44..520694b412 100644 --- a/flash/vision/segmentation/__init__.py +++ b/flash/vision/segmentation/__init__.py @@ -1,2 +1,2 @@ -from flash.vision.segmentation.data import SemanticSegmentationData +from flash.vision.segmentation.data import SemanticSegmentationData, SemantincSegmentationPreprocess from flash.vision.segmentation.model import SemanticSegmentation diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 8de1a8b15a..baa83165b1 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -1,5 +1,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +import kornia as K +import numpy as np import torch from torch.utils.data import Dataset @@ -12,18 +14,27 @@ class SemantincSegmentationPreprocess(Preprocess): def __init__( self, - train_transform: Optional[Union[Dict[str, Callable]]] = None, - val_transform: Optional[Union[Dict[str, Callable]]] = None, - test_transform: Optional[Union[Dict[str, Callable]]] = None, - predict_transform: Optional[Union[Dict[str, Callable]]] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, ) -> 'SemantincSegmentationPreprocess': # TODO: implement me '''train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( train_transform, val_transform, test_transform, predict_transform )''' + train_transform = dict(to_tensor_transform=self.to_tensor) + val_transform = dict(to_tensor_transform=self.to_tensor) + test_transform = dict(to_tensor_transform=self.to_tensor) + predict_transform = dict(to_tensor_transform=self.to_tensor) + super().__init__(train_transform, val_transform, test_transform, predict_transform) + @staticmethod + def to_tensor(self, x): + return K.utils.image_to_tensor(np.array(x)) + def load_data(self, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable: pass @@ -95,7 +106,8 @@ def from_filepaths( predict_transform, ) - return cls.from_load_data_inputs( + return cls() + '''return cls.from_load_data_inputs( train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, @@ -104,4 +116,4 @@ def from_filepaths( num_workers=num_workers, preprocess=preprocess, seed=seed, - ) + )''' diff --git a/tests/vision/segmentation/__init__.py b/tests/vision/segmentation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index c82ecc6a85..682b101299 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -1,6 +1,73 @@ -from flash.vision import SemanticSegmentationData +from pathlib import Path +from typing import List, Tuple +import numpy as np +import pytest +from PIL import Image -def test_smoke(): - dm = SemanticSegmentationData() - assert dm is not None +from flash.vision import SemanticSegmentationData, SemantincSegmentationPreprocess + + +def _rand_image(size: Tuple[int, int]): + data: np.ndarray = np.random.randint(0, 255, (*size, 3), dtype="uint8") + return Image.fromarray(data) + + +# usually labels come as rgb images -> need to map to labels +def _rand_labels(size: Tuple[int, int]): + data: np.ndarray = np.random.randint(0, 255, (*size, 3), dtype="uint8") + return Image.fromarray(data) + + +def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int]) -> Image.Image: + for img_file in image_files: + _rand_image(size).save(img_file) + + for label_file in label_files: + _rand_labels(size).save(img_file) + + +class TestSemanticSegmentationPreprocess: + + @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") + def test_smoke(self): + prep = SemantincSegmentationPreprocess() + assert prep is not None + + +class TestSemanticSegmentationData: + + def test_smoke(self): + dm = SemanticSegmentationData() + assert dm is not None + + def test_from_filepaths(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + train_images = [ + tmp_dir / "img1.png", + tmp_dir / "img2.png", + tmp_dir / "img3.png", + ] + + train_labels = [ + tmp_dir / "labels_img1.png", + tmp_dir / "labels_img2.png", + tmp_dir / "labels_img3.png", + ] + + img_size: Tuple[int, int] = (192, 192) + create_random_data(train_images, train_labels, img_size) + + # instantiate the data module + + dm = SemanticSegmentationData.from_filepaths( + train_filepaths=train_images, + train_labels=train_labels, + batch_size=2, + num_workers=0, + ) + assert dm is not None + # assert dm.train_dataloader() is not None From 1ef1b40e8d5298b3dfe8d7a04746e1c29632d75b Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Fri, 23 Apr 2021 17:02:06 +0200 Subject: [PATCH 03/53] data module connections working --- flash/vision/segmentation/data.py | 53 ++++++++++++++++---------- tests/vision/segmentation/test_data.py | 14 +++---- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index baa83165b1..7e28bd4fa8 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -1,11 +1,13 @@ -from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import kornia as K import numpy as np import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset from flash.data.auto_dataset import AutoDataset +from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.process import Preprocess @@ -18,6 +20,7 @@ def __init__( val_transform: Optional[Dict[str, Callable]] = None, test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (196, 196), ) -> 'SemantincSegmentationPreprocess': # TODO: implement me @@ -63,20 +66,12 @@ def per_batch_transform_on_device(self, sample: Any) -> Any: class SemanticSegmentationData(DataModule): """Data module for semantic segmentation tasks.""" - def __init__( - train_dataset: Optional[Dataset] = None, - val_dataset: Optional[Dataset] = None, - test_dataset: Optional[Dataset] = None, - predict_dataset: Optional[Dataset] = None, - batch_size: int = 1, - num_workers: Optional[int] = None, - seed: int = 1234, - train_split: Optional[float] = None, - val_split: Optional[float] = None, - # test_split: Optional[float] = None, ## THIS WILL GO OUT - preprocess: Optional[Preprocess] = None, - ) -> None: - pass + @staticmethod + def _check_valid_filepaths(filepaths: List[str]): + if filepaths is not None and ( + not isinstance(filepaths, list) or not all(isinstance(n, str) for n in filepaths) + ): + raise MisconfigurationException(f"`filepaths` must be of type List[str]. Got: {filepaths}.") @classmethod def from_filepaths( @@ -92,13 +87,26 @@ def from_filepaths( val_transform: Union[str, Dict] = 'default', test_transform: Union[str, Dict] = 'default', predict_transform: Union[str, Dict] = 'default', + image_size: Tuple[int, int] = (196, 196), batch_size: int = 64, num_workers: Optional[int] = None, - seed: Optional[int] = 42, + #seed: Optional[int] = 42, # SEED NEVER USED + data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, + # val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED + #**kwargs, ) -> 'SemanticSegmentationData': + # verify input data format + SemanticSegmentationData._check_valid_filepaths(train_filepaths) + SemanticSegmentationData._check_valid_filepaths(train_labels) + SemanticSegmentationData._check_valid_filepaths(val_filepaths) + SemanticSegmentationData._check_valid_filepaths(val_labels) + SemanticSegmentationData._check_valid_filepaths(test_filepaths) + SemanticSegmentationData._check_valid_filepaths(test_labels) + SemanticSegmentationData._check_valid_filepaths(predict_filepaths) + + # create the preprocess objects preprocess = preprocess or SemantincSegmentationPreprocess( train_transform, val_transform, @@ -106,14 +114,17 @@ def from_filepaths( predict_transform, ) - return cls() - '''return cls.from_load_data_inputs( + # instantiate the data module class + return DataModule.from_load_data_inputs( train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, predict_load_data_input=predict_filepaths, batch_size=batch_size, num_workers=num_workers, + data_fetcher=data_fetcher, preprocess=preprocess, - seed=seed, - )''' + #seed=seed, + #val_split=val_split, + #**kwargs + ) diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 682b101299..927a432294 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -47,15 +47,15 @@ def test_from_filepaths(self, tmpdir): # create random dummy data train_images = [ - tmp_dir / "img1.png", - tmp_dir / "img2.png", - tmp_dir / "img3.png", + str(tmp_dir / "img1.png"), + str(tmp_dir / "img2.png"), + str(tmp_dir / "img3.png"), ] train_labels = [ - tmp_dir / "labels_img1.png", - tmp_dir / "labels_img2.png", - tmp_dir / "labels_img3.png", + str(tmp_dir / "labels_img1.png"), + str(tmp_dir / "labels_img2.png"), + str(tmp_dir / "labels_img3.png"), ] img_size: Tuple[int, int] = (192, 192) @@ -70,4 +70,4 @@ def test_from_filepaths(self, tmpdir): num_workers=0, ) assert dm is not None - # assert dm.train_dataloader() is not None + assert dm.train_dataloader() is not None From 7f17fb2b85a67ebd177f762f662e30ef7d97e352 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Fri, 23 Apr 2021 18:08:43 +0200 Subject: [PATCH 04/53] preprocess not crashing(wip) --- flash/vision/segmentation/data.py | 62 ++++++++++++++------------ tests/vision/segmentation/test_data.py | 12 ++++- 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 7e28bd4fa8..1e24d3d57c 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -3,6 +3,7 @@ import kornia as K import numpy as np import torch +import torchvision from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset @@ -12,6 +13,10 @@ from flash.data.process import Preprocess +def to_tensor(self, x): + return K.utils.image_to_tensor(np.array(x)) + + class SemantincSegmentationPreprocess(Preprocess): def __init__( @@ -27,38 +32,37 @@ def __init__( '''train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( train_transform, val_transform, test_transform, predict_transform )''' - train_transform = dict(to_tensor_transform=self.to_tensor) - val_transform = dict(to_tensor_transform=self.to_tensor) - test_transform = dict(to_tensor_transform=self.to_tensor) - predict_transform = dict(to_tensor_transform=self.to_tensor) + train_transform = dict(per_batch_transform=to_tensor) + val_transform = dict(per_batch_transform=to_tensor) + test_transform = dict(per_batch_transform=to_tensor) + predict_transform = dict(per_batch_transform=to_tensor) super().__init__(train_transform, val_transform, test_transform, predict_transform) - @staticmethod - def to_tensor(self, x): - return K.utils.image_to_tensor(np.array(x)) - - def load_data(self, data: Any, dataset: Optional[AutoDataset] = None) -> Iterable: - pass - - def load_sample(sample) -> Tuple[torch.Tensor, torch.Tensor]: - pass - - def collate(self, samples: Sequence) -> Any: - pass - - def pre_tensor_transform(self, sample: Any) -> Any: - pass - - def to_tensor_transform(self, sample: Any) -> Any: - pass - - def post_tensor_transform(self, sample: Any) -> Any: - pass - - def per_batch_transform(self, sample: Any) -> Any: - pass - + def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tensor]: + if not isinstance(sample, tuple): + raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") + # unpack data paths + img_path: str = sample[0] + img_labels_path: str = sample[1] + + # load images directly to torch tensors + img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW + + return img, img_labels + + def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + if not isinstance(sample, list): + raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") + img, img_labels = sample + # THIS IS CRASHING + # out1 = self.current_transform(img) # images + # out2 = self.current_transform(img_labels) # labels + # return out1, out2 + return img, img_labels + + # TODO: implement me def per_batch_transform_on_device(self, sample: Any) -> Any: pass diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 927a432294..febc5e1472 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -24,7 +24,7 @@ def create_random_data(image_files: List[str], label_files: List[str], size: Tup _rand_image(size).save(img_file) for label_file in label_files: - _rand_labels(size).save(img_file) + _rand_labels(size).save(label_file) class TestSemanticSegmentationPreprocess: @@ -58,7 +58,7 @@ def test_from_filepaths(self, tmpdir): str(tmp_dir / "labels_img3.png"), ] - img_size: Tuple[int, int] = (192, 192) + img_size: Tuple[int, int] = (196, 196) create_random_data(train_images, train_labels, img_size) # instantiate the data module @@ -71,3 +71,11 @@ def test_from_filepaths(self, tmpdir): ) assert dm is not None assert dm.train_dataloader() is not None + assert dm.val_dataloader() is None + assert dm.test_dataloader() is None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 3, 196, 196) From 7d9d46cab714aef031dfb78b17f5adc4db990670 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Mon, 26 Apr 2021 11:29:18 +0200 Subject: [PATCH 05/53] implement segmentation sequential --- flash/vision/segmentation/data.py | 38 ++++++++++++++++++++++---- tests/vision/segmentation/test_data.py | 2 +- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 1e24d3d57c..42e515cb70 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -3,6 +3,7 @@ import kornia as K import numpy as np import torch +import torch.nn as nn import torchvision from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset @@ -13,6 +14,20 @@ from flash.data.process import Preprocess +class SegmentationSequential(nn.Sequential): + + def __init__(self, *args): + super(SegmentationSequential, self).__init__(*args) + + def forward(self, img, mask): + img_out = img.float() + mask_out = mask.float() + for aug in self.children(): + img_out = aug(img_out) + mask_out = aug(mask_out, aug._params) + return img_out[0], mask_out[0] + + def to_tensor(self, x): return K.utils.image_to_tensor(np.array(x)) @@ -32,13 +47,18 @@ def __init__( '''train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( train_transform, val_transform, test_transform, predict_transform )''' - train_transform = dict(per_batch_transform=to_tensor) - val_transform = dict(per_batch_transform=to_tensor) - test_transform = dict(per_batch_transform=to_tensor) - predict_transform = dict(per_batch_transform=to_tensor) + augs = SegmentationSequential( + # K.augmentation.RandomResizedCrop((128, 128)), + K.augmentation.RandomHorizontalFlip(), + ) + train_transform = dict(to_tensor_transform=augs) + val_transform = dict(to_tensor_transform=augs) + test_transform = dict(to_tensor_transform=augs) + predict_transform = dict(to_tensor_transform=augs) super().__init__(train_transform, val_transform, test_transform, predict_transform) + # TODO: is it a problem to load sample directly in tensor. What happens in to_tensor_tranform def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tensor]: if not isinstance(sample, tuple): raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") @@ -52,6 +72,14 @@ def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tens return img, img_labels + def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + if not isinstance(sample, tuple): + raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") + img, img_labels = sample + img_out, img_labels_out = self.current_transform(img, img_labels) + return img_out, img_labels_out + + # TODO: the labels are not clear how to forward to the loss once are transform from this point def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: if not isinstance(sample, list): raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") @@ -62,7 +90,7 @@ def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tupl # return out1, out2 return img, img_labels - # TODO: implement me + # TODO: the labels are not clear how to forward to the loss once are transform from this point def per_batch_transform_on_device(self, sample: Any) -> Any: pass diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index febc5e1472..03149a0874 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -15,7 +15,7 @@ def _rand_image(size: Tuple[int, int]): # usually labels come as rgb images -> need to map to labels def _rand_labels(size: Tuple[int, int]): - data: np.ndarray = np.random.randint(0, 255, (*size, 3), dtype="uint8") + data: np.ndarray = np.ones((*size, 3), dtype=np.uint8) return Image.fromarray(data) From 498e278158b6b871620d40c8b86cdcc8466be70f Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Mon, 26 Apr 2021 13:23:35 +0200 Subject: [PATCH 06/53] implement torchvision backbone model --- flash/vision/segmentation/data.py | 1 + flash/vision/segmentation/model.py | 24 ++++++++++++------------ tests/vision/segmentation/test_data.py | 20 ++++++++++++++++++-- tests/vision/segmentation/test_model.py | 17 +++++++++++------ 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 42e515cb70..1a5d576b24 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -127,6 +127,7 @@ def from_filepaths( preprocess: Optional[Preprocess] = None, # val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED #**kwargs, + map_pix_to_labels: Dict[Tuple[int, int, int], int] = None, # TODO: implement me ) -> 'SemanticSegmentationData': # verify input data format diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index 98dedc07d4..8e33e3345a 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -35,7 +35,7 @@ class SemanticSegmentation(ClassificationTask): def __init__( self, num_classes: int, - backbone: Union[str, Tuple[nn.Module, int]] = "resnet18", + backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50", backbone_kwargs: Optional[Dict] = None, head: Optional[Union[FunctionType, nn.Module]] = None, pretrained: bool = True, @@ -50,6 +50,7 @@ def __init__( if metrics is None: metrics = Accuracy(subset_accuracy=multi_label) + # TODO: do we have any case for this ? if loss_fn is None: # loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy loss_fn = F.cross_entropy @@ -68,16 +69,15 @@ def __init__( if not backbone_kwargs: backbone_kwargs = {} - # TODO: implement first torchvision - self.backbone, num_features = None, 1 - '''if isinstance(backbone, tuple): - self.backbone, num_features = backbone - else: - self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)''' - - head = head(num_features, num_classes) if isinstance(head, FunctionType) else head - self.head = head or nn.Conv2d(num_features, num_classes, kernel_size=1) + # TODO: pretrained to True causes some issues + self.model = self.backbones.get(backbone)(pretrained=False, num_classes=num_classes, **backbone_kwargs) def forward(self, x) -> torch.Tensor: - x = self.backbone(x) - return self.head(x) + return self.model(x)['out'] # TODO: find a proper way to get 'out' from registry + + +@SemanticSegmentation.backbones(name="torchvision/fcn_resnet50") +def fn(pretrained: bool, num_classes: int): + import torchvision + model: nn.Module = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained, num_classes=num_classes) + return model diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 03149a0874..a828f67c04 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -66,16 +66,32 @@ def test_from_filepaths(self, tmpdir): dm = SemanticSegmentationData.from_filepaths( train_filepaths=train_images, train_labels=train_labels, + val_filepaths=train_images, + val_labels=train_labels, + test_filepaths=train_images, + test_labels=train_labels, batch_size=2, num_workers=0, ) assert dm is not None assert dm.train_dataloader() is not None - assert dm.val_dataloader() is None - assert dm.test_dataloader() is None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 3, 196, 196) + + # check training data + data = next(iter(dm.val_dataloader())) + imgs, labels = data + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 3, 196, 196) + + # check training data + data = next(iter(dm.val_dataloader())) + imgs, labels = data + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 3, 196, 196) diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index fdc98674bb..92c8532672 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -9,11 +9,16 @@ def test_smoke(): assert model is not None -@pytest.mark.skip(reason="todo") -def test_forward(): - num_classes = 5 - model = SemanticSegmentation(num_classes) +@pytest.mark.parametrize("num_classes", [8, 256]) +@pytest.mark.parametrize("img_shape", [(1, 3, 224, 192), (2, 3, 127, 212)]) +def test_forward(num_classes, img_shape): + model = SemanticSegmentation( + num_classes=num_classes, + backbone='torchvision/fcn_resnet50', + ) + + B, C, H, W = img_shape + img = torch.rand(B, C, H, W) - img = torch.rand(1, 3, 224, 224) out = model(img) - assert out.shape == (1, num_classes, 224, 224) + assert out.shape == (B, num_classes, H, W) From 56fa4d50087a35054d91dc894dbd42b00b8e33a2 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Mon, 26 Apr 2021 16:42:18 +0200 Subject: [PATCH 07/53] model working --- flash/core/classification.py | 3 +- flash/vision/segmentation/model.py | 18 +++++---- tests/vision/segmentation/test_model.py | 52 +++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 346905b823..8270fec99b 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -61,7 +61,8 @@ def __init__( def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: if getattr(self.hparams, "multi_label", False): return F.sigmoid(x) - return F.softmax(x, -1) + # we'll assume that the data always comes as `(B, C, ...)` + return F.softmax(x, dim=1) class ClassificationSerializer(Serializer): diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index 8e33e3345a..62b39f9409 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -17,7 +17,7 @@ import torch from torch import nn from torch.nn import functional as F -from torchmetrics import Accuracy +from torchmetrics import Accuracy, IoU from flash.core.classification import Classes, ClassificationTask from flash.core.registry import FlashRegistry @@ -40,21 +40,23 @@ def __init__( head: Optional[Union[FunctionType, nn.Module]] = None, pretrained: bool = True, loss_fn: Optional[Callable] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, learning_rate: float = 1e-3, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, - ): + ) -> None: if metrics is None: - metrics = Accuracy(subset_accuracy=multi_label) + metrics = IoU(num_classes=num_classes) - # TODO: do we have any case for this ? if loss_fn is None: - # loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy loss_fn = F.cross_entropy + # TODO: need to check for multi_label + if multi_label: + raise NotImplementedError("Multi-label not supported yet.") + super().__init__( model=None, loss_fn=loss_fn, @@ -70,10 +72,10 @@ def __init__( backbone_kwargs = {} # TODO: pretrained to True causes some issues - self.model = self.backbones.get(backbone)(pretrained=False, num_classes=num_classes, **backbone_kwargs) + self.backbone = self.backbones.get(backbone)(pretrained=False, num_classes=num_classes, **backbone_kwargs) def forward(self, x) -> torch.Tensor: - return self.model(x)['out'] # TODO: find a proper way to get 'out' from registry + return self.backbone(x)['out'] # TODO: find a proper way to get 'out' from registry @SemanticSegmentation.backbones(name="torchvision/fcn_resnet50") diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index 92c8532672..283d4e4a76 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -1,8 +1,28 @@ +from typing import Tuple + import pytest import torch +from flash import Trainer from flash.vision import SemanticSegmentation +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + size: Tuple[int, int] = (224, 224) + num_classes: int = 8 + + def __getitem__(self, index): + return torch.rand(3, *self.size), \ + torch.randint(self.num_classes - 1, self.size) + + def __len__(self) -> int: + return 10 + + +# ============================== + def test_smoke(): model = SemanticSegmentation(num_classes=1) @@ -22,3 +42,35 @@ def test_forward(num_classes, img_shape): out = model(img) assert out.shape == (B, num_classes, H, W) + + +@pytest.mark.parametrize( + "backbone", + [ + "torchvision/fcn_resnet50", + ], +) +def test_init_train(tmpdir, backbone): + model = SemanticSegmentation(num_classes=10, backbone=backbone) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.finetune(model, train_dl, strategy="freeze_unfreeze") + + +def test_non_existent_backbone(): + with pytest.raises(KeyError): + SemanticSegmentation(2, "i am never going to implement this lol") + + +def test_freeze(): + model = SemanticSegmentation(2) + model.freeze() + for p in model.backbone.parameters(): + assert p.requires_grad is False + + +def test_unfreeze(): + model = SemanticSegmentation(2) + model.unfreeze() + for p in model.backbone.parameters(): + assert p.requires_grad is True From 950252e985e54147016d481bc92dfeb8470e045c Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Mon, 26 Apr 2021 17:44:01 +0200 Subject: [PATCH 08/53] implement labels mapping --- flash/vision/segmentation/data.py | 25 ++++++++-- tests/vision/segmentation/test_data.py | 63 ++++++++++++++++++++++---- 2 files changed, 76 insertions(+), 12 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 1a5d576b24..140ee2331a 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -41,7 +41,9 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), + map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, ) -> 'SemantincSegmentationPreprocess': + self._map_labels = map_labels # TODO: implement me '''train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( @@ -58,6 +60,16 @@ def __init__( super().__init__(train_transform, val_transform, test_transform, predict_transform) + def _apply_map_labels(self, img) -> torch.Tensor: + assert len(img.shape) == 3, img.shape + C, H, W = img.shape + outs = torch.empty(H, W, dtype=torch.int64) + for label, values in self._map_labels.items(): + vals = torch.tensor(values).view(3, 1, 1) + mask = (img == vals).all(-3) + outs[mask] = label + return outs + # TODO: is it a problem to load sample directly in tensor. What happens in to_tensor_tranform def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tensor]: if not isinstance(sample, tuple): @@ -70,6 +82,10 @@ def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tens img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW + # TODO: decide at which point do we apply this + if self._map_labels is not None: + img_labels = self._apply_map_labels(img_labels) + return img, img_labels def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: @@ -127,7 +143,7 @@ def from_filepaths( preprocess: Optional[Preprocess] = None, # val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED #**kwargs, - map_pix_to_labels: Dict[Tuple[int, int, int], int] = None, # TODO: implement me + map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, ) -> 'SemanticSegmentationData': # verify input data format @@ -145,6 +161,7 @@ def from_filepaths( val_transform, test_transform, predict_transform, + map_labels=map_labels, ) # instantiate the data module class @@ -157,7 +174,7 @@ def from_filepaths( num_workers=num_workers, data_fetcher=data_fetcher, preprocess=preprocess, - #seed=seed, - #val_split=val_split, - #**kwargs + #seed=seed, # THIS CRASHES + #val_split=val_split, # THIS CRASHES + #**kwargs # THIS CRASHES ) diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index a828f67c04..e076e07f42 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import pytest @@ -14,17 +14,27 @@ def _rand_image(size: Tuple[int, int]): # usually labels come as rgb images -> need to map to labels -def _rand_labels(size: Tuple[int, int]): - data: np.ndarray = np.ones((*size, 3), dtype=np.uint8) - return Image.fromarray(data) - - -def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int]) -> Image.Image: +def _rand_labels(size: Tuple[int, int], map_labels: Dict[Tuple[int, int, int], int] = None): + data: np.ndarray = np.random.rand(*size, 3) + if map_labels is not None: + data_bin = (data.mean(-1) > 0.5) + for k, v in map_labels.items(): + mask = (data_bin == k) + data[mask] = v + return Image.fromarray(data.astype(np.uint8)) + + +def create_random_data( + image_files: List[str], + label_files: List[str], + size: Tuple[int, int], + map_labels: Optional[Dict[Tuple[int, int, int], int]] = None, +) -> Image.Image: for img_file in image_files: _rand_image(size).save(img_file) for label_file in label_files: - _rand_labels(size).save(label_file) + _rand_labels(size, map_labels).save(label_file) class TestSemanticSegmentationPreprocess: @@ -95,3 +105,40 @@ def test_from_filepaths(self, tmpdir): imgs, labels = data assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 3, 196, 196) + + def test_map_labels(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + train_images = [ + str(tmp_dir / "img1.png"), + str(tmp_dir / "img2.png"), + str(tmp_dir / "img3.png"), + ] + + train_labels = [ + str(tmp_dir / "labels_img1.png"), + str(tmp_dir / "labels_img2.png"), + str(tmp_dir / "labels_img3.png"), + ] + + map_labels: Dict[int, Tuple[int, int, int]] = { + 0: [0, 0, 0], + 1: [255, 255, 255], + } + + img_size: Tuple[int, int] = (196, 196) + create_random_data(train_images, train_labels, img_size, map_labels) + + # instantiate the data module + + dm = SemanticSegmentationData.from_filepaths( + train_filepaths=train_images, train_labels=train_labels, batch_size=2, num_workers=0, map_labels=map_labels + ) + assert dm is not None + assert dm.train_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data From 6a7524515d073af401cd7c586de3bfe30e024a71 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Mon, 26 Apr 2021 18:00:38 +0200 Subject: [PATCH 09/53] add map labels tests --- flash/vision/segmentation/data.py | 9 +++++---- tests/vision/segmentation/test_data.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 140ee2331a..a3ca364ff5 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -82,10 +82,6 @@ def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tens img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW - # TODO: decide at which point do we apply this - if self._map_labels is not None: - img_labels = self._apply_map_labels(img_labels) - return img, img_labels def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: @@ -93,6 +89,11 @@ def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tupl raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") img, img_labels = sample img_out, img_labels_out = self.current_transform(img, img_labels) + + # TODO: decide at which point do we apply this + if self._map_labels is not None: + img_labels_out = self._apply_map_labels(img_labels_out) + return img_out, img_labels_out # TODO: the labels are not clear how to forward to the loss once are transform from this point diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index e076e07f42..500174fce8 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -3,6 +3,7 @@ import numpy as np import pytest +import torch from PIL import Image from flash.vision import SemanticSegmentationData, SemantincSegmentationPreprocess @@ -14,7 +15,7 @@ def _rand_image(size: Tuple[int, int]): # usually labels come as rgb images -> need to map to labels -def _rand_labels(size: Tuple[int, int], map_labels: Dict[Tuple[int, int, int], int] = None): +def _rand_labels(size: Tuple[int, int], map_labels: Dict[int, Tuple[int, int, int]] = None): data: np.ndarray = np.random.rand(*size, 3) if map_labels is not None: data_bin = (data.mean(-1) > 0.5) @@ -28,8 +29,8 @@ def create_random_data( image_files: List[str], label_files: List[str], size: Tuple[int, int], - map_labels: Optional[Dict[Tuple[int, int, int], int]] = None, -) -> Image.Image: + map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, +): for img_file in image_files: _rand_image(size).save(img_file) @@ -142,3 +143,8 @@ def test_map_labels(self, tmpdir): # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + assert labels.min().item() == 0 + assert labels.max().item() == 1 + assert labels.dtype == torch.int64 From 7a7f855024b0bea08f85f77caca2e2992569d44a Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Mon, 26 Apr 2021 18:15:23 +0200 Subject: [PATCH 10/53] from filepaths training test not crashing --- flash/vision/segmentation/data.py | 5 +++-- tests/vision/segmentation/test_data.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index a3ca364ff5..f46af0eb92 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -14,6 +14,7 @@ from flash.data.process import Preprocess +# container to apply augmentations at both image and mask reusing the same parameters class SegmentationSequential(nn.Sequential): def __init__(self, *args): @@ -97,7 +98,7 @@ def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tupl return img_out, img_labels_out # TODO: the labels are not clear how to forward to the loss once are transform from this point - def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + '''def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: if not isinstance(sample, list): raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") img, img_labels = sample @@ -109,7 +110,7 @@ def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tupl # TODO: the labels are not clear how to forward to the loss once are transform from this point def per_batch_transform_on_device(self, sample: Any) -> Any: - pass + pass''' class SemanticSegmentationData(DataModule): diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 500174fce8..aae1756279 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -6,7 +6,8 @@ import torch from PIL import Image -from flash.vision import SemanticSegmentationData, SemantincSegmentationPreprocess +from flash import Trainer +from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemantincSegmentationPreprocess def _rand_image(size: Tuple[int, int]): @@ -135,7 +136,13 @@ def test_map_labels(self, tmpdir): # instantiate the data module dm = SemanticSegmentationData.from_filepaths( - train_filepaths=train_images, train_labels=train_labels, batch_size=2, num_workers=0, map_labels=map_labels + train_filepaths=train_images, + train_labels=train_labels, + val_filepaths=train_images, + val_labels=train_labels, + batch_size=2, + num_workers=0, + map_labels=map_labels ) assert dm is not None assert dm.train_dataloader() is not None @@ -148,3 +155,8 @@ def test_map_labels(self, tmpdir): assert labels.min().item() == 0 assert labels.max().item() == 1 assert labels.dtype == torch.int64 + + # now train with `fast_dev_run` + model = SemanticSegmentation(num_classes=2, backbone="torchvision/fcn_resnet50") + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.finetune(model, dm, strategy="freeze_unfreeze") From def1ea07921e82a7f4b0ffdb6be53f4db132c269 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Tue, 27 Apr 2021 11:58:59 +0200 Subject: [PATCH 11/53] non working visualiser --- flash/data/data_module.py | 3 +- flash/vision/segmentation/data.py | 88 +++++++++++++++++++++++++- tests/vision/segmentation/test_data.py | 2 + 3 files changed, 90 insertions(+), 3 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index bcb3787268..b40f21d437 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -434,7 +434,8 @@ def from_load_data_inputs( else: data_pipeline = cls(**kwargs).data_pipeline - data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher() + #data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher() + data_fetcher: BaseDataFetcher = data_fetcher or DataModule.configure_data_fetcher() data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index f46af0eb92..aa8801a5c6 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -5,13 +5,21 @@ import torch import torch.nn as nn import torchvision +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset from flash.data.auto_dataset import AutoDataset +from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.process import Preprocess +from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + plt = None # container to apply augmentations at both image and mask reusing the same parameters @@ -116,6 +124,10 @@ def per_batch_transform_on_device(self, sample: Any) -> Any: class SemanticSegmentationData(DataModule): """Data module for semantic segmentation tasks.""" + # TODO: figure out if this needed + #def __init__(self, **kwargs) -> None: + # super().__init__(**kwargs) + @staticmethod def _check_valid_filepaths(filepaths: List[str]): if filepaths is not None and ( @@ -123,6 +135,10 @@ def _check_valid_filepaths(filepaths: List[str]): ): raise MisconfigurationException(f"`filepaths` must be of type List[str]. Got: {filepaths}.") + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + return _MatplotlibVisualization(*args, **kwargs) + @classmethod def from_filepaths( cls, @@ -144,8 +160,8 @@ def from_filepaths( data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, # val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED - #**kwargs, map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, + **kwargs, # TODO: remove and make explicit params ) -> 'SemanticSegmentationData': # verify input data format @@ -178,5 +194,73 @@ def from_filepaths( preprocess=preprocess, #seed=seed, # THIS CRASHES #val_split=val_split, # THIS CRASHES - #**kwargs # THIS CRASHES + **kwargs, # TODO: remove and make explicit params ) + + +class _MatplotlibVisualization(BaseVisualization): + """Process and show the image batch and its associated label using matplotlib. + """ + max_cols: int = 4 # maximum number of columns we accept + block_viz_window: bool = True # parameter to allow user to block visualisation windows + '''@staticmethod + def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: + out: np.ndarray + if isinstance(img, Image.Image): + out = np.array(img) + elif isinstance(img, torch.Tensor): + out = img.squeeze(0).permute(1, 2, 0).cpu().numpy() + else: + raise TypeError(f"Unknown image type. Got: {type(img)}.") + return out''' + + def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): + # define the image grid + cols: int = min(num_samples, self.max_cols) + rows: int = num_samples // cols + + if not _MATPLOTLIB_AVAILABLE: + raise MisconfigurationException("You need matplotlib to visualise. Please, pip install matplotlib") + + # create figure and set title + fig, axs = plt.subplots(rows, cols) + fig.suptitle(title) + + for i, ax in enumerate(axs.ravel()): + # unpack images and labels + if isinstance(data, list): + _img, _label = data[i] + elif isinstance(data, tuple): + imgs, labels = data + _img, _label = imgs[i], labels[i] + else: + raise TypeError(f"Unknown data type. Got: {type(data)}.") + # convert images to numpy + _img: np.ndarray = self._to_numpy(_img) + if isinstance(_label, torch.Tensor): + _label = _label.squeeze().tolist() + # show image and set label as subplot title + ax.imshow(_img) + ax.set_title(str(_label)) + ax.axis('off') + plt.show(block=self.block_viz_window) + + def show_load_sample(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_load_sample" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_pre_tensor_transform" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_to_tensor_transform" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_post_tensor_transform" + self._show_images_and_labels(samples, len(samples), win_title) + + def show_per_batch_transform(self, batch: List[Any], running_stage): + win_title: str = f"{running_stage} - show_per_batch_transform" + self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title) diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index aae1756279..1e4429c48e 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -147,6 +147,8 @@ def test_map_labels(self, tmpdir): assert dm is not None assert dm.train_dataloader() is not None + dm.show_train_batch() + # check training data data = next(iter(dm.train_dataloader())) imgs, labels = data From ed17eb0aa01c22ef792319ebbde9c85bd8b10e1e Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Tue, 27 Apr 2021 13:21:06 +0200 Subject: [PATCH 12/53] fix visualiser --- flash/data/data_module.py | 4 +- flash/vision/segmentation/data.py | 79 ++++++++++++++------------ tests/vision/segmentation/test_data.py | 20 ++++++- 3 files changed, 63 insertions(+), 40 deletions(-) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index b40f21d437..a4ede7d12d 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -434,8 +434,8 @@ def from_load_data_inputs( else: data_pipeline = cls(**kwargs).data_pipeline - #data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher() - data_fetcher: BaseDataFetcher = data_fetcher or DataModule.configure_data_fetcher() + data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher() + #data_fetcher: BaseDataFetcher = data_fetcher or DataModule.configure_data_fetcher() data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index aa8801a5c6..abdc7b6296 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torchvision +from PIL import Image from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset @@ -69,7 +70,7 @@ def __init__( super().__init__(train_transform, val_transform, test_transform, predict_transform) - def _apply_map_labels(self, img) -> torch.Tensor: + def _image_to_labels(self, img) -> torch.Tensor: assert len(img.shape) == 3, img.shape C, H, W = img.shape outs = torch.empty(H, W, dtype=torch.int64) @@ -101,7 +102,7 @@ def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tupl # TODO: decide at which point do we apply this if self._map_labels is not None: - img_labels_out = self._apply_map_labels(img_labels_out) + img_labels_out = self._image_to_labels(img_labels_out) return img_out, img_labels_out @@ -139,16 +140,23 @@ def _check_valid_filepaths(filepaths: List[str]): def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return _MatplotlibVisualization(*args, **kwargs) + def set_map_labels(self, map_labels): + self.data_fetcher.map_labels = map_labels + + def set_block_viz_window(self, value: bool) -> None: + """Setter method to switch on/off matplotlib to pop up windows.""" + self.data_fetcher.block_viz_window = value + @classmethod def from_filepaths( cls, - train_filepaths: Optional[Sequence[str]] = None, - train_labels: Optional[Sequence[str]] = None, - val_filepaths: Optional[Sequence[str]] = None, - val_labels: Optional[Sequence[str]] = None, - test_filepaths: Optional[Sequence[str]] = None, - test_labels: Optional[Sequence[str]] = None, - predict_filepaths: Optional[Sequence[str]] = None, + train_filepaths: Optional[List[str]] = None, + train_labels: Optional[List[str]] = None, + val_filepaths: Optional[List[str]] = None, + val_labels: Optional[List[str]] = None, + test_filepaths: Optional[List[str]] = None, + test_labels: Optional[List[str]] = None, + predict_filepaths: Optional[List[str]] = None, train_transform: Union[str, Dict] = 'default', val_transform: Union[str, Dict] = 'default', test_transform: Union[str, Dict] = 'default', @@ -182,8 +190,8 @@ def from_filepaths( map_labels=map_labels, ) - # instantiate the data module class - return DataModule.from_load_data_inputs( + # this functions overrides `DataModule.from_load_data_inputs` + return cls.from_load_data_inputs( train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, @@ -203,7 +211,9 @@ class _MatplotlibVisualization(BaseVisualization): """ max_cols: int = 4 # maximum number of columns we accept block_viz_window: bool = True # parameter to allow user to block visualisation windows - '''@staticmethod + map_labels = {} + + @staticmethod def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: out: np.ndarray if isinstance(img, Image.Image): @@ -212,7 +222,17 @@ def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: out = img.squeeze(0).permute(1, 2, 0).cpu().numpy() else: raise TypeError(f"Unknown image type. Got: {type(img)}.") - return out''' + return out + + def _labels_to_image(self, img_labels: torch.Tensor) -> torch.Tensor: + assert len(img_labels.shape) == 2, img_labels.shape + H, W = img_labels.shape + out = torch.empty(3, H, W, dtype=torch.uint8) + for label_id, label_val in self.map_labels.items(): + mask = (img_labels == label_id) + for i in range(3): + out[i].masked_fill_(mask, label_val[i]) + return out def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): # define the image grid @@ -229,19 +249,20 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) for i, ax in enumerate(axs.ravel()): # unpack images and labels if isinstance(data, list): - _img, _label = data[i] + _img, _img_labels = data[i] elif isinstance(data, tuple): - imgs, labels = data - _img, _label = imgs[i], labels[i] + imgs, imgs_labels = data + _img, _img_labels = imgs[i], imgs_labels[i] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") - # convert images to numpy - _img: np.ndarray = self._to_numpy(_img) - if isinstance(_label, torch.Tensor): - _label = _label.squeeze().tolist() - # show image and set label as subplot title - ax.imshow(_img) - ax.set_title(str(_label)) + # convert images and labels to numpy and stack horizontally + img_vis: np.ndarray = self._to_numpy(_img) + if len(_img_labels.shape) == 2: + _img_labels = self._labels_to_image(_img_labels) + img_labels_vis: np.ndarray = self._to_numpy(_img_labels) + img_vis = np.hstack((img_vis, img_labels_vis)) + # send to visualiser + ax.imshow(img_vis) ax.axis('off') plt.show(block=self.block_viz_window) @@ -249,18 +270,6 @@ def show_load_sample(self, samples: List[Any], running_stage: RunningStage): win_title: str = f"{running_stage} - show_load_sample" self._show_images_and_labels(samples, len(samples), win_title) - def show_pre_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - win_title: str = f"{running_stage} - show_pre_tensor_transform" - self._show_images_and_labels(samples, len(samples), win_title) - def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): win_title: str = f"{running_stage} - show_to_tensor_transform" self._show_images_and_labels(samples, len(samples), win_title) - - def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - win_title: str = f"{running_stage} - show_post_tensor_transform" - self._show_images_and_labels(samples, len(samples), win_title) - - def show_per_batch_transform(self, batch: List[Any], running_stage): - win_title: str = f"{running_stage} - show_per_batch_transform" - self._show_images_and_labels(batch[0], batch[0][0].shape[0], win_title) diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 1e4429c48e..789b5800a0 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -10,14 +10,21 @@ from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemantincSegmentationPreprocess +def build_checkboard(n, m, k=8): + x = np.zeros((n, m)) + x[k::k * 2, ::k] = 1 + x[::k * 2, k::k * 2] = 1 + return x + + def _rand_image(size: Tuple[int, int]): - data: np.ndarray = np.random.randint(0, 255, (*size, 3), dtype="uint8") + data = build_checkboard(*size).astype(np.uint8)[..., None].repeat(3, -1) return Image.fromarray(data) # usually labels come as rgb images -> need to map to labels def _rand_labels(size: Tuple[int, int], map_labels: Dict[int, Tuple[int, int, int]] = None): - data: np.ndarray = np.random.rand(*size, 3) + data: np.ndarray = np.random.rand(*size, 3) / .5 if map_labels is not None: data_bin = (data.mean(-1) > 0.5) for k, v in map_labels.items(): @@ -147,7 +154,14 @@ def test_map_labels(self, tmpdir): assert dm is not None assert dm.train_dataloader() is not None - dm.show_train_batch() + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + dm.set_map_labels(map_labels) + dm.show_train_batch("load_sample") + dm.show_train_batch("to_tensor_transform") # check training data data = next(iter(dm.train_dataloader())) From 3eb6417446876d519041c744abc200e79c305390 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Tue, 27 Apr 2021 16:27:55 +0200 Subject: [PATCH 13/53] training working --- flash/vision/segmentation/data.py | 23 ++-- .../finetuning/semantic_segmentation.py | 113 ++++++++++++++++++ 2 files changed, 128 insertions(+), 8 deletions(-) create mode 100644 flash_examples/finetuning/semantic_segmentation.py diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index abdc7b6296..5a60167c3e 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -24,6 +24,8 @@ # container to apply augmentations at both image and mask reusing the same parameters +# TODO: we have to figure out how to decide what transforms are applied to mask +# For instance, color transforms cannot be applied to masks class SegmentationSequential(nn.Sequential): def __init__(self, *args): @@ -31,11 +33,15 @@ def __init__(self, *args): def forward(self, img, mask): img_out = img.float() - mask_out = mask.float() + mask_out = mask[None].float() for aug in self.children(): img_out = aug(img_out) - mask_out = aug(mask_out, aug._params) - return img_out[0], mask_out[0] + # some transforms don't have params + if hasattr(aug, "_params"): + mask_out = aug(mask_out, aug._params) + else: + mask_out = aug(mask_out) + return img_out[0], mask_out[0, 0].long() def to_tensor(self, x): @@ -60,8 +66,8 @@ def __init__( train_transform, val_transform, test_transform, predict_transform )''' augs = SegmentationSequential( - # K.augmentation.RandomResizedCrop((128, 128)), - K.augmentation.RandomHorizontalFlip(), + K.geometry.Resize((128, 128), interpolation='nearest'), + K.augmentation.RandomHorizontalFlip(p=0.75), ) train_transform = dict(to_tensor_transform=augs) val_transform = dict(to_tensor_transform=augs) @@ -91,6 +97,8 @@ def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tens # load images directly to torch tensors img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW + # TODO: need to figure best api for this + img_labels = img_labels[0] # HxW return img, img_labels @@ -256,9 +264,8 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images and labels to numpy and stack horizontally - img_vis: np.ndarray = self._to_numpy(_img) - if len(_img_labels.shape) == 2: - _img_labels = self._labels_to_image(_img_labels) + img_vis: np.ndarray = self._to_numpy(_img.byte()) + _img_labels = self._labels_to_image(_img_labels.byte()) img_labels_vis: np.ndarray = self._to_numpy(_img_labels) img_vis = np.hstack((img_vis, img_labels_vis)) # send to visualiser diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py new file mode 100644 index 0000000000..bce36e74c4 --- /dev/null +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -0,0 +1,113 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import List, Tuple + +import pandas as pd +import torch + +import flash +from flash.core.classification import Labels +from flash.core.finetuning import FreezeUnfreeze +from flash.data.utils import download_data +from flash.vision import SemanticSegmentation, SemanticSegmentationData + +# 1. Download the data +# This is a subset of the movie poster genre prediction data set from the paper +# “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo. +# Please consider citing their paper if you use it. More here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ +#download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/") + +# 2. Load the data +# TODO: define labels maps +num_classes = 21 + +labels_map = {} +for i in range(num_classes): + labels_map[i] = torch.randint(0, 255, (3, )) + +ROOT_DIR = '/home/edgar/data/archive/dataA/dataA' + + +def load_data(data: str, root: str = '') -> Tuple[List[str], List[str]]: + images: List[str] = [] + labels: List[str] = [] + + rgb_path = os.path.join(ROOT_DIR, "CameraRGB") + seg_path = os.path.join(ROOT_DIR, "CameraSeg") + + for fname in os.listdir(rgb_path): + images.append(os.path.join(rgb_path, fname)) + labels.append(os.path.join(seg_path, fname)) + + return images, labels + + +train_filepaths, train_labels = load_data('train') +val_filepaths, val_labels = load_data('val') +test_filepaths, test_labels = load_data('test') + +datamodule = SemanticSegmentationData.from_filepaths( + train_filepaths=train_filepaths, + train_labels=train_labels, + val_filepaths=val_filepaths, + val_labels=val_labels, + test_filepaths=test_filepaths, + test_labels=test_labels, + batch_size=16 + #preprocess=ImageClassificationPreprocess(), +) +datamodule.set_map_labels(labels_map) +'''datamodule.set_block_viz_window(False) +datamodule.show_train_batch("load_sample") +datamodule.set_block_viz_window(True)''' +datamodule.show_train_batch("load_sample") +datamodule.show_train_batch("to_tensor_transform") + +# 3. Build the model +model = SemanticSegmentation( + backbone="torchvision/fcn_resnet50", + num_classes=num_classes, +) + +# 4. Create the trainer. +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) + +# 5. Train the model +trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) + +# 6a. Predict what's on a few images! + +# Serialize predictions as labels. +'''model.serializer = Labels(genres, multi_label=True) + +predictions = model.predict([ + "data/movie_posters/val/tt0361500.jpg", + "data/movie_posters/val/tt0361748.jpg", + "data/movie_posters/val/tt0362478.jpg", +]) + +print(predictions) + +datamodule = ImageClassificationData.from_folders( + predict_folder="data/movie_posters/predict/", + preprocess=model.preprocess, +) + +# 6b. Or generate predictions with a whole folder! +predictions = trainer.predict(model, datamodule=datamodule) +print(predictions)''' + +# 7. Save it! +trainer.save_checkpoint("semantic_segmentation_model.pt") From d529d9e411712b796b73803b7748248ebabe50ea Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Tue, 27 Apr 2021 21:10:14 +0200 Subject: [PATCH 14/53] training not crashing --- flash/vision/__init__.py | 2 +- flash/vision/segmentation/__init__.py | 2 +- flash/vision/segmentation/data.py | 37 ++++-- flash/vision/segmentation/model.py | 2 +- .../finetuning/semantic_segmentation.py | 108 +++++++++++------- 5 files changed, 98 insertions(+), 53 deletions(-) diff --git a/flash/vision/__init__.py b/flash/vision/__init__.py index 526b1f2892..346c84870a 100644 --- a/flash/vision/__init__.py +++ b/flash/vision/__init__.py @@ -2,4 +2,4 @@ from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier from flash.vision.detection import ObjectDetectionData, ObjectDetector from flash.vision.embedding import ImageEmbedder -from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData, SemantincSegmentationPreprocess +from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess diff --git a/flash/vision/segmentation/__init__.py b/flash/vision/segmentation/__init__.py index 520694b412..08f9742e47 100644 --- a/flash/vision/segmentation/__init__.py +++ b/flash/vision/segmentation/__init__.py @@ -1,2 +1,2 @@ -from flash.vision.segmentation.data import SemanticSegmentationData, SemantincSegmentationPreprocess +from flash.vision.segmentation.data import SemanticSegmentationData, SemanticSegmentationPreprocess from flash.vision.segmentation.model import SemanticSegmentation diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 5a60167c3e..ade1da7343 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -31,6 +31,7 @@ class SegmentationSequential(nn.Sequential): def __init__(self, *args): super(SegmentationSequential, self).__init__(*args) + @torch.no_grad() def forward(self, img, mask): img_out = img.float() mask_out = mask[None].float() @@ -48,7 +49,7 @@ def to_tensor(self, x): return K.utils.image_to_tensor(np.array(x)) -class SemantincSegmentationPreprocess(Preprocess): +class SemanticSegmentationPreprocess(Preprocess): def __init__( self, @@ -58,21 +59,26 @@ def __init__( predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, - ) -> 'SemantincSegmentationPreprocess': + ) -> 'SemanticSegmentationPreprocess': self._map_labels = map_labels # TODO: implement me '''train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( train_transform, val_transform, test_transform, predict_transform )''' - augs = SegmentationSequential( - K.geometry.Resize((128, 128), interpolation='nearest'), + augs_train = SegmentationSequential( + K.geometry.Resize(image_size, interpolation='nearest'), K.augmentation.RandomHorizontalFlip(p=0.75), ) - train_transform = dict(to_tensor_transform=augs) + augs = SegmentationSequential( + K.geometry.Resize(image_size, interpolation='nearest'), + K.augmentation.RandomHorizontalFlip(p=0.), + ) + augs_pred = nn.Sequential(K.geometry.Resize(image_size, interpolation='nearest'), ) + train_transform = dict(to_tensor_transform=augs_train) val_transform = dict(to_tensor_transform=augs) test_transform = dict(to_tensor_transform=augs) - predict_transform = dict(to_tensor_transform=augs) + predict_transform = dict(to_tensor_transform=augs_pred) super().__init__(train_transform, val_transform, test_transform, predict_transform) @@ -87,9 +93,17 @@ def _image_to_labels(self, img) -> torch.Tensor: return outs # TODO: is it a problem to load sample directly in tensor. What happens in to_tensor_tranform - def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tensor]: - if not isinstance(sample, tuple): + def load_sample(self, sample: Union[str, Tuple[str, + str]]) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if not isinstance(sample, ( + str, + tuple, + )): raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") + + if isinstance(sample, str): # case for predict + return torchvision.io.read_image(sample) + # unpack data paths img_path: str = sample[0] img_labels_path: str = sample[1] @@ -103,6 +117,10 @@ def load_sample(self, sample: Tuple[str, str]) -> Tuple[torch.Tensor, torch.Tens return img, img_labels def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(sample, torch.Tensor): # case for predict + out = sample.float() / 255. # TODO: define predict transforms + return out + if not isinstance(sample, tuple): raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") img, img_labels = sample @@ -190,11 +208,12 @@ def from_filepaths( SemanticSegmentationData._check_valid_filepaths(predict_filepaths) # create the preprocess objects - preprocess = preprocess or SemantincSegmentationPreprocess( + preprocess = preprocess or SemanticSegmentationPreprocess( train_transform, val_transform, test_transform, predict_transform, + image_size=image_size, map_labels=map_labels, ) diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index 62b39f9409..c0db768e41 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -72,7 +72,7 @@ def __init__( backbone_kwargs = {} # TODO: pretrained to True causes some issues - self.backbone = self.backbones.get(backbone)(pretrained=False, num_classes=num_classes, **backbone_kwargs) + self.backbone = self.backbones.get(backbone)(pretrained=True, num_classes=num_classes, **backbone_kwargs) def forward(self, x) -> torch.Tensor: return self.backbone(x)['out'] # TODO: find a proper way to get 'out' from registry diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index bce36e74c4..598dd22e57 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -21,7 +21,7 @@ from flash.core.classification import Labels from flash.core.finetuning import FreezeUnfreeze from flash.data.utils import download_data -from flash.vision import SemanticSegmentation, SemanticSegmentationData +from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess # 1. Download the data # This is a subset of the movie poster genre prediction data set from the paper @@ -29,49 +29,45 @@ # Please consider citing their paper if you use it. More here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ #download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/") +# download from: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge + # 2. Load the data -# TODO: define labels maps -num_classes = 21 +num_classes: int = 21 labels_map = {} for i in range(num_classes): labels_map[i] = torch.randint(0, 255, (3, )) -ROOT_DIR = '/home/edgar/data/archive/dataA/dataA' +root_dir = '/home/edgar/data/archive' +datasets = ['dataA', 'dataB', 'dataC', 'dataD', 'dataE'] -def load_data(data: str, root: str = '') -> Tuple[List[str], List[str]]: +def load_data(data_root: str, datasets: List[str]) -> Tuple[List[str], List[str]]: images: List[str] = [] labels: List[str] = [] - - rgb_path = os.path.join(ROOT_DIR, "CameraRGB") - seg_path = os.path.join(ROOT_DIR, "CameraSeg") - - for fname in os.listdir(rgb_path): - images.append(os.path.join(rgb_path, fname)) - labels.append(os.path.join(seg_path, fname)) - + for data in datasets: + data_dir = os.path.join(root_dir, data, data) + rgb_path = os.path.join(data_dir, "CameraRGB") + seg_path = os.path.join(data_dir, "CameraSeg") + for fname in os.listdir(rgb_path): + images.append(os.path.join(rgb_path, fname)) + labels.append(os.path.join(seg_path, fname)) return images, labels -train_filepaths, train_labels = load_data('train') -val_filepaths, val_labels = load_data('val') -test_filepaths, test_labels = load_data('test') +train_filepaths, train_labels = load_data(root_dir, datasets[:3]) +val_filepaths, val_labels = load_data(root_dir, datasets[3:4]) +predict_filepaths, predict_labels = load_data(root_dir, datasets[4:5]) datamodule = SemanticSegmentationData.from_filepaths( train_filepaths=train_filepaths, train_labels=train_labels, val_filepaths=val_filepaths, val_labels=val_labels, - test_filepaths=test_filepaths, - test_labels=test_labels, - batch_size=16 - #preprocess=ImageClassificationPreprocess(), + batch_size=4, + image_size=(300, 400), # (600, 800) ) datamodule.set_map_labels(labels_map) -'''datamodule.set_block_viz_window(False) -datamodule.show_train_batch("load_sample") -datamodule.set_block_viz_window(True)''' datamodule.show_train_batch("load_sample") datamodule.show_train_batch("to_tensor_transform") @@ -82,32 +78,62 @@ def load_data(data: str, root: str = '') -> Tuple[List[str], List[str]]: ) # 4. Create the trainer. -trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) +#trainer = flash.Trainer(max_epochs=5, limit_train_batches=1, limit_val_batches=1) +trainer = flash.Trainer( + max_epochs=20, + gpus=1, + #precision=16, # why slower ? :) +) # 5. Train the model -trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) +trainer.finetune(model, datamodule=datamodule, strategy='freeze') +# TODO: getting error: BrokenPipeError: [Errno 32] Broken pipe -# 6a. Predict what's on a few images! +# 6. Predict what's on a few images! -# Serialize predictions as labels. -'''model.serializer = Labels(genres, multi_label=True) +import kornia as K +import matplotlib.pyplot as plt -predictions = model.predict([ - "data/movie_posters/val/tt0361500.jpg", - "data/movie_posters/val/tt0361748.jpg", - "data/movie_posters/val/tt0362478.jpg", -]) +from flash.data.process import ProcessState, Serializer -print(predictions) -datamodule = ImageClassificationData.from_folders( - predict_folder="data/movie_posters/predict/", - preprocess=model.preprocess, -) +class SegmentationLabels(Serializer): + + def __init__(self, map_labels, visualise): + super().__init__() + self.map_labels = map_labels + self.visualise = visualise + + def _labels_to_image(self, img_labels: torch.Tensor) -> torch.Tensor: + assert len(img_labels.shape) == 2, img_labels.shape + H, W = img_labels.shape + out = torch.empty(3, H, W, dtype=torch.uint8) + for label_id, label_val in self.map_labels.items(): + mask = (img_labels == label_id) + for i in range(3): + out[i].masked_fill_(mask, label_val[i]) + return out + + def serialize(self, sample: torch.Tensor) -> torch.Tensor: + assert len(sample.shape) == 3, sample.shape + labels = torch.argmax(sample, dim=-3) # HxW + if self.visualise: + labels_vis = self._labels_to_image(labels) + labels_vis = K.utils.tensor_to_image(labels_vis) + plt.imshow(labels_vis) + plt.show() + return labels + + +model.serializer = SegmentationLabels(labels_map, visualise=True) + +predictions = model.predict([ + predict_filepaths[0], + predict_filepaths[1], + predict_filepaths[2], +], datamodule.data_pipeline) -# 6b. Or generate predictions with a whole folder! -predictions = trainer.predict(model, datamodule=datamodule) -print(predictions)''' +#print(predictions) # 7. Save it! trainer.save_checkpoint("semantic_segmentation_model.pt") From 13095e6d25c18044b19de58f7e07a3762372a915 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 11:49:29 +0200 Subject: [PATCH 15/53] cleanup example and move serializer to core --- flash/core/classification.py | 57 +++++++++- flash/vision/segmentation/data.py | 21 +--- .../finetuning/semantic_segmentation.py | 103 +++++------------- 3 files changed, 89 insertions(+), 92 deletions(-) diff --git a/flash/core/classification.py b/flash/core/classification.py index 8270fec99b..53a6f39fa9 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -12,8 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +# for visualisation +import kornia as K +import matplotlib.pyplot as plt import torch import torch.nn.functional as F import torchmetrics @@ -164,3 +167,55 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: "No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning ) return classes + + +class SegmentationLabels(Serializer): + + def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualise: bool = False): + """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification + per pixel in the image for semantic segmentation tasks. + + Args: + labels_map: A dictionary that map the labels ids to pixel intensities. + visualise: Wether to visualise the image labels. + """ + super().__init__() + self.labels_map = labels_map + self.visualise = visualise + + @staticmethod + def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor: + """Function that given an image with labels ids and their pixels intrensity mapping, + creates a RGB representation for visualisation purposes. + """ + assert len(img_labels.shape) == 2, img_labels.shape + H, W = img_labels.shape + out = torch.empty(3, H, W, dtype=torch.uint8) + for label_id, label_val in labels_map.items(): + mask = (img_labels == label_id) + for i in range(3): + out[i].masked_fill_(mask, label_val[i]) + return out + + @staticmethod + def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: + labels_map: Dict[int, Tuple[int, int, int]] = {} + for i in range(num_classes): + labels_map[i] = torch.randint(0, 255, (3, )) + return labels_map + + def serialize(self, sample: torch.Tensor) -> torch.Tensor: + assert len(sample.shape) == 3, sample.shape + labels = torch.argmax(sample, dim=-3) # HxW + if self.visualise: + if self.labels_map is None: + # create random colors map + num_classes = sample.shape[-3] + labels_map = self.create_random_labels_map(num_classes) + else: + labels_map = self.labels_map + labels_vis = self.labels_to_image(labels, labels_map) + labels_vis = K.utils.tensor_to_image(labels_vis) + plt.imshow(labels_vis) + plt.show() + return labels diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index ade1da7343..82fbe63e43 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -10,6 +10,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset +from flash.core.classification import SegmentationLabels from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher @@ -166,8 +167,8 @@ def _check_valid_filepaths(filepaths: List[str]): def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return _MatplotlibVisualization(*args, **kwargs) - def set_map_labels(self, map_labels): - self.data_fetcher.map_labels = map_labels + def set_map_labels(self, labels_map: Dict[int, Tuple[int, int, int]]): + self.data_fetcher.labels_map = labels_map def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" @@ -193,7 +194,7 @@ def from_filepaths( #seed: Optional[int] = 42, # SEED NEVER USED data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, - # val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED + val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, **kwargs, # TODO: remove and make explicit params ) -> 'SemanticSegmentationData': @@ -238,7 +239,7 @@ class _MatplotlibVisualization(BaseVisualization): """ max_cols: int = 4 # maximum number of columns we accept block_viz_window: bool = True # parameter to allow user to block visualisation windows - map_labels = {} + labels_map: Dict[int, Tuple[int, int, int]] = {} @staticmethod def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: @@ -251,16 +252,6 @@ def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: raise TypeError(f"Unknown image type. Got: {type(img)}.") return out - def _labels_to_image(self, img_labels: torch.Tensor) -> torch.Tensor: - assert len(img_labels.shape) == 2, img_labels.shape - H, W = img_labels.shape - out = torch.empty(3, H, W, dtype=torch.uint8) - for label_id, label_val in self.map_labels.items(): - mask = (img_labels == label_id) - for i in range(3): - out[i].masked_fill_(mask, label_val[i]) - return out - def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str): # define the image grid cols: int = min(num_samples, self.max_cols) @@ -284,7 +275,7 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images and labels to numpy and stack horizontally img_vis: np.ndarray = self._to_numpy(_img.byte()) - _img_labels = self._labels_to_image(_img_labels.byte()) + _img_labels = SegmentationLabels.labels_to_image(_img_labels.byte(), self.labels_map) img_labels_vis: np.ndarray = self._to_numpy(_img_labels) img_vis = np.hstack((img_vis, img_labels_vis)) # send to visualiser diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 598dd22e57..e1ebdfae54 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -18,55 +18,45 @@ import torch import flash -from flash.core.classification import Labels -from flash.core.finetuning import FreezeUnfreeze +from flash.core.classification import SegmentationLabels from flash.data.utils import download_data from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess # 1. Download the data -# This is a subset of the movie poster genre prediction data set from the paper -# “Movie Genre Classification based on Poster Images with Deep Neural Networks” by Wei-Ta Chu and Hung-Jui Guo. -# Please consider citing their paper if you use it. More here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ -#download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip", "data/") - -# download from: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge - -# 2. Load the data -num_classes: int = 21 - -labels_map = {} -for i in range(num_classes): - labels_map[i] = torch.randint(0, 255, (3, )) +# This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator. +# The data was generated as part of the Lyft Udacity Challenge. +# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge +download_data( + 'https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip', "data/" +) -root_dir = '/home/edgar/data/archive' -datasets = ['dataA', 'dataB', 'dataC', 'dataD', 'dataE'] +# 2.1 Load the data -def load_data(data_root: str, datasets: List[str]) -> Tuple[List[str], List[str]]: +def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: images: List[str] = [] labels: List[str] = [] - for data in datasets: - data_dir = os.path.join(root_dir, data, data) - rgb_path = os.path.join(data_dir, "CameraRGB") - seg_path = os.path.join(data_dir, "CameraSeg") - for fname in os.listdir(rgb_path): - images.append(os.path.join(rgb_path, fname)) - labels.append(os.path.join(seg_path, fname)) + rgb_path = os.path.join(data_root, "CameraRGB") + seg_path = os.path.join(data_root, "CameraSeg") + for fname in os.listdir(rgb_path): + images.append(os.path.join(rgb_path, fname)) + labels.append(os.path.join(seg_path, fname)) return images, labels -train_filepaths, train_labels = load_data(root_dir, datasets[:3]) -val_filepaths, val_labels = load_data(root_dir, datasets[3:4]) -predict_filepaths, predict_labels = load_data(root_dir, datasets[4:5]) +images_filepaths, labels_filepaths = load_data() +# create the data module datamodule = SemanticSegmentationData.from_filepaths( - train_filepaths=train_filepaths, - train_labels=train_labels, - val_filepaths=val_filepaths, - val_labels=val_labels, + train_filepaths=images_filepaths, + train_labels=labels_filepaths, batch_size=4, + val_split=0.3, image_size=(300, 400), # (600, 800) ) + +# 2.2 Visualise the samples +labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) datamodule.set_map_labels(labels_map) datamodule.show_train_batch("load_sample") datamodule.show_train_batch("to_tensor_transform") @@ -74,66 +64,27 @@ def load_data(data_root: str, datasets: List[str]) -> Tuple[List[str], List[str] # 3. Build the model model = SemanticSegmentation( backbone="torchvision/fcn_resnet50", - num_classes=num_classes, + num_classes=21, ) # 4. Create the trainer. -#trainer = flash.Trainer(max_epochs=5, limit_train_batches=1, limit_val_batches=1) trainer = flash.Trainer( max_epochs=20, gpus=1, - #precision=16, # why slower ? :) + # precision=16, # why slower ? :) ) # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy='freeze') -# TODO: getting error: BrokenPipeError: [Errno 32] Broken pipe # 6. Predict what's on a few images! - -import kornia as K -import matplotlib.pyplot as plt - -from flash.data.process import ProcessState, Serializer - - -class SegmentationLabels(Serializer): - - def __init__(self, map_labels, visualise): - super().__init__() - self.map_labels = map_labels - self.visualise = visualise - - def _labels_to_image(self, img_labels: torch.Tensor) -> torch.Tensor: - assert len(img_labels.shape) == 2, img_labels.shape - H, W = img_labels.shape - out = torch.empty(3, H, W, dtype=torch.uint8) - for label_id, label_val in self.map_labels.items(): - mask = (img_labels == label_id) - for i in range(3): - out[i].masked_fill_(mask, label_val[i]) - return out - - def serialize(self, sample: torch.Tensor) -> torch.Tensor: - assert len(sample.shape) == 3, sample.shape - labels = torch.argmax(sample, dim=-3) # HxW - if self.visualise: - labels_vis = self._labels_to_image(labels) - labels_vis = K.utils.tensor_to_image(labels_vis) - plt.imshow(labels_vis) - plt.show() - return labels - - model.serializer = SegmentationLabels(labels_map, visualise=True) predictions = model.predict([ - predict_filepaths[0], - predict_filepaths[1], - predict_filepaths[2], + 'data/CameraRGB/F61-1.png', + 'data/CameraRGB/F62-1.png', + 'data/CameraRGB/F63-1.png', ], datamodule.data_pipeline) -#print(predictions) - # 7. Save it! trainer.save_checkpoint("semantic_segmentation_model.pt") From 2f9ede54bd94b2edf9f07dc427c1de156303471e Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 12:23:53 +0200 Subject: [PATCH 16/53] cleanup model code, tests and docs --- flash/vision/segmentation/model.py | 38 +++++++++++++++---- .../finetuning/semantic_segmentation.py | 2 +- tests/vision/segmentation/test_data.py | 2 +- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index c0db768e41..527f2764bd 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -19,15 +19,41 @@ from torch.nn import functional as F from torchmetrics import Accuracy, IoU -from flash.core.classification import Classes, ClassificationTask +from flash.core.classification import ClassificationTask, SegmentationLabels from flash.core.registry import FlashRegistry from flash.data.process import Preprocess, Serializer +from flash.utils.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + import torchvision SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") class SemanticSegmentation(ClassificationTask): """Task that performs semantic segmentation on images. + + Use a built in backbone + + Example:: + + from flash.vision import SemanticSegmentation + + segmentation = SemanticSegmentation( + num_classes=21, backbone="torchvision/fcn_resnet50" + ) + + Args: + num_classes: Number of classes to classify. + backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"torchvision/fcn_resnet50"``. + backbone_kwargs: Additional arguments for the backbone configuration. + pretrained: Use a pretrained backbone, defaults to ``False``. + loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. + optimizer: Optimizer to use for training, defaults to :class:`torch.optim.Adam`. + metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.IoU`. + learning_rate: Learning rate to use for training, defaults to ``1e-3``. + multi_label: Whether the targets are multi-label or not. + serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs. """ backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES @@ -37,8 +63,7 @@ def __init__( num_classes: int, backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50", backbone_kwargs: Optional[Dict] = None, - head: Optional[Union[FunctionType, nn.Module]] = None, - pretrained: bool = True, + pretrained: bool = False, loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, @@ -63,7 +88,7 @@ def __init__( optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, - serializer=serializer or Classes(multi_label=multi_label), + serializer=serializer or SegmentationLabels(), ) self.save_hyperparameters() @@ -72,14 +97,13 @@ def __init__( backbone_kwargs = {} # TODO: pretrained to True causes some issues - self.backbone = self.backbones.get(backbone)(pretrained=True, num_classes=num_classes, **backbone_kwargs) + self.backbone = self.backbones.get(backbone)(pretrained=pretrained, num_classes=num_classes, **backbone_kwargs) def forward(self, x) -> torch.Tensor: return self.backbone(x)['out'] # TODO: find a proper way to get 'out' from registry @SemanticSegmentation.backbones(name="torchvision/fcn_resnet50") -def fn(pretrained: bool, num_classes: int): - import torchvision +def fn(pretrained: bool, num_classes: int) -> nn.Module: model: nn.Module = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained, num_classes=num_classes) return model diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index e1ebdfae54..493c741de8 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -51,7 +51,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: train_filepaths=images_filepaths, train_labels=labels_filepaths, batch_size=4, - val_split=0.3, + val_split=0.3, # TODO: this needs to be implemented image_size=(300, 400), # (600, 800) ) diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 789b5800a0..2ff739243f 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -7,7 +7,7 @@ from PIL import Image from flash import Trainer -from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemantincSegmentationPreprocess +from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess def build_checkboard(n, m, k=8): From dc9b2b88c4e27cf2ffa23ce46513d70276603647 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 13:27:02 +0200 Subject: [PATCH 17/53] move transforms apart --- flash/data/data_module.py | 1 - flash/vision/segmentation/data.py | 165 ++++++++++-------- flash/vision/segmentation/transforms.py | 58 ++++++ .../finetuning/semantic_segmentation.py | 2 +- 4 files changed, 149 insertions(+), 77 deletions(-) create mode 100644 flash/vision/segmentation/transforms.py diff --git a/flash/data/data_module.py b/flash/data/data_module.py index a4ede7d12d..bcb3787268 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -435,7 +435,6 @@ def from_load_data_inputs( data_pipeline = cls(**kwargs).data_pipeline data_fetcher: BaseDataFetcher = data_fetcher or cls.configure_data_fetcher() - #data_fetcher: BaseDataFetcher = data_fetcher or DataModule.configure_data_fetcher() data_fetcher.attach_to_preprocess(data_pipeline._preprocess_pipeline) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 82fbe63e43..34a9cc947f 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -1,3 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union import kornia as K @@ -18,38 +31,14 @@ from flash.data.process import Preprocess from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE +from . import transforms as T + if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: plt = None -# container to apply augmentations at both image and mask reusing the same parameters -# TODO: we have to figure out how to decide what transforms are applied to mask -# For instance, color transforms cannot be applied to masks -class SegmentationSequential(nn.Sequential): - - def __init__(self, *args): - super(SegmentationSequential, self).__init__(*args) - - @torch.no_grad() - def forward(self, img, mask): - img_out = img.float() - mask_out = mask[None].float() - for aug in self.children(): - img_out = aug(img_out) - # some transforms don't have params - if hasattr(aug, "_params"): - mask_out = aug(mask_out, aug._params) - else: - mask_out = aug(mask_out) - return img_out[0], mask_out[0, 0].long() - - -def to_tensor(self, x): - return K.utils.image_to_tensor(np.array(x)) - - class SemanticSegmentationPreprocess(Preprocess): def __init__( @@ -59,48 +48,56 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), - map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, - ) -> 'SemanticSegmentationPreprocess': - self._map_labels = map_labels - - # TODO: implement me - '''train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( - train_transform, val_transform, test_transform, predict_transform - )''' - augs_train = SegmentationSequential( - K.geometry.Resize(image_size, interpolation='nearest'), - K.augmentation.RandomHorizontalFlip(p=0.75), - ) - augs = SegmentationSequential( - K.geometry.Resize(image_size, interpolation='nearest'), - K.augmentation.RandomHorizontalFlip(p=0.), + ) -> None: + """Preprocess pipeline for semantic segmentation tasks. + + Args: + train_transform: Dictionary with the set of transforms to apply during training. + val_transform: Dictionary with the set of transforms to apply during validation. + test_transform: Dictionary with the set of transforms to apply during testing. + predict_transform: Dictionary with the set of transforms to apply during prediction. + image_size: A tuple with the expected output image size. + """ + train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( + train_transform, val_transform, test_transform, predict_transform, image_size ) - augs_pred = nn.Sequential(K.geometry.Resize(image_size, interpolation='nearest'), ) - train_transform = dict(to_tensor_transform=augs_train) - val_transform = dict(to_tensor_transform=augs) - test_transform = dict(to_tensor_transform=augs) - predict_transform = dict(to_tensor_transform=augs_pred) - super().__init__(train_transform, val_transform, test_transform, predict_transform) - def _image_to_labels(self, img) -> torch.Tensor: - assert len(img.shape) == 3, img.shape - C, H, W = img.shape - outs = torch.empty(H, W, dtype=torch.int64) - for label, values in self._map_labels.items(): - vals = torch.tensor(values).view(3, 1, 1) - mask = (img == vals).all(-3) - outs[mask] = label - return outs - - # TODO: is it a problem to load sample directly in tensor. What happens in to_tensor_tranform + def _resolve_transforms( + self, + train_transform: Optional[Union[str, Dict]] = 'default', + val_transform: Optional[Union[str, Dict]] = 'default', + test_transform: Optional[Union[str, Dict]] = 'default', + predict_transform: Optional[Union[str, Dict]] = 'default', + image_size: Tuple[int, int] = (196, 196), + ): + + if not train_transform or train_transform == 'default': + train_transform = T.default_train_transforms(image_size) + + if not val_transform or val_transform == 'default': + val_transform = T.default_val_transforms(image_size) + + if not test_transform or test_transform == 'default': + test_transform = T.default_val_transforms(image_size) + + if not predict_transform or predict_transform == 'default': + predict_transform = T.default_val_transforms(image_size) + + return ( + train_transform, + val_transform, + test_transform, + predict_transform, + ) + def load_sample(self, sample: Union[str, Tuple[str, str]]) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not isinstance(sample, ( str, tuple, )): - raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") + raise TypeError(f"Invalid type, expected `str` or `tuple`. Got: {sample}.") if isinstance(sample, str): # case for predict return torchvision.io.read_image(sample) @@ -117,7 +114,7 @@ def load_sample(self, sample: Union[str, Tuple[str, return img, img_labels - def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def post_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: if isinstance(sample, torch.Tensor): # case for predict out = sample.float() / 255. # TODO: define predict transforms return out @@ -127,10 +124,6 @@ def to_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tupl img, img_labels = sample img_out, img_labels_out = self.current_transform(img, img_labels) - # TODO: decide at which point do we apply this - if self._map_labels is not None: - img_labels_out = self._image_to_labels(img_labels_out) - return img_out, img_labels_out # TODO: the labels are not clear how to forward to the loss once are transform from this point @@ -152,10 +145,6 @@ def per_batch_transform_on_device(self, sample: Any) -> Any: class SemanticSegmentationData(DataModule): """Data module for semantic segmentation tasks.""" - # TODO: figure out if this needed - #def __init__(self, **kwargs) -> None: - # super().__init__(**kwargs) - @staticmethod def _check_valid_filepaths(filepaths: List[str]): if filepaths is not None and ( @@ -177,8 +166,8 @@ def set_block_viz_window(self, value: bool) -> None: @classmethod def from_filepaths( cls, - train_filepaths: Optional[List[str]] = None, - train_labels: Optional[List[str]] = None, + train_filepaths: List[str], + train_labels: List[str], val_filepaths: Optional[List[str]] = None, val_labels: Optional[List[str]] = None, test_filepaths: Optional[List[str]] = None, @@ -195,9 +184,36 @@ def from_filepaths( data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED - map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, **kwargs, # TODO: remove and make explicit params ) -> 'SemanticSegmentationData': + """Creates a Semantic SegmentationData object from a given list of paths to images and labels. + + Args: + train_filepaths: List of file paths for training images. + train_labels: List of file path for the training image labels. + val_filepaths: List of file paths for validation images. + val_labels: List of file path for the validation image labels. + test_filepaths: List of file paths for testing images. + test_labels: List of file path for the testing image labels. + predict_filepaths: List of file paths for predicting images. + train_transform: Image and mask transform to use for the train set. + val_transform: Image and mask transform to use for the validation set. + test_transform: Image and mask transform to use for the test set. + predict_transform: Image transform to use for the predict set. + image_size: A tuple with the expected output image size. + batch_size: The batch size to use for parallel loading. + num_workers: The number of workers to use for parallelized loading. + Defaults to ``None`` which equals the number of available CPU threads. + data_fetcher: An optional data fetcher object instance. + preprocess: An optional `SemanticSegmentationPreprocess` object instance. + val_split: Float number to control the percentage of train/validation samples + from the ``train_filepaths`` and ``train_labels`` list. + + + Returns: + SemanticSegmentationData: The constructed data module. + + """ # verify input data format SemanticSegmentationData._check_valid_filepaths(train_filepaths) @@ -215,7 +231,6 @@ def from_filepaths( test_transform, predict_transform, image_size=image_size, - map_labels=map_labels, ) # this functions overrides `DataModule.from_load_data_inputs` @@ -223,7 +238,7 @@ def from_filepaths( train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, - predict_load_data_input=predict_filepaths, + # predict_load_data_input=predict_filepaths, # TODO: is it really used ? batch_size=batch_size, num_workers=num_workers, data_fetcher=data_fetcher, @@ -287,6 +302,6 @@ def show_load_sample(self, samples: List[Any], running_stage: RunningStage): win_title: str = f"{running_stage} - show_load_sample" self._show_images_and_labels(samples, len(samples), win_title) - def show_to_tensor_transform(self, samples: List[Any], running_stage: RunningStage): - win_title: str = f"{running_stage} - show_to_tensor_transform" + def show_post_tensor_transform(self, samples: List[Any], running_stage: RunningStage): + win_title: str = f"{running_stage} - show_post_tensor_transform" self._show_images_and_labels(samples, len(samples), win_title) diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py new file mode 100644 index 0000000000..f4bce6690f --- /dev/null +++ b/flash/vision/segmentation/transforms.py @@ -0,0 +1,58 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Tuple + +import kornia as K +import torch +import torch.nn as nn + + +# container to apply augmentations at both image and mask reusing the same parameters +# TODO: we have to figure out how to decide what transforms are applied to mask +# For instance, color transforms cannot be applied to masks +class SegmentationSequential(nn.Sequential): + + def __init__(self, *args): + super(SegmentationSequential, self).__init__(*args) + + @torch.no_grad() + def forward(self, img, mask): + img_out = img.float() + mask_out = mask[None].float() + for aug in self.children(): + img_out = aug(img_out) + # some transforms don't have params + if hasattr(aug, "_params"): + mask_out = aug(mask_out, aug._params) + else: + mask_out = aug(mask_out) + return img_out[0], mask_out[0, 0].long() + + +def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + return { + "post_tensor_transform": SegmentationSequential( + K.geometry.Resize(image_size, interpolation='nearest'), + K.augmentation.RandomHorizontalFlip(p=0.75), + ), + } + + +def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: + return { + "post_tensor_transform": SegmentationSequential( + K.geometry.Resize(image_size, interpolation='nearest'), + K.augmentation.RandomHorizontalFlip(p=0.), # #TODO: bug somewhere with shapes + ), + } diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 493c741de8..7ed7fc13de 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -59,7 +59,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) datamodule.set_map_labels(labels_map) datamodule.show_train_batch("load_sample") -datamodule.show_train_batch("to_tensor_transform") +datamodule.show_train_batch("post_tensor_transform") # 3. Build the model model = SemanticSegmentation( From e767a5371b00b16e376d2241019f2183b0369e58 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 16:13:34 +0200 Subject: [PATCH 18/53] implement ApplytransformsToKey augmentations --- flash/vision/segmentation/data.py | 53 +++++++++++----------- flash/vision/segmentation/transforms.py | 60 +++++++++++++++---------- tests/vision/segmentation/test_data.py | 19 ++++---- 3 files changed, 73 insertions(+), 59 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 34a9cc947f..2e93c18b74 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union -import kornia as K import numpy as np import torch import torch.nn as nn @@ -29,7 +28,7 @@ from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.process import Preprocess -from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE +from flash.utils.imports import _MATPLOTLIB_AVAILABLE from . import transforms as T @@ -65,12 +64,12 @@ def __init__( def _resolve_transforms( self, - train_transform: Optional[Union[str, Dict]] = 'default', - val_transform: Optional[Union[str, Dict]] = 'default', - test_transform: Optional[Union[str, Dict]] = 'default', - predict_transform: Optional[Union[str, Dict]] = 'default', + train_transform: Optional[Union[str, Dict]] = None, + val_transform: Optional[Union[str, Dict]] = None, + test_transform: Optional[Union[str, Dict]] = None, + predict_transform: Optional[Union[str, Dict]] = None, image_size: Tuple[int, int] = (196, 196), - ): + ) -> Tuple[Dict[str, Callable], ...]: if not train_transform or train_transform == 'default': train_transform = T.default_train_transforms(image_size) @@ -109,22 +108,24 @@ def load_sample(self, sample: Union[str, Tuple[str, # load images directly to torch tensors img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW - # TODO: need to figure best api for this - img_labels = img_labels[0] # HxW - return img, img_labels + return {'images': img, 'masks': img_labels} def post_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: if isinstance(sample, torch.Tensor): # case for predict out = sample.float() / 255. # TODO: define predict transforms return out - if not isinstance(sample, tuple): - raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") - img, img_labels = sample - img_out, img_labels_out = self.current_transform(img, img_labels) + if not isinstance(sample, dict): + raise TypeError(f"Invalid type, expected `dict`. Got: {sample}.") + + # arrange data as floating point and batch before the augmentations + sample['images'] = sample['images'][None].float().contiguous() # 1xCxHxW + sample['masks'] = sample['masks'][None, :1].float().contiguous() # 1x1xHxW + + out: Dict[str, torch.Tensor] = self.current_transform(sample) - return img_out, img_labels_out + return out['images'][0], out['masks'][0, 0].long() # TODO: the labels are not clear how to forward to the loss once are transform from this point '''def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: @@ -156,7 +157,7 @@ def _check_valid_filepaths(filepaths: List[str]): def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return _MatplotlibVisualization(*args, **kwargs) - def set_map_labels(self, labels_map: Dict[int, Tuple[int, int, int]]): + def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]): self.data_fetcher.labels_map = labels_map def set_block_viz_window(self, value: bool) -> None: @@ -281,18 +282,20 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) for i, ax in enumerate(axs.ravel()): # unpack images and labels - if isinstance(data, list): - _img, _img_labels = data[i] - elif isinstance(data, tuple): - imgs, imgs_labels = data - _img, _img_labels = imgs[i], imgs_labels[i] + sample = data[i] + if isinstance(sample, dict): + image = sample['images'] + label = sample['masks'] + elif isinstance(sample, tuple): + image = sample[0] + label = sample[1] else: raise TypeError(f"Unknown data type. Got: {type(data)}.") # convert images and labels to numpy and stack horizontally - img_vis: np.ndarray = self._to_numpy(_img.byte()) - _img_labels = SegmentationLabels.labels_to_image(_img_labels.byte(), self.labels_map) - img_labels_vis: np.ndarray = self._to_numpy(_img_labels) - img_vis = np.hstack((img_vis, img_labels_vis)) + image_vis: np.ndarray = self._to_numpy(image.byte()) + label_tmp: torch.Tensor = SegmentationLabels.labels_to_image(label.squeeze().byte(), self.labels_map) + label_vis: np.ndarray = self._to_numpy(label_tmp) + img_vis = np.hstack((image_vis, label_vis)) # send to visualiser ax.imshow(img_vis) ax.axis('off') diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index f4bce6690f..b02778dd91 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -11,48 +11,60 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, List, Tuple import kornia as K import torch import torch.nn as nn -# container to apply augmentations at both image and mask reusing the same parameters -# TODO: we have to figure out how to decide what transforms are applied to mask -# For instance, color transforms cannot be applied to masks -class SegmentationSequential(nn.Sequential): +class ApplyTransformToKeys(nn.Sequential): - def __init__(self, *args): - super(SegmentationSequential, self).__init__(*args) + def __init__(self, keys: List[str], *args): + super().__init__(*args) + self.keys = keys @torch.no_grad() - def forward(self, img, mask): - img_out = img.float() - mask_out = mask[None].float() - for aug in self.children(): - img_out = aug(img_out) - # some transforms don't have params - if hasattr(aug, "_params"): - mask_out = aug(mask_out, aug._params) - else: - mask_out = aug(mask_out) - return img_out[0], mask_out[0, 0].long() + def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + for seq in self.children(): + for aug in seq: + for key in self.keys: + # check wether the transform was applied + # and apply transform + if hasattr(aug, "_params") and bool(aug._params): + params = aug._params + x[key] = aug(x[key], params) + else: # case for non random transforms + x[key] = aug(x[key]) + return x def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { - "post_tensor_transform": SegmentationSequential( - K.geometry.Resize(image_size, interpolation='nearest'), - K.augmentation.RandomHorizontalFlip(p=0.75), + "post_tensor_transform": nn.Sequential( + ApplyTransformToKeys(['images', 'masks'], + nn.Sequential( + K.geometry.Resize(image_size, interpolation='nearest'), + K.augmentation.RandomHorizontalFlip(p=0.75), + )), + ApplyTransformToKeys( + ['images'], + nn.Sequential( + K.enhance.Normalize(0., 255.), + K.augmentation.ColorJitter(0.4, p=0.5), + # NOTE: uncomment to visualise better + # K.enhance.Denormalize(0., 255.), + ) + ), ), } def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { - "post_tensor_transform": SegmentationSequential( - K.geometry.Resize(image_size, interpolation='nearest'), - K.augmentation.RandomHorizontalFlip(p=0.), # #TODO: bug somewhere with shapes + "post_tensor_transform": nn.Sequential( + ApplyTransformToKeys(['images', 'masks'], + nn.Sequential(K.geometry.Resize(image_size, interpolation='nearest'), )), + ApplyTransformToKeys(['images'], nn.Sequential(K.enhance.Normalize(0., 255.), )), ), } diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 2ff739243f..0c3100426e 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -50,7 +50,7 @@ class TestSemanticSegmentationPreprocess: @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") def test_smoke(self): - prep = SemantincSegmentationPreprocess() + prep = SemanticSegmentationPreprocess() assert prep is not None @@ -101,19 +101,19 @@ def test_from_filepaths(self, tmpdir): data = next(iter(dm.train_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) # check training data data = next(iter(dm.val_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) # check training data data = next(iter(dm.val_dataloader())) imgs, labels = data assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) def test_map_labels(self, tmpdir): tmp_dir = Path(tmpdir) @@ -132,13 +132,13 @@ def test_map_labels(self, tmpdir): str(tmp_dir / "labels_img3.png"), ] - map_labels: Dict[int, Tuple[int, int, int]] = { + labels_map: Dict[int, Tuple[int, int, int]] = { 0: [0, 0, 0], 1: [255, 255, 255], } img_size: Tuple[int, int] = (196, 196) - create_random_data(train_images, train_labels, img_size, map_labels) + create_random_data(train_images, train_labels, img_size, labels_map) # instantiate the data module @@ -149,7 +149,6 @@ def test_map_labels(self, tmpdir): val_labels=train_labels, batch_size=2, num_workers=0, - map_labels=map_labels ) assert dm is not None assert dm.train_dataloader() is not None @@ -159,7 +158,7 @@ def test_map_labels(self, tmpdir): dm.set_block_viz_window(False) assert dm.data_fetcher.block_viz_window is False - dm.set_map_labels(map_labels) + dm.set_labels_map(labels_map) dm.show_train_batch("load_sample") dm.show_train_batch("to_tensor_transform") @@ -169,10 +168,10 @@ def test_map_labels(self, tmpdir): assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 196, 196) assert labels.min().item() == 0 - assert labels.max().item() == 1 + assert labels.max().item() == 255. assert labels.dtype == torch.int64 # now train with `fast_dev_run` model = SemanticSegmentation(num_classes=2, backbone="torchvision/fcn_resnet50") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - trainer.finetune(model, dm, strategy="freeze_unfreeze") + #trainer.finetune(model, dm, strategy="freeze_unfreeze") From f268b62194558a430da8d8d577b529c38f85c4ce Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 16:33:06 +0200 Subject: [PATCH 19/53] relative path --- flash/vision/segmentation/data.py | 3 +-- flash_examples/finetuning/semantic_segmentation.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 2e93c18b74..9092b9b449 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -22,6 +22,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset +import flash.vision.segmentation.transforms as T from flash.core.classification import SegmentationLabels from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz @@ -30,8 +31,6 @@ from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE -from . import transforms as T - if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 7ed7fc13de..3ceace062e 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -57,7 +57,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: # 2.2 Visualise the samples labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) -datamodule.set_map_labels(labels_map) +datamodule.set_labels_map(labels_map) datamodule.show_train_batch("load_sample") datamodule.show_train_batch("post_tensor_transform") @@ -69,7 +69,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: # 4. Create the trainer. trainer = flash.Trainer( - max_epochs=20, + max_epochs=5, gpus=1, # precision=16, # why slower ? :) ) From 99b99f00748adedbd521fbb885b14746caf3f1bf Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 18:23:02 +0200 Subject: [PATCH 20/53] fix load from pretrained and add resnet 101 --- flash/vision/segmentation/model.py | 14 +++++++++++--- flash_examples/finetuning/semantic_segmentation.py | 9 +++++---- tests/vision/segmentation/test_model.py | 1 + 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index 527f2764bd..2785eb65a3 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -63,7 +63,7 @@ def __init__( num_classes: int, backbone: Union[str, Tuple[nn.Module, int]] = "torchvision/fcn_resnet50", backbone_kwargs: Optional[Dict] = None, - pretrained: bool = False, + pretrained: bool = True, loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, @@ -104,6 +104,14 @@ def forward(self, x) -> torch.Tensor: @SemanticSegmentation.backbones(name="torchvision/fcn_resnet50") -def fn(pretrained: bool, num_classes: int) -> nn.Module: - model: nn.Module = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained, num_classes=num_classes) +def load_torchvision_fcn_resnet50(pretrained: bool, num_classes: int) -> nn.Module: + model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained) + model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) + return model + + +@SemanticSegmentation.backbones(name="torchvision/fcn_resnet101") +def load_torchvision_fcn_resnet101(pretrained: bool, num_classes: int) -> nn.Module: + model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained) + model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) return model diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 3ceace062e..fb91be7039 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -53,13 +53,14 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: batch_size=4, val_split=0.3, # TODO: this needs to be implemented image_size=(300, 400), # (600, 800) + num_workers=0, ) # 2.2 Visualise the samples labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) datamodule.set_labels_map(labels_map) -datamodule.show_train_batch("load_sample") -datamodule.show_train_batch("post_tensor_transform") +#datamodule.show_train_batch("load_sample") +#datamodule.show_train_batch("post_tensor_transform") # 3. Build the model model = SemanticSegmentation( @@ -69,9 +70,9 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: # 4. Create the trainer. trainer = flash.Trainer( - max_epochs=5, + max_epochs=1, gpus=1, - # precision=16, # why slower ? :) + #precision=16, # why slower ? :) ) # 5. Train the model diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index 283d4e4a76..1b4d2d45e9 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -48,6 +48,7 @@ def test_forward(num_classes, img_shape): "backbone", [ "torchvision/fcn_resnet50", + "torchvision/fcn_resnet101", ], ) def test_init_train(tmpdir, backbone): From d1a91fda0ded743694a30299361fbeaa4de5ac94 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 20:41:57 +0200 Subject: [PATCH 21/53] create segmentation keys enum --- flash/core/classification.py | 55 ------------ flash/vision/segmentation/data.py | 48 +++++----- flash/vision/segmentation/model.py | 3 +- flash/vision/segmentation/serialization.py | 87 +++++++++++++++++++ flash/vision/segmentation/transforms.py | 16 ++-- .../finetuning/semantic_segmentation.py | 8 +- 6 files changed, 124 insertions(+), 93 deletions(-) create mode 100644 flash/vision/segmentation/serialization.py diff --git a/flash/core/classification.py b/flash/core/classification.py index 53a6f39fa9..f5fcfaa138 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -14,9 +14,6 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union -# for visualisation -import kornia as K -import matplotlib.pyplot as plt import torch import torch.nn.functional as F import torchmetrics @@ -167,55 +164,3 @@ def serialize(self, sample: Any) -> Union[int, List[int], str, List[str]]: "No ClassificationState was found, this serializer will act as a Classes serializer.", UserWarning ) return classes - - -class SegmentationLabels(Serializer): - - def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualise: bool = False): - """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification - per pixel in the image for semantic segmentation tasks. - - Args: - labels_map: A dictionary that map the labels ids to pixel intensities. - visualise: Wether to visualise the image labels. - """ - super().__init__() - self.labels_map = labels_map - self.visualise = visualise - - @staticmethod - def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor: - """Function that given an image with labels ids and their pixels intrensity mapping, - creates a RGB representation for visualisation purposes. - """ - assert len(img_labels.shape) == 2, img_labels.shape - H, W = img_labels.shape - out = torch.empty(3, H, W, dtype=torch.uint8) - for label_id, label_val in labels_map.items(): - mask = (img_labels == label_id) - for i in range(3): - out[i].masked_fill_(mask, label_val[i]) - return out - - @staticmethod - def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: - labels_map: Dict[int, Tuple[int, int, int]] = {} - for i in range(num_classes): - labels_map[i] = torch.randint(0, 255, (3, )) - return labels_map - - def serialize(self, sample: torch.Tensor) -> torch.Tensor: - assert len(sample.shape) == 3, sample.shape - labels = torch.argmax(sample, dim=-3) # HxW - if self.visualise: - if self.labels_map is None: - # create random colors map - num_classes = sample.shape[-3] - labels_map = self.create_random_labels_map(num_classes) - else: - labels_map = self.labels_map - labels_vis = self.labels_to_image(labels, labels_map) - labels_vis = K.utils.tensor_to_image(labels_vis) - plt.imshow(labels_vis) - plt.show() - return labels diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 9092b9b449..8584e537ee 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -23,13 +23,13 @@ from torch.utils.data import Dataset import flash.vision.segmentation.transforms as T -from flash.core.classification import SegmentationLabels from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE +from flash.vision.segmentation.serialization import SegmentationKeys, SegmentationLabels if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -89,8 +89,7 @@ def _resolve_transforms( predict_transform, ) - def load_sample(self, sample: Union[str, Tuple[str, - str]]) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + def load_sample(self, sample: Union[str, Tuple[str, str]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: if not isinstance(sample, ( str, tuple, @@ -108,9 +107,13 @@ def load_sample(self, sample: Union[str, Tuple[str, img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW - return {'images': img, 'masks': img_labels} + return {SegmentationKeys.IMAGES: img, SegmentationKeys.MASKS: img_labels} - def post_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + # TODO: this routine should be moved to `per_batch_transform` once we have a way to + # forward the labels to the loss function:. + def post_tensor_transform( + self, sample: Union[torch.Tensor, Dict[str, torch.Tensor]] + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if isinstance(sample, torch.Tensor): # case for predict out = sample.float() / 255. # TODO: define predict transforms return out @@ -119,27 +122,15 @@ def post_tensor_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tu raise TypeError(f"Invalid type, expected `dict`. Got: {sample}.") # arrange data as floating point and batch before the augmentations - sample['images'] = sample['images'][None].float().contiguous() # 1xCxHxW - sample['masks'] = sample['masks'][None, :1].float().contiguous() # 1x1xHxW + sample[SegmentationKeys.IMAGES] = sample[SegmentationKeys.IMAGES][None].float().contiguous() # 1xCxHxW + sample[SegmentationKeys.MASKS] = sample[SegmentationKeys.MASKS][None, :1].float().contiguous() # 1x1xHxW out: Dict[str, torch.Tensor] = self.current_transform(sample) - return out['images'][0], out['masks'][0, 0].long() + return out[SegmentationKeys.IMAGES][0], out[SegmentationKeys.MASKS][0, 0].long() - # TODO: the labels are not clear how to forward to the loss once are transform from this point - '''def per_batch_transform(self, sample: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - if not isinstance(sample, list): - raise TypeError(f"Invalid type, expected `tuple`. Got: {sample}.") - img, img_labels = sample - # THIS IS CRASHING - # out1 = self.current_transform(img) # images - # out2 = self.current_transform(img_labels) # labels - # return out1, out2 - return img, img_labels - - # TODO: the labels are not clear how to forward to the loss once are transform from this point - def per_batch_transform_on_device(self, sample: Any) -> Any: - pass''' + # TODO: implement `per_batch_transform` and `per_batch_transform_on_device` + ## class SemanticSegmentationData(DataModule): @@ -154,7 +145,7 @@ def _check_valid_filepaths(filepaths: List[str]): @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - return _MatplotlibVisualization(*args, **kwargs) + return SegmentationMatplotlibVisualization(*args, **kwargs) def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]): self.data_fetcher.labels_map = labels_map @@ -249,12 +240,15 @@ def from_filepaths( ) -class _MatplotlibVisualization(BaseVisualization): +class SegmentationMatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. """ - max_cols: int = 4 # maximum number of columns we accept - block_viz_window: bool = True # parameter to allow user to block visualisation windows - labels_map: Dict[int, Tuple[int, int, int]] = {} + + def __init__(self): + super().__init__(self) + self.max_cols: int = 4 # maximum number of columns we accept + self.block_viz_window: bool = True # parameter to allow user to block visualisation windows + self.labels_map: Dict[int, Tuple[int, int, int]] = {} @staticmethod def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index 2785eb65a3..acc5a8d609 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -19,10 +19,11 @@ from torch.nn import functional as F from torchmetrics import Accuracy, IoU -from flash.core.classification import ClassificationTask, SegmentationLabels +from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry from flash.data.process import Preprocess, Serializer from flash.utils.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.vision.segmentation.serialization import SegmentationLabels if _TORCHVISION_AVAILABLE: import torchvision diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py new file mode 100644 index 0000000000..907ed937c0 --- /dev/null +++ b/flash/vision/segmentation/serialization.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum +from typing import Dict, Optional, Tuple + +import torch + +from flash.data.process import ProcessState, Serializer +from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE + +if _MATPLOTLIB_AVAILABLE: + import matplotlib.pyplot as plt +else: + plt = None + +if _KORNIA_AVAILABLE: + import kornia as K +else: + K = None + + +class SegmentationKeys(Enum): + IMAGES = 'images' + MASKS = 'masks' + + +class SegmentationLabels(Serializer): + + def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False): + """A :class:`.Serializer` which converts the model outputs to the label of the argmax classification + per pixel in the image for semantic segmentation tasks. + + Args: + labels_map: A dictionary that map the labels ids to pixel intensities. + visualise: Wether to visualise the image labels. + """ + super().__init__() + self.labels_map = labels_map + self.visualize = visualize + + @staticmethod + def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, int, int]]) -> torch.Tensor: + """Function that given an image with labels ids and their pixels intrensity mapping, + creates a RGB representation for visualisation purposes. + """ + assert len(img_labels.shape) == 2, img_labels.shape + H, W = img_labels.shape + out = torch.empty(3, H, W, dtype=torch.uint8) + for label_id, label_val in labels_map.items(): + mask = (img_labels == label_id) + for i in range(3): + out[i].masked_fill_(mask, label_val[i]) + return out + + @staticmethod + def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: + labels_map: Dict[int, Tuple[int, int, int]] = {} + for i in range(num_classes): + labels_map[i] = torch.randint(0, 255, (3, )) + return labels_map + + def serialize(self, sample: torch.Tensor) -> torch.Tensor: + assert len(sample.shape) == 3, sample.shape + labels = torch.argmax(sample, dim=-3) # HxW + if self.visualize: + if self.labels_map is None: + # create random colors map + num_classes = sample.shape[-3] + labels_map = self.create_random_labels_map(num_classes) + else: + labels_map = self.labels_map + labels_vis = self.labels_to_image(labels, labels_map) + labels_vis = K.utils.tensor_to_image(labels_vis) + plt.imshow(labels_vis) + plt.show() + return labels diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index b02778dd91..9c56898dcf 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -17,6 +17,8 @@ import torch import torch.nn as nn +from flash.vision.segmentation.serialization import SegmentationKeys + class ApplyTransformToKeys(nn.Sequential): @@ -29,8 +31,10 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: for seq in self.children(): for aug in seq: for key in self.keys: - # check wether the transform was applied - # and apply transform + # kornia caches the random parameters in `_params` after every + # forward call in the augmentation object. We check whether the + # random parameters have been generated so that we can apply the + # same parameters to images and masks. if hasattr(aug, "_params") and bool(aug._params): params = aug._params x[key] = aug(x[key], params) @@ -42,13 +46,13 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { "post_tensor_transform": nn.Sequential( - ApplyTransformToKeys(['images', 'masks'], + ApplyTransformToKeys([SegmentationKeys.IMAGES, SegmentationKeys.MASKS], nn.Sequential( K.geometry.Resize(image_size, interpolation='nearest'), K.augmentation.RandomHorizontalFlip(p=0.75), )), ApplyTransformToKeys( - ['images'], + [SegmentationKeys.IMAGES], nn.Sequential( K.enhance.Normalize(0., 255.), K.augmentation.ColorJitter(0.4, p=0.5), @@ -63,8 +67,8 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { "post_tensor_transform": nn.Sequential( - ApplyTransformToKeys(['images', 'masks'], + ApplyTransformToKeys([SegmentationKeys.IMAGES, SegmentationKeys.MASKS], nn.Sequential(K.geometry.Resize(image_size, interpolation='nearest'), )), - ApplyTransformToKeys(['images'], nn.Sequential(K.enhance.Normalize(0., 255.), )), + ApplyTransformToKeys([SegmentationKeys.IMAGES], nn.Sequential(K.enhance.Normalize(0., 255.), )), ), } diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index fb91be7039..00e1378cf4 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -18,9 +18,9 @@ import torch import flash -from flash.core.classification import SegmentationLabels from flash.data.utils import download_data from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess +from flash.vision.segmentation.serialization import SegmentationLabels # 1. Download the data # This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator. @@ -50,7 +50,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: datamodule = SemanticSegmentationData.from_filepaths( train_filepaths=images_filepaths, train_labels=labels_filepaths, - batch_size=4, + batch_size=2, val_split=0.3, # TODO: this needs to be implemented image_size=(300, 400), # (600, 800) num_workers=0, @@ -70,7 +70,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: # 4. Create the trainer. trainer = flash.Trainer( - max_epochs=1, + max_epochs=10, gpus=1, #precision=16, # why slower ? :) ) @@ -79,7 +79,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: trainer.finetune(model, datamodule=datamodule, strategy='freeze') # 6. Predict what's on a few images! -model.serializer = SegmentationLabels(labels_map, visualise=True) +model.serializer = SegmentationLabels(labels_map, visualize=True) predictions = model.predict([ 'data/CameraRGB/F61-1.png', From 7343887ac780a2cd16a8d0b0f1d1d05bbe7d4b1d Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Wed, 28 Apr 2021 21:40:59 +0200 Subject: [PATCH 22/53] sync with master and fix val_split --- flash/vision/segmentation/data.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 8584e537ee..ebc6e67cb6 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -56,11 +56,29 @@ def __init__( predict_transform: Dictionary with the set of transforms to apply during prediction. image_size: A tuple with the expected output image size. """ + self.image_size = image_size + train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( train_transform, val_transform, test_transform, predict_transform, image_size ) super().__init__(train_transform, val_transform, test_transform, predict_transform) + # TODO: this is kind of boilerplate, let's simplify + def get_state_dict(self) -> Dict[str, Any]: + return { + "train_transform": self._train_transform, + "val_transform": self._val_transform, + "test_transform": self._test_transform, + "predict_transform": self._predict_transform, + "image_size": self.image_size + } + + # TODO: this is kind of boilerplate, let's simplify + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + return cls(**state_dict) + + # TODO: this is kind of boilerplate, let's simplify def _resolve_transforms( self, train_transform: Optional[Union[str, Dict]] = None, @@ -171,7 +189,6 @@ def from_filepaths( image_size: Tuple[int, int] = (196, 196), batch_size: int = 64, num_workers: Optional[int] = None, - #seed: Optional[int] = 42, # SEED NEVER USED data_fetcher: BaseDataFetcher = None, preprocess: Optional[Preprocess] = None, val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED @@ -229,13 +246,12 @@ def from_filepaths( train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, - # predict_load_data_input=predict_filepaths, # TODO: is it really used ? + predict_load_data_input=predict_filepaths, # TODO: is it really used ? batch_size=batch_size, num_workers=num_workers, data_fetcher=data_fetcher, preprocess=preprocess, - #seed=seed, # THIS CRASHES - #val_split=val_split, # THIS CRASHES + val_split=val_split, **kwargs, # TODO: remove and make explicit params ) From febe7f02d50b3af0be288ad59f4112701aabee29 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 29 Apr 2021 11:18:29 +0200 Subject: [PATCH 23/53] move apart segmentation backbones --- flash/vision/segmentation/backbones.py | 36 ++++++++++++++++++++++++ flash/vision/segmentation/model.py | 38 +++++++++++--------------- 2 files changed, 52 insertions(+), 22 deletions(-) create mode 100644 flash/vision/segmentation/backbones.py diff --git a/flash/vision/segmentation/backbones.py b/flash/vision/segmentation/backbones.py new file mode 100644 index 0000000000..b8f4519bae --- /dev/null +++ b/flash/vision/segmentation/backbones.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch.nn as nn + +from flash.core.registry import FlashRegistry +from flash.utils.imports import _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + import torchvision + +SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") + + +@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50") +def load_torchvision_fcn_resnet50(pretrained: bool, num_classes: int) -> nn.Module: + model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained) + model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) + return model + + +@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101") +def load_torchvision_fcn_resnet101(pretrained: bool, num_classes: int) -> nn.Module: + model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained) + model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) + return model diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index acc5a8d609..c3ba135479 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -22,14 +22,9 @@ from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry from flash.data.process import Preprocess, Serializer -from flash.utils.imports import _TIMM_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.vision.segmentation.serialization import SegmentationLabels -if _TORCHVISION_AVAILABLE: - import torchvision - -SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") - class SemanticSegmentation(ClassificationTask): """Task that performs semantic segmentation on images. @@ -46,7 +41,8 @@ class SemanticSegmentation(ClassificationTask): Args: num_classes: Number of classes to classify. - backbone: A string or (model, num_features) tuple to use to compute image features, defaults to ``"torchvision/fcn_resnet50"``. + backbone: A string or (model, num_features) tuple to use to compute image features, + defaults to ``"torchvision/fcn_resnet50"``. backbone_kwargs: Additional arguments for the backbone configuration. pretrained: Use a pretrained backbone, defaults to ``False``. loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. @@ -101,18 +97,16 @@ def __init__( self.backbone = self.backbones.get(backbone)(pretrained=pretrained, num_classes=num_classes, **backbone_kwargs) def forward(self, x) -> torch.Tensor: - return self.backbone(x)['out'] # TODO: find a proper way to get 'out' from registry - - -@SemanticSegmentation.backbones(name="torchvision/fcn_resnet50") -def load_torchvision_fcn_resnet50(pretrained: bool, num_classes: int) -> nn.Module: - model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained) - model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) - return model - - -@SemanticSegmentation.backbones(name="torchvision/fcn_resnet101") -def load_torchvision_fcn_resnet101(pretrained: bool, num_classes: int) -> nn.Module: - model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained) - model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) - return model + # infer the image to the model + res: Union[torch.Tensor, Dict[str, torch.Tensor]] = self.backbone(x) + + # some frameworks like torchvision return a dict. + # In particular, torchvision segmentation models return the output logits + # in the key `out`. + out: torch.Tensor + if isinstance(res, dict): + out = res['out'] + else: + raise NotImplementedError(f"Unsupported output type: {type(out)}") + + return out From 248145bea9d720b4cb9c1b45a339c4d7f76200d8 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 29 Apr 2021 13:41:08 +0200 Subject: [PATCH 24/53] fix tests --- flash/vision/segmentation/data.py | 4 ++-- tests/vision/segmentation/test_data.py | 29 ++++++++++---------------- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index ebc6e67cb6..f24aff61eb 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -293,8 +293,8 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) # unpack images and labels sample = data[i] if isinstance(sample, dict): - image = sample['images'] - label = sample['masks'] + image = sample[SegmentationKeys.IMAGES] + label = sample[SegmentationKeys.MASKS] elif isinstance(sample, tuple): image = sample[0] label = sample[1] diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 0c3100426e..432b972ac4 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -23,27 +23,18 @@ def _rand_image(size: Tuple[int, int]): # usually labels come as rgb images -> need to map to labels -def _rand_labels(size: Tuple[int, int], map_labels: Dict[int, Tuple[int, int, int]] = None): - data: np.ndarray = np.random.rand(*size, 3) / .5 - if map_labels is not None: - data_bin = (data.mean(-1) > 0.5) - for k, v in map_labels.items(): - mask = (data_bin == k) - data[mask] = v +def _rand_labels(size: Tuple[int, int], num_classes: int): + data: np.ndarray = np.random.randint(0, num_classes, (*size, 1)) + data = data.repeat(3, axis=-1) return Image.fromarray(data.astype(np.uint8)) -def create_random_data( - image_files: List[str], - label_files: List[str], - size: Tuple[int, int], - map_labels: Optional[Dict[int, Tuple[int, int, int]]] = None, -): +def create_random_data(image_files: List[str], label_files: List[str], size: Tuple[int, int], num_classes: int): for img_file in image_files: _rand_image(size).save(img_file) for label_file in label_files: - _rand_labels(size, map_labels).save(label_file) + _rand_labels(size, num_classes).save(label_file) class TestSemanticSegmentationPreprocess: @@ -77,8 +68,9 @@ def test_from_filepaths(self, tmpdir): str(tmp_dir / "labels_img3.png"), ] + num_classes: int = 2 img_size: Tuple[int, int] = (196, 196) - create_random_data(train_images, train_labels, img_size) + create_random_data(train_images, train_labels, img_size, num_classes) # instantiate the data module @@ -137,8 +129,9 @@ def test_map_labels(self, tmpdir): 1: [255, 255, 255], } + num_classes: int = len(labels_map.keys()) img_size: Tuple[int, int] = (196, 196) - create_random_data(train_images, train_labels, img_size, labels_map) + create_random_data(train_images, train_labels, img_size, num_classes) # instantiate the data module @@ -168,10 +161,10 @@ def test_map_labels(self, tmpdir): assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 196, 196) assert labels.min().item() == 0 - assert labels.max().item() == 255. + assert labels.max().item() == 1 assert labels.dtype == torch.int64 # now train with `fast_dev_run` model = SemanticSegmentation(num_classes=2, backbone="torchvision/fcn_resnet50") trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - #trainer.finetune(model, dm, strategy="freeze_unfreeze") + trainer.finetune(model, dm, strategy="freeze_unfreeze") From 6d635dbe35308c5c7881c7f002a0af50b95438a4 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 29 Apr 2021 13:42:13 +0200 Subject: [PATCH 25/53] fix tests --- .../finetuning/semantic_segmentation.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 00e1378cf4..8703ff2f2b 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -50,17 +50,17 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: datamodule = SemanticSegmentationData.from_filepaths( train_filepaths=images_filepaths, train_labels=labels_filepaths, - batch_size=2, - val_split=0.3, # TODO: this needs to be implemented + batch_size=4, + val_split=0.3, image_size=(300, 400), # (600, 800) - num_workers=0, + num_workers=4, ) # 2.2 Visualise the samples labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) datamodule.set_labels_map(labels_map) -#datamodule.show_train_batch("load_sample") -#datamodule.show_train_batch("post_tensor_transform") +datamodule.show_train_batch("load_sample") +datamodule.show_train_batch("post_tensor_transform") # 3. Build the model model = SemanticSegmentation( @@ -70,9 +70,8 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: # 4. Create the trainer. trainer = flash.Trainer( - max_epochs=10, - gpus=1, - #precision=16, # why slower ? :) + max_epochs=2, + gpus=0, ) # 5. Train the model From ca970342393c01bc51b78beb7eb7da819ad4da8a Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 29 Apr 2021 13:48:19 +0200 Subject: [PATCH 26/53] fix tests --- flash/vision/segmentation/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index 9c56898dcf..66491afe2b 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -50,7 +50,7 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] nn.Sequential( K.geometry.Resize(image_size, interpolation='nearest'), K.augmentation.RandomHorizontalFlip(p=0.75), - )), + )), # noqa: E126 ApplyTransformToKeys( [SegmentationKeys.IMAGES], nn.Sequential( From da8325105eeccb47f6676a2a97d821fd29eac44b Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 29 Apr 2021 20:48:44 +0200 Subject: [PATCH 27/53] fix memory leak issues --- flash/data/callback.py | 5 ++++- flash/data/data_module.py | 7 ++++++- flash/vision/segmentation/data.py | 19 +++++++++++++------ .../finetuning/semantic_segmentation.py | 5 ++--- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/flash/data/callback.py b/flash/data/callback.py index a479a6e59e..9eedbad0c7 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -153,8 +153,11 @@ def __init__(self, enabled: bool = False): self._preprocess = None self.reset() + self._func_names = [] # store the functions to cache the data + def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None: - if self.enabled: + # we check om `func_names` to prevent from memory overflow issues + if self.enabled and (fn_name in self._func_names): store = self.batches[_STAGES_PREFIX[running_stage]] store.setdefault(fn_name, []) store[fn_name].append(data) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index bcb3787268..6967a712b0 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -168,6 +168,11 @@ def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], re if isinstance(func_names, str): func_names = [func_names] + # the data fetcher tries to store all the data batches per stage, + # we store the function names that we want to visualize/cache to + # avoid duplicity in the visualization and reduce the memory footprint. + self.data_fetcher._func_names = func_names + iter_dataloader = getattr(self, iter_name) with self.data_fetcher.enable(): try: @@ -178,7 +183,7 @@ def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], re data_fetcher: BaseVisualization = self.data_fetcher data_fetcher._show(stage, func_names) if reset: - self.viz.batches[stage] = {} + self.data_fetcher.batches[stage] = {} def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None: """This function is used to visualize a batch from the train dataloader.""" diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index f24aff61eb..499c26af13 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -107,7 +107,8 @@ def _resolve_transforms( predict_transform, ) - def load_sample(self, sample: Union[str, Tuple[str, str]]) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + def load_sample(self, sample: Union[str, Tuple[str, + str]]) -> Union[torch.Tensor, Dict[SegmentationKeys, torch.Tensor]]: if not isinstance(sample, ( str, tuple, @@ -124,13 +125,14 @@ def load_sample(self, sample: Union[str, Tuple[str, str]]) -> Union[torch.Tensor # load images directly to torch tensors img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW + img_labels = img_labels[0] # HxW return {SegmentationKeys.IMAGES: img, SegmentationKeys.MASKS: img_labels} # TODO: this routine should be moved to `per_batch_transform` once we have a way to # forward the labels to the loss function:. def post_tensor_transform( - self, sample: Union[torch.Tensor, Dict[str, torch.Tensor]] + self, sample: Union[torch.Tensor, Dict[SegmentationKeys, torch.Tensor]] ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if isinstance(sample, torch.Tensor): # case for predict out = sample.float() / 255. # TODO: define predict transforms @@ -139,11 +141,16 @@ def post_tensor_transform( if not isinstance(sample, dict): raise TypeError(f"Invalid type, expected `dict`. Got: {sample}.") - # arrange data as floating point and batch before the augmentations - sample[SegmentationKeys.IMAGES] = sample[SegmentationKeys.IMAGES][None].float().contiguous() # 1xCxHxW - sample[SegmentationKeys.MASKS] = sample[SegmentationKeys.MASKS][None, :1].float().contiguous() # 1x1xHxW + # pass to the transforms a dictionary with copies to handle potential memory leaks + sample_in: Dict[SegmentationKeys, torch.Tensor] = {} + sample_in[SegmentationKeys.IMAGES] = ( + sample[SegmentationKeys.IMAGES][None].float().contiguous().clone() # 1xCxHxW + ) + sample_in[SegmentationKeys.MASKS] = ( + sample[SegmentationKeys.MASKS][None, None].float().contiguous().clone() # 1x1xHxW + ) - out: Dict[str, torch.Tensor] = self.current_transform(sample) + out: Dict[SegmentationKeys, torch.Tensor] = self.current_transform(sample_in) return out[SegmentationKeys.IMAGES][0], out[SegmentationKeys.MASKS][0, 0].long() diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 8703ff2f2b..f11fd9ac24 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -52,15 +52,14 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: train_labels=labels_filepaths, batch_size=4, val_split=0.3, - image_size=(300, 400), # (600, 800) + image_size=(200, 200), # (600, 800) num_workers=4, ) # 2.2 Visualise the samples labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) datamodule.set_labels_map(labels_map) -datamodule.show_train_batch("load_sample") -datamodule.show_train_batch("post_tensor_transform") +datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) # 3. Build the model model = SemanticSegmentation( From 2ef8c88587c06c5ddad9a0952049837d55c7bd13 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 29 Apr 2021 21:23:34 +0200 Subject: [PATCH 28/53] undo function filtering --- flash/data/callback.py | 5 +---- flash/data/data_module.py | 5 ----- flash_examples/finetuning/semantic_segmentation.py | 2 +- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/flash/data/callback.py b/flash/data/callback.py index 9eedbad0c7..a479a6e59e 100644 --- a/flash/data/callback.py +++ b/flash/data/callback.py @@ -153,11 +153,8 @@ def __init__(self, enabled: bool = False): self._preprocess = None self.reset() - self._func_names = [] # store the functions to cache the data - def _store(self, data: Any, fn_name: str, running_stage: RunningStage) -> None: - # we check om `func_names` to prevent from memory overflow issues - if self.enabled and (fn_name in self._func_names): + if self.enabled: store = self.batches[_STAGES_PREFIX[running_stage]] store.setdefault(fn_name, []) store[fn_name].append(data) diff --git a/flash/data/data_module.py b/flash/data/data_module.py index 6967a712b0..aaee7f7bc7 100644 --- a/flash/data/data_module.py +++ b/flash/data/data_module.py @@ -168,11 +168,6 @@ def _show_batch(self, stage: RunningStage, func_names: Union[str, List[str]], re if isinstance(func_names, str): func_names = [func_names] - # the data fetcher tries to store all the data batches per stage, - # we store the function names that we want to visualize/cache to - # avoid duplicity in the visualization and reduce the memory footprint. - self.data_fetcher._func_names = func_names - iter_dataloader = getattr(self, iter_name) with self.data_fetcher.enable(): try: diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index f11fd9ac24..ddea8d626e 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -69,7 +69,7 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: # 4. Create the trainer. trainer = flash.Trainer( - max_epochs=2, + max_epochs=1, gpus=0, ) From 87a92f383f4c2d98f099b6cc4ed13c34a37cafa3 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 29 Apr 2021 21:37:40 +0200 Subject: [PATCH 29/53] fix import --- flash/vision/segmentation/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 499c26af13..601b7d4c08 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -22,13 +22,13 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset -import flash.vision.segmentation.transforms as T from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE +from flash.vision.segmentation import transforms as T from flash.vision.segmentation.serialization import SegmentationKeys, SegmentationLabels if _MATPLOTLIB_AVAILABLE: From 73d462bc71862ab3f458b7898e7e03904117b12d Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Thu, 29 Apr 2021 22:06:20 +0200 Subject: [PATCH 30/53] more fixes for memory leaks --- flash/data/batch.py | 6 ++++++ flash/vision/segmentation/data.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/flash/data/batch.py b/flash/data/batch.py index ea6ce1e9ca..a738ed27c3 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -136,6 +136,12 @@ def __init__( self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess) def forward(self, samples: Sequence[Any]) -> Any: + # we create a new dict to prevent from potential memory leaks + # assuming that the dictionary samples are stored in between and + # potentially modified before the transforms are applied. + if isinstance(samples, dict): + samples = dict(samples.items()) + with self._current_stage_context: if self.apply_per_sample_transform: diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 601b7d4c08..be6363fca8 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -144,10 +144,10 @@ def post_tensor_transform( # pass to the transforms a dictionary with copies to handle potential memory leaks sample_in: Dict[SegmentationKeys, torch.Tensor] = {} sample_in[SegmentationKeys.IMAGES] = ( - sample[SegmentationKeys.IMAGES][None].float().contiguous().clone() # 1xCxHxW + sample[SegmentationKeys.IMAGES][None].float().contiguous() # 1xCxHxW ) sample_in[SegmentationKeys.MASKS] = ( - sample[SegmentationKeys.MASKS][None, None].float().contiguous().clone() # 1x1xHxW + sample[SegmentationKeys.MASKS][None, None].float().contiguous() # 1x1xHxW ) out: Dict[SegmentationKeys, torch.Tensor] = self.current_transform(sample_in) From 8b971a4e008565efcf014abfeedc4616bee492ff Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Fri, 30 Apr 2021 12:18:41 +0200 Subject: [PATCH 31/53] add segmentation to docs --- docs/source/index.rst | 1 + .../reference/semantic_segmentation.rst | 150 ++++++++++++++++++ .../finetuning/semantic_segmentation.py | 5 +- 3 files changed, 153 insertions(+), 3 deletions(-) create mode 100644 docs/source/reference/semantic_segmentation.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 5cc7636482..89f6d5291a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -28,6 +28,7 @@ Lightning Flash reference/tabular_classification reference/translation reference/object_detection + reference/semantic_segmentation .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst new file mode 100644 index 0000000000..fe8b543b27 --- /dev/null +++ b/docs/source/reference/semantic_segmentation.rst @@ -0,0 +1,150 @@ + +.. _semantinc_segmentation: + +###################### +Semantinc Segmentation +###################### + +******** +The task +******** +Semantic segmentation, or image segmentation, is the task of clustering parts of an image together which belong to the same object class. It is a form of pixel-level prediction because each pixel in an image is classified according to a category + +------ + +********* +Inference +********* + +The :class:`~flash.vision.SemanticSegmentation` is already pre-trained on a generated dataset from `CARLA `_ simulator. + + +Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inference on any string sequence using :func:`~flash.vision.SemanticSegmentation.predict`: + +.. code-block:: python + + # import our libraries + from flash import Trainer + from flash.data.utils import download_data + from flash.vision import SemanticSegmentation, SemanticSegmentationData + from flash.vision.segmentation.serialization import SegmentationLabels + + # 1. Download the data + download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", + "data/" + ) + + # 2. Load the model from a checkpoint + model = SemanticSegmentation.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" + ) + model.serializer = SegmentationLabels(visualize=True) + + # 3. Predict what's on a few images and visualize! + predictions = model.predict([ + 'data/CameraRGB/F61-1.png', + 'data/CameraRGB/F62-1.png', + 'data/CameraRGB/F63-1.png', + ]) + +For more advanced inference options, see :ref:`predictions`. + +------ + +********** +Finetuning +********** + +you now want to customise your model with new data using the same dataset. +Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.SemanticSegmentationData`. + +.. note:: the dataset is structured in a way that each sample (an image and its corresponding labels) is stored in separated directories but keeping the same filename. + +.. code-block:: + + data + ├── CameraRGB + │ ├── F61-1.png + │ ├── F61-2.png + │ ... + └── CameraSeg + ├── F61-1.png + ├── F61-2.png + ... + + +Now all we need is three lines of code to build to train our task! + +.. code-block:: python + + import os + import flash + from flash.data.utils import download_data + from flash.vision import SemanticSegmentation, SemanticSegmentationData + + # 1. Download the data + download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", + "data/" + ) + + # 2.1 Load the data + + + def load_data(data_root: str = 'data/'): + images = [] + labels = [] + rgb_path = os.path.join(data_root, "CameraRGB") + seg_path = os.path.join(data_root, "CameraSeg") + for fname in os.listdir(rgb_path): + images.append(os.path.join(rgb_path, fname)) + labels.append(os.path.join(seg_path, fname)) + return images, labels + + + images_filepaths, labels_filepaths = load_data() + + # create the data module + datamodule = SemanticSegmentationData.from_filepaths( + train_filepaths=images_filepaths, + train_labels=labels_filepaths, + ) + + # 3. Build the model + model = SemanticSegmentation(backbone="torchvision/fcn_resnet50", num_classes=21) + + # 4. Create the trainer. + trainer = flash.Trainer(max_epochs=1) + + # 5. Train the model + trainer.finetune(model, datamodule=datamodule, strategy='freeze') + + # 7. Save it! + trainer.save_checkpoint("semantic_segmentation_model.pt") + +------ + +************* +API reference +************* + +.. _segmentation: + +SemanticSegmentation +-------------------- + +.. autoclass:: flash.vision.SemanticSegmentation + :members: + :exclude-members: forward + +.. _segmentation_data: + +SemanticSegmentationData +------------------------ + +.. autoclass:: flash.vision.SemanticSegmentationData + +.. automethod:: flash.vision.SemanticSegmentationData.from_filepaths + +.. autoclass:: flash.vision.SemanticSegmentationPreprocess diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index ddea8d626e..1c8d94d5ee 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -14,12 +14,11 @@ import os from typing import List, Tuple -import pandas as pd import torch import flash from flash.data.utils import download_data -from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess +from flash.vision import SemanticSegmentation, SemanticSegmentationData from flash.vision.segmentation.serialization import SegmentationLabels # 1. Download the data @@ -27,7 +26,7 @@ # The data was generated as part of the Lyft Udacity Challenge. # More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge download_data( - 'https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip', "data/" + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/" ) # 2.1 Load the data From 69358a63dea043eaccd67489fa187311862b0097 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Fri, 30 Apr 2021 12:22:56 +0200 Subject: [PATCH 32/53] add inference example --- .../reference/semantic_segmentation.rst | 2 +- .../predict/semantic_segmentation.py | 38 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 flash_examples/predict/semantic_segmentation.py diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index fe8b543b27..0425b50619 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -26,7 +26,7 @@ Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inferen # import our libraries from flash import Trainer from flash.data.utils import download_data - from flash.vision import SemanticSegmentation, SemanticSegmentationData + from flash.vision import SemanticSegmentation from flash.vision.segmentation.serialization import SegmentationLabels # 1. Download the data diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py new file mode 100644 index 0000000000..d1d20830c5 --- /dev/null +++ b/flash_examples/predict/semantic_segmentation.py @@ -0,0 +1,38 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.data.utils import download_data +from flash.vision import SemanticSegmentation +from flash.vision.segmentation.serialization import SegmentationLabels + +# 1. Download the data +# This is a Dataset with Semantic Segmentation Labels generated via CARLA self-driving simulator. +# The data was generated as part of the Lyft Udacity Challenge. +# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge +download_data( + "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip", "data/" +) + +# 2. Load the model from a checkpoint +model = SemanticSegmentation.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" +) +model.serializer = SegmentationLabels(visualize=True) + +# 3. Predict what's on a few images and visualize! +predictions = model.predict([ + 'data/CameraRGB/F61-1.png', + 'data/CameraRGB/F62-1.png', + 'data/CameraRGB/F63-1.png', +]) From caabfb6d67e90815cf8731d1a56c5a6daf6ea049 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Fri, 30 Apr 2021 12:52:48 +0200 Subject: [PATCH 33/53] add image to docs and update with AdamW --- docs/source/reference/semantic_segmentation.rst | 10 ++++++++++ flash/vision/segmentation/model.py | 4 ++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index 0425b50619..d6b3518c03 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -10,6 +10,16 @@ The task ******** Semantic segmentation, or image segmentation, is the task of clustering parts of an image together which belong to the same object class. It is a form of pixel-level prediction because each pixel in an image is classified according to a category +See more: https://paperswithcode.com/task/semantic-segmentation + +.. raw:: html + +

+ + + +

+ ------ ********* diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index c3ba135479..ad0c2aed5a 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -46,7 +46,7 @@ class SemanticSegmentation(ClassificationTask): backbone_kwargs: Additional arguments for the backbone configuration. pretrained: Use a pretrained backbone, defaults to ``False``. loss_fn: Loss function for training, defaults to :func:`torch.nn.functional.cross_entropy`. - optimizer: Optimizer to use for training, defaults to :class:`torch.optim.Adam`. + optimizer: Optimizer to use for training, defaults to :class:`torch.optim.AdamW`. metrics: Metrics to compute for training and evaluation, defaults to :class:`torchmetrics.IoU`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. multi_label: Whether the targets are multi-label or not. @@ -62,7 +62,7 @@ def __init__( backbone_kwargs: Optional[Dict] = None, pretrained: bool = True, loss_fn: Optional[Callable] = None, - optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer: Type[torch.optim.Optimizer] = torch.optim.AdamW, metrics: Optional[Union[Callable, Mapping, Sequence, None]] = None, learning_rate: float = 1e-3, multi_label: bool = False, From e8e92d190a77cf770cd6149595304446ca22ff5a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 19:07:39 +0100 Subject: [PATCH 34/53] Make pretrained arg kwarg --- flash/vision/segmentation/backbones.py | 4 ++-- flash/vision/segmentation/model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/vision/segmentation/backbones.py b/flash/vision/segmentation/backbones.py index b8f4519bae..2a1661be6c 100644 --- a/flash/vision/segmentation/backbones.py +++ b/flash/vision/segmentation/backbones.py @@ -23,14 +23,14 @@ @SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50") -def load_torchvision_fcn_resnet50(pretrained: bool, num_classes: int) -> nn.Module: +def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module: model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained) model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) return model @SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101") -def load_torchvision_fcn_resnet101(pretrained: bool, num_classes: int) -> nn.Module: +def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module: model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained) model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) return model diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index ad0c2aed5a..ab8b121844 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -94,7 +94,7 @@ def __init__( backbone_kwargs = {} # TODO: pretrained to True causes some issues - self.backbone = self.backbones.get(backbone)(pretrained=pretrained, num_classes=num_classes, **backbone_kwargs) + self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs) def forward(self, x) -> torch.Tensor: # infer the image to the model From cf430f3544b1c7cd389550c1fb0cdd0721d9f232 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 May 2021 18:08:00 +0000 Subject: [PATCH 35/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .readthedocs.yml | 2 +- Makefile | 2 +- docs/source/_static/images/logo.svg | 2 +- requirements/extras.txt | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index 6eeec2eb41..bb9ffdbdd8 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,4 +20,4 @@ python: version: 3.7 install: - requirements: requirements/docs.txt - #- requirements: requirements.txt \ No newline at end of file + #- requirements: requirements.txt diff --git a/Makefile b/Makefile index 9bcddf4426..6fcee001e6 100644 --- a/Makefile +++ b/Makefile @@ -25,4 +25,4 @@ clean: rm -rf .pytest_cache rm -rf ./docs/build rm -rf ./docs/source/**/generated - rm -rf ./docs/source/api \ No newline at end of file + rm -rf ./docs/source/api diff --git a/docs/source/_static/images/logo.svg b/docs/source/_static/images/logo.svg index ad94333a72..2c3e330bbf 100644 --- a/docs/source/_static/images/logo.svg +++ b/docs/source/_static/images/logo.svg @@ -1 +1 @@ - \ No newline at end of file + diff --git a/requirements/extras.txt b/requirements/extras.txt index 78882bf8f3..ccb0a4be9b 100644 --- a/requirements/extras.txt +++ b/requirements/extras.txt @@ -1 +1 @@ -timm>=0.4.5 \ No newline at end of file +timm>=0.4.5 From 74ce6dc6a5c655edad3a647119be20e69c6ca4e1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 20:37:37 +0100 Subject: [PATCH 36/53] Data sources initial commit --- .gitignore | 2 + flash/vision/segmentation/data.py | 318 +++++++++++---------- flash/vision/segmentation/serialization.py | 2 +- 3 files changed, 166 insertions(+), 156 deletions(-) diff --git a/.gitignore b/.gitignore index 73b96a16dd..26ab5033dc 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,5 @@ wmt_en_ro action_youtube_naudio kinetics movie_posters +CameraRGB +CameraSeg diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index be6363fca8..8580416ab1 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -11,24 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union +import os +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np import torch -import torch.nn as nn import torchvision from PIL import Image from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch.utils.data import Dataset +from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS -from flash.data.auto_dataset import AutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule +from flash.data.data_source import DefaultDataKeys, DefaultDataSources, PathsDataSource from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE -from flash.vision.segmentation import transforms as T from flash.vision.segmentation.serialization import SegmentationKeys, SegmentationLabels if _MATPLOTLIB_AVAILABLE: @@ -37,6 +37,65 @@ plt = None +class SemanticSegmentationPathsDataSource(PathsDataSource): + + def __init__(self): + super().__init__(IMG_EXTENSIONS) + + def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) -> Sequence[Mapping[str, Any]]: + input_data, target_data = data + + if self.isdir(input_data) and self.isdir(target_data): + files = os.listdir(input_data) + input_files = [os.path.join(input_data, file) for file in files] + target_files = [os.path.join(target_data, file) for file in files] + + target_files = list(filter(os.path.isfile, target_files)) + + if len(input_files) != len(target_files): + rank_zero_warn( + f"Found inconsistent files in input_dir: {input_data} and target_dir: {target_data}. " + f"The following files have been dropped: " + f"{list(set(input_files).difference(set(target_files)))}", + UserWarning, + ) + + input_data = input_files + target_data = target_files + + if not isinstance(input_data, list) and not isinstance(target_data, list): + input_data = [input_data] + target_data = [target_data] + + data = filter( + lambda input, target: ( + has_file_allowed_extension(input, self.extensions) and + has_file_allowed_extension(target, self.extensions) + ), + zip(input_data, target_data), + ) + + return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] + + def predict_load_data(self, data: Union[str, List[str]]): + return super().predict_load_data(data) + + def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: + # unpack data paths + img_path = sample[DefaultDataKeys.INPUT] + img_labels_path = sample[DefaultDataKeys.TARGET] + + # load images directly to torch tensors + img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW + img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW + img_labels = img_labels[0] # HxW + + return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: img_labels} + + def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: + return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT])} + + class SemanticSegmentationPreprocess(Preprocess): def __init__( @@ -58,77 +117,24 @@ def __init__( """ self.image_size = image_size - train_transform, val_transform, test_transform, predict_transform = self._resolve_transforms( - train_transform, val_transform, test_transform, predict_transform, image_size + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource()} ) - super().__init__(train_transform, val_transform, test_transform, predict_transform) - # TODO: this is kind of boilerplate, let's simplify def get_state_dict(self) -> Dict[str, Any]: return { - "train_transform": self._train_transform, - "val_transform": self._val_transform, - "test_transform": self._test_transform, - "predict_transform": self._predict_transform, - "image_size": self.image_size + **self.transforms, + "image_size": self.image_size, } - # TODO: this is kind of boilerplate, let's simplify @classmethod - def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool): + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) - # TODO: this is kind of boilerplate, let's simplify - def _resolve_transforms( - self, - train_transform: Optional[Union[str, Dict]] = None, - val_transform: Optional[Union[str, Dict]] = None, - test_transform: Optional[Union[str, Dict]] = None, - predict_transform: Optional[Union[str, Dict]] = None, - image_size: Tuple[int, int] = (196, 196), - ) -> Tuple[Dict[str, Callable], ...]: - - if not train_transform or train_transform == 'default': - train_transform = T.default_train_transforms(image_size) - - if not val_transform or val_transform == 'default': - val_transform = T.default_val_transforms(image_size) - - if not test_transform or test_transform == 'default': - test_transform = T.default_val_transforms(image_size) - - if not predict_transform or predict_transform == 'default': - predict_transform = T.default_val_transforms(image_size) - - return ( - train_transform, - val_transform, - test_transform, - predict_transform, - ) - - def load_sample(self, sample: Union[str, Tuple[str, - str]]) -> Union[torch.Tensor, Dict[SegmentationKeys, torch.Tensor]]: - if not isinstance(sample, ( - str, - tuple, - )): - raise TypeError(f"Invalid type, expected `str` or `tuple`. Got: {sample}.") - - if isinstance(sample, str): # case for predict - return torchvision.io.read_image(sample) - - # unpack data paths - img_path: str = sample[0] - img_labels_path: str = sample[1] - - # load images directly to torch tensors - img: torch.Tensor = torchvision.io.read_image(img_path) # CxHxW - img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW - img_labels = img_labels[0] # HxW - - return {SegmentationKeys.IMAGES: img, SegmentationKeys.MASKS: img_labels} - # TODO: this routine should be moved to `per_batch_transform` once we have a way to # forward the labels to the loss function:. def post_tensor_transform( @@ -161,12 +167,14 @@ def post_tensor_transform( class SemanticSegmentationData(DataModule): """Data module for semantic segmentation tasks.""" - @staticmethod - def _check_valid_filepaths(filepaths: List[str]): - if filepaths is not None and ( - not isinstance(filepaths, list) or not all(isinstance(n, str) for n in filepaths) - ): - raise MisconfigurationException(f"`filepaths` must be of type List[str]. Got: {filepaths}.") + preprocess_cls = SemanticSegmentationPreprocess + + # @staticmethod + # def _check_valid_filepaths(filepaths: List[str]): + # if filepaths is not None and ( + # not isinstance(filepaths, list) or not all(isinstance(n, str) for n in filepaths) + # ): + # raise MisconfigurationException(f"`filepaths` must be of type List[str]. Got: {filepaths}.") @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: @@ -179,88 +187,88 @@ def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value - @classmethod - def from_filepaths( - cls, - train_filepaths: List[str], - train_labels: List[str], - val_filepaths: Optional[List[str]] = None, - val_labels: Optional[List[str]] = None, - test_filepaths: Optional[List[str]] = None, - test_labels: Optional[List[str]] = None, - predict_filepaths: Optional[List[str]] = None, - train_transform: Union[str, Dict] = 'default', - val_transform: Union[str, Dict] = 'default', - test_transform: Union[str, Dict] = 'default', - predict_transform: Union[str, Dict] = 'default', - image_size: Tuple[int, int] = (196, 196), - batch_size: int = 64, - num_workers: Optional[int] = None, - data_fetcher: BaseDataFetcher = None, - preprocess: Optional[Preprocess] = None, - val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED - **kwargs, # TODO: remove and make explicit params - ) -> 'SemanticSegmentationData': - """Creates a Semantic SegmentationData object from a given list of paths to images and labels. - - Args: - train_filepaths: List of file paths for training images. - train_labels: List of file path for the training image labels. - val_filepaths: List of file paths for validation images. - val_labels: List of file path for the validation image labels. - test_filepaths: List of file paths for testing images. - test_labels: List of file path for the testing image labels. - predict_filepaths: List of file paths for predicting images. - train_transform: Image and mask transform to use for the train set. - val_transform: Image and mask transform to use for the validation set. - test_transform: Image and mask transform to use for the test set. - predict_transform: Image transform to use for the predict set. - image_size: A tuple with the expected output image size. - batch_size: The batch size to use for parallel loading. - num_workers: The number of workers to use for parallelized loading. - Defaults to ``None`` which equals the number of available CPU threads. - data_fetcher: An optional data fetcher object instance. - preprocess: An optional `SemanticSegmentationPreprocess` object instance. - val_split: Float number to control the percentage of train/validation samples - from the ``train_filepaths`` and ``train_labels`` list. - - - Returns: - SemanticSegmentationData: The constructed data module. - - """ - - # verify input data format - SemanticSegmentationData._check_valid_filepaths(train_filepaths) - SemanticSegmentationData._check_valid_filepaths(train_labels) - SemanticSegmentationData._check_valid_filepaths(val_filepaths) - SemanticSegmentationData._check_valid_filepaths(val_labels) - SemanticSegmentationData._check_valid_filepaths(test_filepaths) - SemanticSegmentationData._check_valid_filepaths(test_labels) - SemanticSegmentationData._check_valid_filepaths(predict_filepaths) - - # create the preprocess objects - preprocess = preprocess or SemanticSegmentationPreprocess( - train_transform, - val_transform, - test_transform, - predict_transform, - image_size=image_size, - ) - - # this functions overrides `DataModule.from_load_data_inputs` - return cls.from_load_data_inputs( - train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, - val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, - test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, - predict_load_data_input=predict_filepaths, # TODO: is it really used ? - batch_size=batch_size, - num_workers=num_workers, - data_fetcher=data_fetcher, - preprocess=preprocess, - val_split=val_split, - **kwargs, # TODO: remove and make explicit params - ) + # @classmethod + # def from_filepaths( + # cls, + # train_filepaths: List[str], + # train_labels: List[str], + # val_filepaths: Optional[List[str]] = None, + # val_labels: Optional[List[str]] = None, + # test_filepaths: Optional[List[str]] = None, + # test_labels: Optional[List[str]] = None, + # predict_filepaths: Optional[List[str]] = None, + # train_transform: Union[str, Dict] = 'default', + # val_transform: Union[str, Dict] = 'default', + # test_transform: Union[str, Dict] = 'default', + # predict_transform: Union[str, Dict] = 'default', + # image_size: Tuple[int, int] = (196, 196), + # batch_size: int = 64, + # num_workers: Optional[int] = None, + # data_fetcher: BaseDataFetcher = None, + # preprocess: Optional[Preprocess] = None, + # val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED + # **kwargs, # TODO: remove and make explicit params + # ) -> 'SemanticSegmentationData': + # """Creates a Semantic SegmentationData object from a given list of paths to images and labels. + # + # Args: + # train_filepaths: List of file paths for training images. + # train_labels: List of file path for the training image labels. + # val_filepaths: List of file paths for validation images. + # val_labels: List of file path for the validation image labels. + # test_filepaths: List of file paths for testing images. + # test_labels: List of file path for the testing image labels. + # predict_filepaths: List of file paths for predicting images. + # train_transform: Image and mask transform to use for the train set. + # val_transform: Image and mask transform to use for the validation set. + # test_transform: Image and mask transform to use for the test set. + # predict_transform: Image transform to use for the predict set. + # image_size: A tuple with the expected output image size. + # batch_size: The batch size to use for parallel loading. + # num_workers: The number of workers to use for parallelized loading. + # Defaults to ``None`` which equals the number of available CPU threads. + # data_fetcher: An optional data fetcher object instance. + # preprocess: An optional `SemanticSegmentationPreprocess` object instance. + # val_split: Float number to control the percentage of train/validation samples + # from the ``train_filepaths`` and ``train_labels`` list. + # + # + # Returns: + # SemanticSegmentationData: The constructed data module. + # + # """ + # + # # verify input data format + # SemanticSegmentationData._check_valid_filepaths(train_filepaths) + # SemanticSegmentationData._check_valid_filepaths(train_labels) + # SemanticSegmentationData._check_valid_filepaths(val_filepaths) + # SemanticSegmentationData._check_valid_filepaths(val_labels) + # SemanticSegmentationData._check_valid_filepaths(test_filepaths) + # SemanticSegmentationData._check_valid_filepaths(test_labels) + # SemanticSegmentationData._check_valid_filepaths(predict_filepaths) + # + # # create the preprocess objects + # preprocess = preprocess or SemanticSegmentationPreprocess( + # train_transform, + # val_transform, + # test_transform, + # predict_transform, + # image_size=image_size, + # ) + # + # # this functions overrides `DataModule.from_load_data_inputs` + # return cls.from_load_data_inputs( + # train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, + # val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, + # test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, + # predict_load_data_input=predict_filepaths, # TODO: is it really used ? + # batch_size=batch_size, + # num_workers=num_workers, + # data_fetcher=data_fetcher, + # preprocess=preprocess, + # val_split=val_split, + # **kwargs, # TODO: remove and make explicit params + # ) class SegmentationMatplotlibVisualization(BaseVisualization): diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py index 907ed937c0..f5d67dd69e 100644 --- a/flash/vision/segmentation/serialization.py +++ b/flash/vision/segmentation/serialization.py @@ -16,7 +16,7 @@ import torch -from flash.data.process import ProcessState, Serializer +from flash.data.process import Serializer from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE if _MATPLOTLIB_AVAILABLE: From df2b9897f28feae8d68539deb26fa89a01923d26 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 21:00:36 +0100 Subject: [PATCH 37/53] Update transforms --- flash/vision/segmentation/serialization.py | 6 -- flash/vision/segmentation/transforms.py | 72 ++++++++++------------ 2 files changed, 31 insertions(+), 47 deletions(-) diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py index f5d67dd69e..f12aea68e7 100644 --- a/flash/vision/segmentation/serialization.py +++ b/flash/vision/segmentation/serialization.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from enum import Enum from typing import Dict, Optional, Tuple import torch @@ -30,11 +29,6 @@ K = None -class SegmentationKeys(Enum): - IMAGES = 'images' - MASKS = 'masks' - - class SegmentationLabels(Serializer): def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False): diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index 66491afe2b..b90aadab1c 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -11,64 +11,54 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, Tuple import kornia as K -import torch import torch.nn as nn -from flash.vision.segmentation.serialization import SegmentationKeys +from flash.data.data_source import DefaultDataKeys +from flash.data.transforms import ApplyToKeys -class ApplyTransformToKeys(nn.Sequential): +class KorniaParallelTransforms(nn.Sequential): - def __init__(self, keys: List[str], *args): - super().__init__(*args) - self.keys = keys - - @torch.no_grad() - def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - for seq in self.children(): - for aug in seq: - for key in self.keys: - # kornia caches the random parameters in `_params` after every - # forward call in the augmentation object. We check whether the - # random parameters have been generated so that we can apply the - # same parameters to images and masks. - if hasattr(aug, "_params") and bool(aug._params): - params = aug._params - x[key] = aug(x[key], params) - else: # case for non random transforms - x[key] = aug(x[key]) - return x + def forward(self, *inputs: Any): + result = list(inputs) + for transform in self.children(): + inputs = result + for i, input in enumerate(inputs): + if hasattr(transform, "_params") and bool(transform._params): + params = transform._params + result[i] = transform(input, params) + else: # case for non random transforms + result[i] = transform(input) + if hasattr(transform, "_params") and bool(transform._params): + transform._params = None + return result def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { - "post_tensor_transform": nn.Sequential( - ApplyTransformToKeys([SegmentationKeys.IMAGES, SegmentationKeys.MASKS], - nn.Sequential( - K.geometry.Resize(image_size, interpolation='nearest'), - K.augmentation.RandomHorizontalFlip(p=0.75), - )), # noqa: E126 - ApplyTransformToKeys( - [SegmentationKeys.IMAGES], - nn.Sequential( - K.enhance.Normalize(0., 255.), - K.augmentation.ColorJitter(0.4, p=0.5), - # NOTE: uncomment to visualise better - # K.enhance.Denormalize(0., 255.), - ) + "post_tensor_transform": ApplyToKeys( + [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + KorniaParallelTransforms( + K.geometry.Resize(image_size, interpolation='nearest'), + K.augmentation.RandomHorizontalFlip(p=0.75), ), ), + "per_batch_transform_on_device": ApplyToKeys( + DefaultDataKeys.INPUT, + K.enhance.Normalize(0., 255.), + K.augmentation.ColorJitter(0.4, p=0.5), + ), } def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { - "post_tensor_transform": nn.Sequential( - ApplyTransformToKeys([SegmentationKeys.IMAGES, SegmentationKeys.MASKS], - nn.Sequential(K.geometry.Resize(image_size, interpolation='nearest'), )), - ApplyTransformToKeys([SegmentationKeys.IMAGES], nn.Sequential(K.enhance.Normalize(0., 255.), )), + "post_tensor_transform": ApplyToKeys( + [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + K.geometry.Resize(image_size, interpolation='nearest'), ), + "per_batch_transform_on_device": ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)), } From bb95b8f1eb5a1ccd9216979a51571a2a97e15ad6 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 7 May 2021 22:45:31 +0100 Subject: [PATCH 38/53] Updates --- flash/data/transforms.py | 6 +- flash/vision/segmentation/data.py | 120 +++++++++++++----- flash/vision/segmentation/model.py | 24 +++- flash/vision/segmentation/transforms.py | 29 ++++- .../finetuning/semantic_segmentation.py | 7 +- 5 files changed, 137 insertions(+), 49 deletions(-) diff --git a/flash/data/transforms.py b/flash/data/transforms.py index 0a26224791..cfa927c52c 100644 --- a/flash/data/transforms.py +++ b/flash/data/transforms.py @@ -29,8 +29,10 @@ def __init__(self, keys: Union[str, Sequence[str]], *args): def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: inputs = [x[key] for key in filter(lambda key: key in x, self.keys)] if len(inputs) > 0: - outputs = super().forward(*inputs) - if not isinstance(outputs, tuple): + if len(inputs) == 1: + inputs = inputs[0] + outputs = super().forward(inputs) + if not isinstance(outputs, Sequence): outputs = (outputs, ) result = {} diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 8580416ab1..007427ca15 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -29,7 +29,8 @@ from flash.data.data_source import DefaultDataKeys, DefaultDataSources, PathsDataSource from flash.data.process import Preprocess from flash.utils.imports import _MATPLOTLIB_AVAILABLE -from flash.vision.segmentation.serialization import SegmentationKeys, SegmentationLabels +from flash.vision.segmentation.serialization import SegmentationLabels +from flash.vision.segmentation.transforms import default_train_transforms, default_val_transforms if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt @@ -68,9 +69,9 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) - target_data = [target_data] data = filter( - lambda input, target: ( - has_file_allowed_extension(input, self.extensions) and - has_file_allowed_extension(target, self.extensions) + lambda sample: ( + has_file_allowed_extension(sample[0], self.extensions) and + has_file_allowed_extension(sample[1], self.extensions) ), zip(input_data, target_data), ) @@ -90,10 +91,10 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW - return {DefaultDataKeys.INPUT: img, DefaultDataKeys.TARGET: img_labels} + return {DefaultDataKeys.INPUT: img.float(), DefaultDataKeys.TARGET: img_labels.float()} def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: - return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT])} + return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()} class SemanticSegmentationPreprocess(Preprocess): @@ -135,30 +136,46 @@ def get_state_dict(self) -> Dict[str, Any]: def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) - # TODO: this routine should be moved to `per_batch_transform` once we have a way to - # forward the labels to the loss function:. - def post_tensor_transform( - self, sample: Union[torch.Tensor, Dict[SegmentationKeys, torch.Tensor]] - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if isinstance(sample, torch.Tensor): # case for predict - out = sample.float() / 255. # TODO: define predict transforms - return out - - if not isinstance(sample, dict): - raise TypeError(f"Invalid type, expected `dict`. Got: {sample}.") - - # pass to the transforms a dictionary with copies to handle potential memory leaks - sample_in: Dict[SegmentationKeys, torch.Tensor] = {} - sample_in[SegmentationKeys.IMAGES] = ( - sample[SegmentationKeys.IMAGES][None].float().contiguous() # 1xCxHxW - ) - sample_in[SegmentationKeys.MASKS] = ( - sample[SegmentationKeys.MASKS][None, None].float().contiguous() # 1x1xHxW - ) - - out: Dict[SegmentationKeys, torch.Tensor] = self.current_transform(sample_in) - - return out[SegmentationKeys.IMAGES][0], out[SegmentationKeys.MASKS][0, 0].long() + @property + def default_train_transforms(self) -> Optional[Dict[str, Callable]]: + return default_train_transforms(self.image_size) + + @property + def default_val_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) + + @property + def default_test_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) + + @property + def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: + return default_val_transforms(self.image_size) + + # # TODO: this routine should be moved to `per_batch_transform` once we have a way to + # # forward the labels to the loss function:. + # def post_tensor_transform( + # self, sample: Union[torch.Tensor, Dict[SegmentationKeys, torch.Tensor]] + # ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # if isinstance(sample, torch.Tensor): # case for predict + # out = sample.float() / 255. # TODO: define predict transforms + # return out + # + # if not isinstance(sample, dict): + # raise TypeError(f"Invalid type, expected `dict`. Got: {sample}.") + # + # # pass to the transforms a dictionary with copies to handle potential memory leaks + # sample_in: Dict[SegmentationKeys, torch.Tensor] = {} + # sample_in[SegmentationKeys.IMAGES] = ( + # sample[SegmentationKeys.IMAGES][None].float().contiguous() # 1xCxHxW + # ) + # sample_in[SegmentationKeys.MASKS] = ( + # sample[SegmentationKeys.MASKS][None, None].float().contiguous() # 1x1xHxW + # ) + # + # out: Dict[SegmentationKeys, torch.Tensor] = self.current_transform(sample_in) + # + # return out[SegmentationKeys.IMAGES][0], out[SegmentationKeys.MASKS][0, 0].long() # TODO: implement `per_batch_transform` and `per_batch_transform_on_device` ## @@ -187,6 +204,45 @@ def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value + @classmethod + def from_folders( + cls, + train_folder: Optional[str] = None, + train_target_folder: Optional[str] = None, + val_folder: Optional[str] = None, + val_target_folder: Optional[str] = None, + test_folder: Optional[str] = None, + test_target_folder: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: BaseDataFetcher = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + return cls.from_data_source( + DefaultDataSources.PATHS, + (train_folder, train_target_folder), + (val_folder, val_target_folder), + (test_folder, test_target_folder), + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs, + ) + # @classmethod # def from_filepaths( # cls, @@ -308,8 +364,8 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) # unpack images and labels sample = data[i] if isinstance(sample, dict): - image = sample[SegmentationKeys.IMAGES] - label = sample[SegmentationKeys.MASKS] + image = sample[DefaultDataKeys.INPUT] + label = sample[DefaultDataKeys.TARGET] elif isinstance(sample, tuple): image = sample[0] label = sample[1] diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index ab8b121844..e543b341ed 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -11,17 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from types import FunctionType -from typing import Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union import torch from torch import nn from torch.nn import functional as F -from torchmetrics import Accuracy, IoU +from torchmetrics import IoU from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry -from flash.data.process import Preprocess, Serializer +from flash.data.data_source import DefaultDataKeys +from flash.data.process import Serializer from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.vision.segmentation.serialization import SegmentationLabels @@ -96,6 +96,22 @@ def __init__( # TODO: pretrained to True causes some issues self.backbone = self.backbones.get(backbone)(num_classes, pretrained=pretrained, **backbone_kwargs) + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.TARGET]) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch = (batch[DefaultDataKeys.INPUT]) + return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + def forward(self, x) -> torch.Tensor: # infer the image to the model res: Union[torch.Tensor, Dict[str, torch.Tensor]] = self.backbone(x) diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index b90aadab1c..d6ced9fbef 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -14,15 +14,20 @@ from typing import Any, Callable, Dict, Tuple import kornia as K +import torch import torch.nn as nn from flash.data.data_source import DefaultDataKeys from flash.data.transforms import ApplyToKeys +from flash.data.utils import convert_to_modules class KorniaParallelTransforms(nn.Sequential): - def forward(self, *inputs: Any): + def __init__(self, *args): + super().__init__(*[convert_to_modules(arg) for arg in args]) + + def forward(self, inputs: Any): result = list(inputs) for transform in self.children(): inputs = result @@ -37,6 +42,10 @@ def forward(self, *inputs: Any): return result +def to_long(tensor: torch.Tensor) -> torch.Tensor: + return tensor.long() + + def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { "post_tensor_transform": ApplyToKeys( @@ -46,10 +55,13 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] K.augmentation.RandomHorizontalFlip(p=0.75), ), ), - "per_batch_transform_on_device": ApplyToKeys( - DefaultDataKeys.INPUT, - K.enhance.Normalize(0., 255.), - K.augmentation.ColorJitter(0.4, p=0.5), + "per_batch_transform_on_device": nn.Sequential( + ApplyToKeys( + DefaultDataKeys.INPUT, + K.enhance.Normalize(0., 255.), + K.augmentation.ColorJitter(0.4, p=0.5), + ), + ApplyToKeys(DefaultDataKeys.TARGET, to_long), ), } @@ -58,7 +70,10 @@ def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { "post_tensor_transform": ApplyToKeys( [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], - K.geometry.Resize(image_size, interpolation='nearest'), + KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest'), ), + ), + "per_batch_transform_on_device": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)), + ApplyToKeys(DefaultDataKeys.TARGET, to_long), ), - "per_batch_transform_on_device": ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)), } diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 1c8d94d5ee..313b1ad7c9 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -46,13 +46,12 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: images_filepaths, labels_filepaths = load_data() # create the data module -datamodule = SemanticSegmentationData.from_filepaths( - train_filepaths=images_filepaths, - train_labels=labels_filepaths, +datamodule = SemanticSegmentationData.from_files( + train_files=images_filepaths, + train_targets=labels_filepaths, batch_size=4, val_split=0.3, image_size=(200, 200), # (600, 800) - num_workers=4, ) # 2.2 Visualise the samples From 3a8842d7d420bee21f1ce8eee1bf3ffe628bb8d4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sat, 8 May 2021 13:09:45 +0100 Subject: [PATCH 39/53] Fixes --- flash/data/transforms.py | 5 +- flash/vision/classification/data.py | 2 +- flash/vision/segmentation/data.py | 129 ++---------------- flash/vision/segmentation/serialization.py | 3 +- flash/vision/segmentation/transforms.py | 12 +- .../finetuning/semantic_segmentation.py | 39 ++---- .../predict/semantic_segmentation.py | 6 +- tests/examples/test_scripts.py | 2 + 8 files changed, 36 insertions(+), 162 deletions(-) diff --git a/flash/data/transforms.py b/flash/data/transforms.py index cfa927c52c..67b457f62b 100644 --- a/flash/data/transforms.py +++ b/flash/data/transforms.py @@ -27,7 +27,8 @@ def __init__(self, keys: Union[str, Sequence[str]], *args): self.keys = keys def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: - inputs = [x[key] for key in filter(lambda key: key in x, self.keys)] + keys = list(filter(lambda key: key in x, self.keys)) + inputs = [x[key] for key in keys] if len(inputs) > 0: if len(inputs) == 1: inputs = inputs[0] @@ -37,7 +38,7 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: result = {} result.update(x) - for i, key in enumerate(self.keys): + for i, key in enumerate(keys): result[key] = outputs[i] return result return x diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 928605b244..79d0fca863 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -73,7 +73,7 @@ def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: for key in sample.keys(): if torch.is_tensor(sample[key]): sample[key] = sample[key].squeeze(0) - return default_collate(samples) + return super().collate(samples) @property def default_train_transforms(self) -> Optional[Dict[str, Callable]]: diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 007427ca15..4333e7c268 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -123,7 +123,8 @@ def __init__( val_transform=val_transform, test_transform=test_transform, predict_transform=predict_transform, - data_sources={DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource()} + data_sources={DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource()}, + default_data_source=DefaultDataSources.PATHS, ) def get_state_dict(self) -> Dict[str, Any]: @@ -136,6 +137,14 @@ def get_state_dict(self) -> Dict[str, Any]: def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): return cls(**state_dict) + def collate(self, samples: Sequence[Dict[str, Any]]) -> Any: + # todo: Kornia transforms add batch dimension which need to be removed + for sample in samples: + for key in sample.keys(): + if torch.is_tensor(sample[key]): + sample[key] = sample[key].squeeze(0) + return super().collate(samples) + @property def default_train_transforms(self) -> Optional[Dict[str, Callable]]: return default_train_transforms(self.image_size) @@ -152,47 +161,12 @@ def default_test_transforms(self) -> Optional[Dict[str, Callable]]: def default_predict_transforms(self) -> Optional[Dict[str, Callable]]: return default_val_transforms(self.image_size) - # # TODO: this routine should be moved to `per_batch_transform` once we have a way to - # # forward the labels to the loss function:. - # def post_tensor_transform( - # self, sample: Union[torch.Tensor, Dict[SegmentationKeys, torch.Tensor]] - # ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - # if isinstance(sample, torch.Tensor): # case for predict - # out = sample.float() / 255. # TODO: define predict transforms - # return out - # - # if not isinstance(sample, dict): - # raise TypeError(f"Invalid type, expected `dict`. Got: {sample}.") - # - # # pass to the transforms a dictionary with copies to handle potential memory leaks - # sample_in: Dict[SegmentationKeys, torch.Tensor] = {} - # sample_in[SegmentationKeys.IMAGES] = ( - # sample[SegmentationKeys.IMAGES][None].float().contiguous() # 1xCxHxW - # ) - # sample_in[SegmentationKeys.MASKS] = ( - # sample[SegmentationKeys.MASKS][None, None].float().contiguous() # 1x1xHxW - # ) - # - # out: Dict[SegmentationKeys, torch.Tensor] = self.current_transform(sample_in) - # - # return out[SegmentationKeys.IMAGES][0], out[SegmentationKeys.MASKS][0, 0].long() - - # TODO: implement `per_batch_transform` and `per_batch_transform_on_device` - ## - class SemanticSegmentationData(DataModule): """Data module for semantic segmentation tasks.""" preprocess_cls = SemanticSegmentationPreprocess - # @staticmethod - # def _check_valid_filepaths(filepaths: List[str]): - # if filepaths is not None and ( - # not isinstance(filepaths, list) or not all(isinstance(n, str) for n in filepaths) - # ): - # raise MisconfigurationException(f"`filepaths` must be of type List[str]. Got: {filepaths}.") - @staticmethod def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return SegmentationMatplotlibVisualization(*args, **kwargs) @@ -243,89 +217,6 @@ def from_folders( **preprocess_kwargs, ) - # @classmethod - # def from_filepaths( - # cls, - # train_filepaths: List[str], - # train_labels: List[str], - # val_filepaths: Optional[List[str]] = None, - # val_labels: Optional[List[str]] = None, - # test_filepaths: Optional[List[str]] = None, - # test_labels: Optional[List[str]] = None, - # predict_filepaths: Optional[List[str]] = None, - # train_transform: Union[str, Dict] = 'default', - # val_transform: Union[str, Dict] = 'default', - # test_transform: Union[str, Dict] = 'default', - # predict_transform: Union[str, Dict] = 'default', - # image_size: Tuple[int, int] = (196, 196), - # batch_size: int = 64, - # num_workers: Optional[int] = None, - # data_fetcher: BaseDataFetcher = None, - # preprocess: Optional[Preprocess] = None, - # val_split: Optional[float] = None, # MAKES IT CRASH. NEED TO BE FIXED - # **kwargs, # TODO: remove and make explicit params - # ) -> 'SemanticSegmentationData': - # """Creates a Semantic SegmentationData object from a given list of paths to images and labels. - # - # Args: - # train_filepaths: List of file paths for training images. - # train_labels: List of file path for the training image labels. - # val_filepaths: List of file paths for validation images. - # val_labels: List of file path for the validation image labels. - # test_filepaths: List of file paths for testing images. - # test_labels: List of file path for the testing image labels. - # predict_filepaths: List of file paths for predicting images. - # train_transform: Image and mask transform to use for the train set. - # val_transform: Image and mask transform to use for the validation set. - # test_transform: Image and mask transform to use for the test set. - # predict_transform: Image transform to use for the predict set. - # image_size: A tuple with the expected output image size. - # batch_size: The batch size to use for parallel loading. - # num_workers: The number of workers to use for parallelized loading. - # Defaults to ``None`` which equals the number of available CPU threads. - # data_fetcher: An optional data fetcher object instance. - # preprocess: An optional `SemanticSegmentationPreprocess` object instance. - # val_split: Float number to control the percentage of train/validation samples - # from the ``train_filepaths`` and ``train_labels`` list. - # - # - # Returns: - # SemanticSegmentationData: The constructed data module. - # - # """ - # - # # verify input data format - # SemanticSegmentationData._check_valid_filepaths(train_filepaths) - # SemanticSegmentationData._check_valid_filepaths(train_labels) - # SemanticSegmentationData._check_valid_filepaths(val_filepaths) - # SemanticSegmentationData._check_valid_filepaths(val_labels) - # SemanticSegmentationData._check_valid_filepaths(test_filepaths) - # SemanticSegmentationData._check_valid_filepaths(test_labels) - # SemanticSegmentationData._check_valid_filepaths(predict_filepaths) - # - # # create the preprocess objects - # preprocess = preprocess or SemanticSegmentationPreprocess( - # train_transform, - # val_transform, - # test_transform, - # predict_transform, - # image_size=image_size, - # ) - # - # # this functions overrides `DataModule.from_load_data_inputs` - # return cls.from_load_data_inputs( - # train_load_data_input=list(zip(train_filepaths, train_labels)) if train_filepaths else None, - # val_load_data_input=list(zip(val_filepaths, val_labels)) if val_filepaths else None, - # test_load_data_input=list(zip(test_filepaths, test_labels)) if test_filepaths else None, - # predict_load_data_input=predict_filepaths, # TODO: is it really used ? - # batch_size=batch_size, - # num_workers=num_workers, - # data_fetcher=data_fetcher, - # preprocess=preprocess, - # val_split=val_split, - # **kwargs, # TODO: remove and make explicit params - # ) - class SegmentationMatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py index f12aea68e7..50ba5be9a9 100644 --- a/flash/vision/segmentation/serialization.py +++ b/flash/vision/segmentation/serialization.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Dict, Optional, Tuple import torch @@ -67,7 +68,7 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int] def serialize(self, sample: torch.Tensor) -> torch.Tensor: assert len(sample.shape) == 3, sample.shape labels = torch.argmax(sample, dim=-3) # HxW - if self.visualize: + if self.visualize and os.getenv("FLASH_TESTING", "0") == "0": if self.labels_map is None: # create random colors map num_classes = sample.shape[-3] diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index d6ced9fbef..2a91f53db4 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Tuple +from typing import Any, Callable, Dict, Sequence, Tuple import kornia as K import torch @@ -28,7 +28,7 @@ def __init__(self, *args): super().__init__(*[convert_to_modules(arg) for arg in args]) def forward(self, inputs: Any): - result = list(inputs) + result = list(inputs) if isinstance(inputs, Sequence) else [inputs] for transform in self.children(): inputs = result for i, input in enumerate(inputs): @@ -42,8 +42,8 @@ def forward(self, inputs: Any): return result -def to_long(tensor: torch.Tensor) -> torch.Tensor: - return tensor.long() +def prepare_target(tensor: torch.Tensor) -> torch.Tensor: + return tensor.long().squeeze() def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: @@ -61,7 +61,7 @@ def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable] K.enhance.Normalize(0., 255.), K.augmentation.ColorJitter(0.4, p=0.5), ), - ApplyToKeys(DefaultDataKeys.TARGET, to_long), + ApplyToKeys(DefaultDataKeys.TARGET, prepare_target), ), } @@ -74,6 +74,6 @@ def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: ), "per_batch_transform_on_device": nn.Sequential( ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)), - ApplyToKeys(DefaultDataKeys.TARGET, to_long), + ApplyToKeys(DefaultDataKeys.TARGET, prepare_target), ), } diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 313b1ad7c9..3676353ec8 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from typing import List, Tuple - -import torch - import flash from flash.data.utils import download_data from flash.vision import SemanticSegmentation, SemanticSegmentationData @@ -30,25 +25,9 @@ ) # 2.1 Load the data - - -def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: - images: List[str] = [] - labels: List[str] = [] - rgb_path = os.path.join(data_root, "CameraRGB") - seg_path = os.path.join(data_root, "CameraSeg") - for fname in os.listdir(rgb_path): - images.append(os.path.join(rgb_path, fname)) - labels.append(os.path.join(seg_path, fname)) - return images, labels - - -images_filepaths, labels_filepaths = load_data() - -# create the data module -datamodule = SemanticSegmentationData.from_files( - train_files=images_filepaths, - train_targets=labels_filepaths, +datamodule = SemanticSegmentationData.from_folders( + train_folder="data/CameraRGB", + train_target_folder="data/CameraSeg", batch_size=4, val_split=0.3, image_size=(200, 200), # (600, 800) @@ -68,20 +47,20 @@ def load_data(data_root: str = 'data/') -> Tuple[List[str], List[str]]: # 4. Create the trainer. trainer = flash.Trainer( max_epochs=1, - gpus=0, + fast_dev_run=1, ) # 5. Train the model -trainer.finetune(model, datamodule=datamodule, strategy='freeze') +trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 6. Predict what's on a few images! model.serializer = SegmentationLabels(labels_map, visualize=True) predictions = model.predict([ - 'data/CameraRGB/F61-1.png', - 'data/CameraRGB/F62-1.png', - 'data/CameraRGB/F63-1.png', -], datamodule.data_pipeline) + "data/CameraRGB/F61-1.png", + "data/CameraRGB/F62-1.png", + "data/CameraRGB/F63-1.png", +]) # 7. Save it! trainer.save_checkpoint("semantic_segmentation_model.pt") diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py index d1d20830c5..b338ed95ec 100644 --- a/flash_examples/predict/semantic_segmentation.py +++ b/flash_examples/predict/semantic_segmentation.py @@ -32,7 +32,7 @@ # 3. Predict what's on a few images and visualize! predictions = model.predict([ - 'data/CameraRGB/F61-1.png', - 'data/CameraRGB/F62-1.png', - 'data/CameraRGB/F63-1.png', + "data/CameraRGB/F61-1.png", + "data/CameraRGB/F62-1.png", + "data/CameraRGB/F63-1.png", ]) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 2fc4ee18f3..a60ecbf021 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -58,6 +58,7 @@ def run_test(filepath): ("finetuning", "image_classification.py"), ("finetuning", "image_classification_multi_label.py"), # ("finetuning", "object_detection.py"), # TODO: takes too long. + ("finetuning", "semantic_segmentation.py"), # ("finetuning", "summarization.py"), # TODO: takes too long. ("finetuning", "tabular_classification.py"), # ("finetuning", "video_classification.py"), @@ -65,6 +66,7 @@ def run_test(filepath): ("finetuning", "translation.py"), ("predict", "image_classification.py"), ("predict", "image_classification_multi_label.py"), + ("predict", "semantic_segmentation.py"), ("predict", "tabular_classification.py"), # ("predict", "text_classification.py"), ("predict", "image_embedder.py"), From 3596e16a69a949c4a834f05ae73de703fb05cb7a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sat, 8 May 2021 13:23:35 +0100 Subject: [PATCH 40/53] Fix tests --- flash/vision/segmentation/transforms.py | 34 +++++++-------- tests/vision/segmentation/test_data.py | 55 +++++++++++-------------- tests/vision/segmentation/test_model.py | 7 +++- 3 files changed, 47 insertions(+), 49 deletions(-) diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index 2a91f53db4..11bb87cde3 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -48,32 +48,32 @@ def prepare_target(tensor: torch.Tensor) -> torch.Tensor: def default_train_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { - "post_tensor_transform": ApplyToKeys( - [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], - KorniaParallelTransforms( - K.geometry.Resize(image_size, interpolation='nearest'), - K.augmentation.RandomHorizontalFlip(p=0.75), - ), - ), - "per_batch_transform_on_device": nn.Sequential( + "post_tensor_transform": nn.Sequential( ApplyToKeys( - DefaultDataKeys.INPUT, - K.enhance.Normalize(0., 255.), - K.augmentation.ColorJitter(0.4, p=0.5), + [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + KorniaParallelTransforms( + K.geometry.Resize(image_size, interpolation='nearest'), + K.augmentation.RandomHorizontalFlip(p=0.75), + ), ), ApplyToKeys(DefaultDataKeys.TARGET, prepare_target), ), + "per_batch_transform_on_device": ApplyToKeys( + DefaultDataKeys.INPUT, + K.enhance.Normalize(0., 255.), + K.augmentation.ColorJitter(0.4, p=0.5), + ), } def default_val_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]: return { - "post_tensor_transform": ApplyToKeys( - [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], - KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest'), ), - ), - "per_batch_transform_on_device": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)), + "post_tensor_transform": nn.Sequential( + ApplyToKeys( + [DefaultDataKeys.INPUT, DefaultDataKeys.TARGET], + KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest')), + ), ApplyToKeys(DefaultDataKeys.TARGET, prepare_target), ), + "per_batch_transform_on_device": ApplyToKeys(DefaultDataKeys.INPUT, K.enhance.Normalize(0., 255.)), } diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 432b972ac4..3fcf3e9504 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import numpy as np import pytest @@ -7,6 +7,7 @@ from PIL import Image from flash import Trainer +from flash.data.data_source import DefaultDataKeys from flash.vision import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess @@ -51,18 +52,18 @@ def test_smoke(self): dm = SemanticSegmentationData() assert dm is not None - def test_from_filepaths(self, tmpdir): + def test_from_files(self, tmpdir): tmp_dir = Path(tmpdir) # create random dummy data - train_images = [ + images = [ str(tmp_dir / "img1.png"), str(tmp_dir / "img2.png"), str(tmp_dir / "img3.png"), ] - train_labels = [ + targets = [ str(tmp_dir / "labels_img1.png"), str(tmp_dir / "labels_img2.png"), str(tmp_dir / "labels_img3.png"), @@ -70,17 +71,17 @@ def test_from_filepaths(self, tmpdir): num_classes: int = 2 img_size: Tuple[int, int] = (196, 196) - create_random_data(train_images, train_labels, img_size, num_classes) + create_random_data(images, targets, img_size, num_classes) # instantiate the data module - dm = SemanticSegmentationData.from_filepaths( - train_filepaths=train_images, - train_labels=train_labels, - val_filepaths=train_images, - val_labels=train_labels, - test_filepaths=train_images, - test_labels=train_labels, + dm = SemanticSegmentationData.from_files( + train_files=images, + train_targets=targets, + val_files=images, + val_targets=targets, + test_files=images, + test_targets=targets, batch_size=2, num_workers=0, ) @@ -91,19 +92,13 @@ def test_from_filepaths(self, tmpdir): # check training data data = next(iter(dm.train_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 196, 196) - # check training data - data = next(iter(dm.val_dataloader())) - imgs, labels = data - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) - - # check training data + # check val data data = next(iter(dm.val_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 196, 196) @@ -112,13 +107,13 @@ def test_map_labels(self, tmpdir): # create random dummy data - train_images = [ + images = [ str(tmp_dir / "img1.png"), str(tmp_dir / "img2.png"), str(tmp_dir / "img3.png"), ] - train_labels = [ + targets = [ str(tmp_dir / "labels_img1.png"), str(tmp_dir / "labels_img2.png"), str(tmp_dir / "labels_img3.png"), @@ -131,15 +126,15 @@ def test_map_labels(self, tmpdir): num_classes: int = len(labels_map.keys()) img_size: Tuple[int, int] = (196, 196) - create_random_data(train_images, train_labels, img_size, num_classes) + create_random_data(images, targets, img_size, num_classes) # instantiate the data module - dm = SemanticSegmentationData.from_filepaths( - train_filepaths=train_images, - train_labels=train_labels, - val_filepaths=train_images, - val_labels=train_labels, + dm = SemanticSegmentationData.from_files( + train_files=images, + train_targets=targets, + val_files=images, + val_targets=targets, batch_size=2, num_workers=0, ) @@ -157,7 +152,7 @@ def test_map_labels(self, tmpdir): # check training data data = next(iter(dm.train_dataloader())) - imgs, labels = data + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 196, 196) assert labels.min().item() == 0 diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index 1b4d2d45e9..0ebaa5d956 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -4,6 +4,7 @@ import torch from flash import Trainer +from flash.data.data_source import DefaultDataKeys from flash.vision import SemanticSegmentation # ======== Mock functions ======== @@ -14,8 +15,10 @@ class DummyDataset(torch.utils.data.Dataset): num_classes: int = 8 def __getitem__(self, index): - return torch.rand(3, *self.size), \ - torch.randint(self.num_classes - 1, self.size) + return { + DefaultDataKeys.INPUT: torch.rand(3, *self.size), + DefaultDataKeys.TARGET: torch.randint(self.num_classes - 1, self.size), + } def __len__(self) -> int: return 10 From 859a0ef085b891bf82a822b3fa139962b7ebc2a9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sat, 8 May 2021 13:37:07 +0100 Subject: [PATCH 41/53] Fixes --- .../reference/semantic_segmentation.rst | 33 +++++++------------ .../predict/semantic_segmentation.py | 1 - tests/data/test_callbacks.py | 3 +- 3 files changed, 14 insertions(+), 23 deletions(-) diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index d6b3518c03..b77167ffc6 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -34,7 +34,6 @@ Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inferen .. code-block:: python # import our libraries - from flash import Trainer from flash.data.utils import download_data from flash.vision import SemanticSegmentation from flash.vision.segmentation.serialization import SegmentationLabels @@ -88,10 +87,10 @@ Now all we need is three lines of code to build to train our task! .. code-block:: python - import os import flash from flash.data.utils import download_data from flash.vision import SemanticSegmentation, SemanticSegmentationData + from flash.vision.segmentation.serialization import SegmentationLabels # 1. Download the data download_data( @@ -100,27 +99,19 @@ Now all we need is three lines of code to build to train our task! ) # 2.1 Load the data - - - def load_data(data_root: str = 'data/'): - images = [] - labels = [] - rgb_path = os.path.join(data_root, "CameraRGB") - seg_path = os.path.join(data_root, "CameraSeg") - for fname in os.listdir(rgb_path): - images.append(os.path.join(rgb_path, fname)) - labels.append(os.path.join(seg_path, fname)) - return images, labels - - - images_filepaths, labels_filepaths = load_data() - - # create the data module - datamodule = SemanticSegmentationData.from_filepaths( - train_filepaths=images_filepaths, - train_labels=labels_filepaths, + datamodule = SemanticSegmentationData.from_folders( + train_folder="data/CameraRGB", + train_target_folder="data/CameraSeg", + batch_size=4, + val_split=0.3, + image_size=(200, 200), # (600, 800) ) + # 2.2 Visualise the samples + labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) + datamodule.set_labels_map(labels_map) + datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) + # 3. Build the model model = SemanticSegmentation(backbone="torchvision/fcn_resnet50", num_classes=21) diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py index b338ed95ec..f507f2a6a6 100644 --- a/flash_examples/predict/semantic_segmentation.py +++ b/flash_examples/predict/semantic_segmentation.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import flash from flash.data.utils import download_data from flash.vision import SemanticSegmentation from flash.vision.segmentation.serialization import SegmentationLabels diff --git a/tests/data/test_callbacks.py b/tests/data/test_callbacks.py index f4748a5149..b53db09b6e 100644 --- a/tests/data/test_callbacks.py +++ b/tests/data/test_callbacks.py @@ -157,8 +157,9 @@ def configure_data_fetcher(*args, **kwargs) -> CustomBaseVisualization: for _ in range(num_tests): for fcn_name in _CALLBACK_FUNCS: + dm.data_fetcher.reset() fcn = getattr(dm, f"show_{stage}_batch") - fcn(fcn_name, reset=True) + fcn(fcn_name, reset=False) is_predict = stage == "predict" From 268b4c8116bc978d624f7a265188f8e3d2ef07e8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sat, 8 May 2021 13:41:15 +0100 Subject: [PATCH 42/53] Fixes --- docs/source/reference/semantic_segmentation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index b77167ffc6..97063438ea 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -146,6 +146,6 @@ SemanticSegmentationData .. autoclass:: flash.vision.SemanticSegmentationData -.. automethod:: flash.vision.SemanticSegmentationData.from_filepaths +.. automethod:: flash.vision.SemanticSegmentationData.from_folders .. autoclass:: flash.vision.SemanticSegmentationPreprocess From 4aa3716124efff52a193ba31b25ea90073545b17 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 08:37:14 +0100 Subject: [PATCH 43/53] Add tests --- flash/vision/segmentation/data.py | 5 +- tests/vision/segmentation/test_data.py | 124 +++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 3 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 4333e7c268..47d154c1a7 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -55,9 +55,8 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) - if len(input_files) != len(target_files): rank_zero_warn( - f"Found inconsistent files in input_dir: {input_data} and target_dir: {target_data}. " - f"The following files have been dropped: " - f"{list(set(input_files).difference(set(target_files)))}", + f"Found inconsistent files in input_dir: {input_data} and target_dir: {target_data}. Some files" + " have been dropped.", UserWarning, ) diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 3fcf3e9504..b3eedf36cc 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -1,3 +1,4 @@ +import os from pathlib import Path from typing import Dict, List, Tuple @@ -52,6 +53,123 @@ def test_smoke(self): dm = SemanticSegmentationData() assert dm is not None + def test_from_folders(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + os.makedirs(str(tmp_dir / "images")) + os.makedirs(str(tmp_dir / "targets")) + + images = [ + str(tmp_dir / "images" / "img1.png"), + str(tmp_dir / "images" / "img2.png"), + str(tmp_dir / "images" / "img3.png"), + ] + + targets = [ + str(tmp_dir / "targets" / "img1.png"), + str(tmp_dir / "targets" / "img2.png"), + str(tmp_dir / "targets" / "img3.png"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (196, 196) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + dm = SemanticSegmentationData.from_folders( + train_folder=str(tmp_dir / "images"), + train_target_folder=str(tmp_dir / "targets"), + val_folder=str(tmp_dir / "images"), + val_target_folder=str(tmp_dir / "targets"), + test_folder=str(tmp_dir / "images"), + test_target_folder=str(tmp_dir / "targets"), + batch_size=2, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check val data + data = next(iter(dm.val_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + + def test_from_folders_warning(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + os.makedirs(str(tmp_dir / "images")) + os.makedirs(str(tmp_dir / "targets")) + + images = [ + str(tmp_dir / "images" / "img1.png"), + str(tmp_dir / "images" / "img3.png"), + ] + + targets = [ + str(tmp_dir / "targets" / "img1.png"), + str(tmp_dir / "targets" / "img2.png"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (196, 196) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + with pytest.warns(UserWarning, match="Found inconsistent files"): + dm = SemanticSegmentationData.from_folders( + train_folder=str(tmp_dir / "images"), + train_target_folder=str(tmp_dir / "targets"), + val_folder=str(tmp_dir / "images"), + val_target_folder=str(tmp_dir / "targets"), + test_folder=str(tmp_dir / "images"), + test_target_folder=str(tmp_dir / "targets"), + batch_size=1, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + assert dm.val_dataloader() is not None + assert dm.test_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1, 196, 196) + + # check val data + data = next(iter(dm.val_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1, 196, 196) + + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1, 196, 196) + def test_from_files(self, tmpdir): tmp_dir = Path(tmpdir) @@ -102,6 +220,12 @@ def test_from_files(self, tmpdir): assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 196, 196) + # check test data + data = next(iter(dm.test_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + def test_map_labels(self, tmpdir): tmp_dir = Path(tmpdir) From f50c200059a6f173b861f0f67ea3a73cc4e5a0ee Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 09:15:05 +0100 Subject: [PATCH 44/53] Update docs/source/reference/semantic_segmentation.rst Co-authored-by: thomas chaton --- docs/source/reference/semantic_segmentation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index 97063438ea..059813202d 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -8,7 +8,7 @@ Semantinc Segmentation ******** The task ******** -Semantic segmentation, or image segmentation, is the task of clustering parts of an image together which belong to the same object class. It is a form of pixel-level prediction because each pixel in an image is classified according to a category +Semantic segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. The model output shape is ``(batch_size, num_classes, heigh, width)``. See more: https://paperswithcode.com/task/semantic-segmentation From 11ed7c55b764d7aef7ff890066ad9a670b69ed63 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 09:15:20 +0100 Subject: [PATCH 45/53] Update docs/source/reference/semantic_segmentation.rst Co-authored-by: thomas chaton --- docs/source/reference/semantic_segmentation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/reference/semantic_segmentation.rst b/docs/source/reference/semantic_segmentation.rst index 059813202d..0a9e01d8bb 100644 --- a/docs/source/reference/semantic_segmentation.rst +++ b/docs/source/reference/semantic_segmentation.rst @@ -26,7 +26,7 @@ See more: https://paperswithcode.com/task/semantic-segmentation Inference ********* -The :class:`~flash.vision.SemanticSegmentation` is already pre-trained on a generated dataset from `CARLA `_ simulator. +A :class:`~flash.vision.SemanticSegmentation` `fcn_resnet50` pre-trained on `CARLA `_ simulator is provided for the inference example. Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inference on any string sequence using :func:`~flash.vision.SemanticSegmentation.predict`: From 0fc35819cddc847cb643ced9b11d54698b3f7e62 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 12:54:43 +0100 Subject: [PATCH 46/53] Add a check --- flash/vision/segmentation/data.py | 16 ++++++-- tests/vision/segmentation/test_data.py | 57 ++++++++++++++++++-------- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 47d154c1a7..08b3565f7d 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -62,10 +62,20 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) - input_data = input_files target_data = target_files + else: + if not isinstance(input_data, list) and not isinstance(target_data, list): + input_data = [input_data] + target_data = [target_data] - if not isinstance(input_data, list) and not isinstance(target_data, list): - input_data = [input_data] - target_data = [target_data] + if len(input_data) != len(target_data): + rank_zero_warn( + f"The number of input files ({len(input_data)}) and number of target files ({len(target_data)}) are" + " different. Some files have been dropped.", + UserWarning, + ) + length = min(len(input_data), len(target_data)) + input_data = input_data[:length] + target_data = target_data[:length] data = filter( lambda sample: ( diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index b3eedf36cc..20620528f9 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -140,17 +140,11 @@ def test_from_folders_warning(self, tmpdir): dm = SemanticSegmentationData.from_folders( train_folder=str(tmp_dir / "images"), train_target_folder=str(tmp_dir / "targets"), - val_folder=str(tmp_dir / "images"), - val_target_folder=str(tmp_dir / "targets"), - test_folder=str(tmp_dir / "images"), - test_target_folder=str(tmp_dir / "targets"), batch_size=1, num_workers=0, ) assert dm is not None assert dm.train_dataloader() is not None - assert dm.val_dataloader() is not None - assert dm.test_dataloader() is not None # check training data data = next(iter(dm.train_dataloader())) @@ -158,18 +152,6 @@ def test_from_folders_warning(self, tmpdir): assert imgs.shape == (1, 3, 196, 196) assert labels.shape == (1, 196, 196) - # check val data - data = next(iter(dm.val_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, 196, 196) - - # check test data - data = next(iter(dm.test_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (1, 3, 196, 196) - assert labels.shape == (1, 196, 196) - def test_from_files(self, tmpdir): tmp_dir = Path(tmpdir) @@ -226,6 +208,45 @@ def test_from_files(self, tmpdir): assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, 196, 196) + def test_from_files_warning(self, tmpdir): + tmp_dir = Path(tmpdir) + + # create random dummy data + + images = [ + str(tmp_dir / "img1.png"), + str(tmp_dir / "img2.png"), + str(tmp_dir / "img3.png"), + ] + + targets = [ + str(tmp_dir / "labels_img1.png"), + str(tmp_dir / "labels_img2.png"), + str(tmp_dir / "labels_img3.png"), + ] + + num_classes: int = 2 + img_size: Tuple[int, int] = (196, 196) + create_random_data(images, targets, img_size, num_classes) + + # instantiate the data module + + with pytest.warns(UserWarning, match="The number of input files"): + dm = SemanticSegmentationData.from_files( + train_files=images, + train_targets=targets + [str(tmp_dir / "labels_img4.png")], + batch_size=2, + num_workers=0, + ) + assert dm is not None + assert dm.train_dataloader() is not None + + # check training data + data = next(iter(dm.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 196, 196) + def test_map_labels(self, tmpdir): tmp_dir = Path(tmpdir) From 18759676df4ac0ef483f3718c5de8d00551f04d5 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 13:00:29 +0100 Subject: [PATCH 47/53] Move KorniaParallelTransforms and add docstring --- flash/data/transforms.py | 23 +++++++++++++++++++++++ flash/vision/segmentation/transforms.py | 25 ++----------------------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/flash/data/transforms.py b/flash/data/transforms.py index 67b457f62b..67b1229ad4 100644 --- a/flash/data/transforms.py +++ b/flash/data/transforms.py @@ -42,3 +42,26 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: result[key] = outputs[i] return result return x + + +class KorniaParallelTransforms(nn.Sequential): + """The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each + input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when + multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask).""" + + def __init__(self, *args): + super().__init__(*[convert_to_modules(arg) for arg in args]) + + def forward(self, inputs: Any): + result = list(inputs) if isinstance(inputs, Sequence) else [inputs] + for transform in self.children(): + inputs = result + for i, input in enumerate(inputs): + if hasattr(transform, "_params") and bool(transform._params): + params = transform._params + result[i] = transform(input, params) + else: # case for non random transforms + result[i] = transform(input) + if hasattr(transform, "_params") and bool(transform._params): + transform._params = None + return result diff --git a/flash/vision/segmentation/transforms.py b/flash/vision/segmentation/transforms.py index 11bb87cde3..1cf491793f 100644 --- a/flash/vision/segmentation/transforms.py +++ b/flash/vision/segmentation/transforms.py @@ -11,35 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Sequence, Tuple +from typing import Callable, Dict, Tuple import kornia as K import torch import torch.nn as nn from flash.data.data_source import DefaultDataKeys -from flash.data.transforms import ApplyToKeys -from flash.data.utils import convert_to_modules - - -class KorniaParallelTransforms(nn.Sequential): - - def __init__(self, *args): - super().__init__(*[convert_to_modules(arg) for arg in args]) - - def forward(self, inputs: Any): - result = list(inputs) if isinstance(inputs, Sequence) else [inputs] - for transform in self.children(): - inputs = result - for i, input in enumerate(inputs): - if hasattr(transform, "_params") and bool(transform._params): - params = transform._params - result[i] = transform(input, params) - else: # case for non random transforms - result[i] = transform(input) - if hasattr(transform, "_params") and bool(transform._params): - transform._params = None - return result +from flash.data.transforms import ApplyToKeys, KorniaParallelTransforms def prepare_target(tensor: torch.Tensor) -> torch.Tensor: From e9dee3015fdf89fca1f60669ec04dc96b8b2fe49 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Mon, 10 May 2021 15:10:07 +0200 Subject: [PATCH 48/53] implement quick test for segmentation labels --- .../vision/segmentation/test_serialisation.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/vision/segmentation/test_serialisation.py diff --git a/tests/vision/segmentation/test_serialisation.py b/tests/vision/segmentation/test_serialisation.py new file mode 100644 index 0000000000..42a1728950 --- /dev/null +++ b/tests/vision/segmentation/test_serialisation.py @@ -0,0 +1,32 @@ +import pytest +import torch + +from flash.vision.segmentation.serialization import SegmentationLabels + + +class TestSemanticSegmentationLabels: + + def test_smoke(self): + serial = SegmentationLabels() + assert serial is not None + assert serial.labels_map is None + assert serial.visualize is False + + def test_serialize(self): + serial = SegmentationLabels() + + sample = torch.zeros(5, 2, 3) + sample[1, 1, 2] = 1 # add peak in class 2 + sample[3, 0, 1] = 1 # add peak in class 4 + + classes = serial.serialize(sample) + assert classes[1, 2] == 1 + assert classes[0, 1] == 3 + + # TODO: implement me + def test_create_random_labels(self): + pass + + # TODO: implement me + def test_labels_to_image(self): + pass From 00491975943be402e819c4f63a2de19d270db730 Mon Sep 17 00:00:00 2001 From: Edgar Riba Date: Mon, 10 May 2021 15:15:24 +0200 Subject: [PATCH 49/53] add small assertion tests --- tests/vision/segmentation/test_serialisation.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/vision/segmentation/test_serialisation.py b/tests/vision/segmentation/test_serialisation.py index 42a1728950..a971c91fbf 100644 --- a/tests/vision/segmentation/test_serialisation.py +++ b/tests/vision/segmentation/test_serialisation.py @@ -12,6 +12,17 @@ def test_smoke(self): assert serial.labels_map is None assert serial.visualize is False + def test_exception(self): + serial = SegmentationLabels() + + with pytest.raises(Exception): + sample = torch.zeros(1, 5, 2, 3) + serial.serialize(sample) + + with pytest.raises(Exception): + sample = torch.zeros(2, 3) + serial.serialize(sample) + def test_serialize(self): serial = SegmentationLabels() From 4c7577415398e0167cc70335215783ef2236b8db Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 14:17:22 +0100 Subject: [PATCH 50/53] Rename test_serialisation.py to test_serialization.py --- .../segmentation/{test_serialisation.py => test_serialization.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/vision/segmentation/{test_serialisation.py => test_serialization.py} (100%) diff --git a/tests/vision/segmentation/test_serialisation.py b/tests/vision/segmentation/test_serialization.py similarity index 100% rename from tests/vision/segmentation/test_serialisation.py rename to tests/vision/segmentation/test_serialization.py From 2d76f37b18ed4c86662e0dcae6e32afecba577c7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 14:53:34 +0100 Subject: [PATCH 51/53] Switch to exception --- flash/vision/segmentation/data.py | 22 +++++++++------------- tests/vision/segmentation/test_data.py | 11 ++--------- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 08b3565f7d..abe9a7796d 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -62,20 +62,16 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) - input_data = input_files target_data = target_files - else: - if not isinstance(input_data, list) and not isinstance(target_data, list): - input_data = [input_data] - target_data = [target_data] - if len(input_data) != len(target_data): - rank_zero_warn( - f"The number of input files ({len(input_data)}) and number of target files ({len(target_data)}) are" - " different. Some files have been dropped.", - UserWarning, - ) - length = min(len(input_data), len(target_data)) - input_data = input_data[:length] - target_data = target_data[:length] + if not isinstance(input_data, list) and not isinstance(target_data, list): + input_data = [input_data] + target_data = [target_data] + + if len(input_data) != len(target_data): + raise MisconfigurationException( + f"The number of input files ({len(input_data)}) and number of target files ({len(target_data)}) must be" + " the same.", + ) data = filter( lambda sample: ( diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 20620528f9..7ac36b017d 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -6,6 +6,7 @@ import pytest import torch from PIL import Image +from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash import Trainer from flash.data.data_source import DefaultDataKeys @@ -231,21 +232,13 @@ def test_from_files_warning(self, tmpdir): # instantiate the data module - with pytest.warns(UserWarning, match="The number of input files"): + with pytest.raises(MisconfigurationException, match="The number of input files"): dm = SemanticSegmentationData.from_files( train_files=images, train_targets=targets + [str(tmp_dir / "labels_img4.png")], batch_size=2, num_workers=0, ) - assert dm is not None - assert dm.train_dataloader() is not None - - # check training data - data = next(iter(dm.train_dataloader())) - imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] - assert imgs.shape == (2, 3, 196, 196) - assert labels.shape == (2, 196, 196) def test_map_labels(self, tmpdir): tmp_dir = Path(tmpdir) From 5f254b66b33c2d7c1b53f7e16a4352fae8e4b357 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 14:54:41 +0100 Subject: [PATCH 52/53] Fix --- tests/vision/segmentation/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index 7ac36b017d..bd51f09d21 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -233,7 +233,7 @@ def test_from_files_warning(self, tmpdir): # instantiate the data module with pytest.raises(MisconfigurationException, match="The number of input files"): - dm = SemanticSegmentationData.from_files( + SemanticSegmentationData.from_files( train_files=images, train_targets=targets + [str(tmp_dir / "labels_img4.png")], batch_size=2, From 8745191af6a56f8ac774f9bb3b70159af24877a8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 10 May 2021 15:14:35 +0100 Subject: [PATCH 53/53] Fixes --- flash/vision/segmentation/data.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index abe9a7796d..d674205786 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -47,21 +47,20 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) - input_data, target_data = data if self.isdir(input_data) and self.isdir(target_data): - files = os.listdir(input_data) - input_files = [os.path.join(input_data, file) for file in files] - target_files = [os.path.join(target_data, file) for file in files] + input_files = os.listdir(input_data) + target_files = os.listdir(target_data) - target_files = list(filter(os.path.isfile, target_files)) + all_files = set(input_files).intersection(set(target_files)) - if len(input_files) != len(target_files): + if len(all_files) != len(input_files) or len(all_files) != len(target_files): rank_zero_warn( f"Found inconsistent files in input_dir: {input_data} and target_dir: {target_data}. Some files" " have been dropped.", UserWarning, ) - input_data = input_files - target_data = target_files + input_data = [os.path.join(input_data, file) for file in all_files] + target_data = [os.path.join(target_data, file) for file in all_files] if not isinstance(input_data, list) and not isinstance(target_data, list): input_data = [input_data]