Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add semantic segmentation task (#239)
Browse files Browse the repository at this point in the history
* semantic segmentation skeleton

* expose and add smoke tests for preproces and datamodule

* data module connections working

* preprocess not crashing(wip)

* implement segmentation sequential

* implement torchvision backbone model

* model working

* implement labels mapping

* add map labels tests

* from filepaths training test not crashing

* non working visualiser

* fix visualiser

* training working

* training not crashing

* cleanup example and move serializer to core

* cleanup model code, tests and docs

* move transforms apart

* implement ApplytransformsToKey augmentations

* relative path

* fix load from pretrained and add resnet 101

* create segmentation keys enum

* sync with master and fix val_split

* move apart segmentation backbones

* fix tests

* fix tests

* fix tests

* fix memory leak issues

* undo function filtering

* fix import

* more fixes for memory leaks

* add segmentation to docs

* add inference example

* add image to docs and update with AdamW

* Make pretrained arg kwarg

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Data sources initial commit

* Update transforms

* Updates

* Fixes

* Fix tests

* Fixes

* Fixes

* Add tests

* Update docs/source/reference/semantic_segmentation.rst

Co-authored-by: thomas chaton <[email protected]>

* Update docs/source/reference/semantic_segmentation.rst

Co-authored-by: thomas chaton <[email protected]>

* Add a check

* Move KorniaParallelTransforms and add docstring

* implement quick test for segmentation labels

* add small assertion tests

* Rename test_serialisation.py to test_serialization.py

* Switch to exception

* Fix

* Fixes

Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ethan Harris <[email protected]>
Co-authored-by: thomas chaton <[email protected]>
  • Loading branch information
5 people authored May 10, 2021
1 parent cfb029e commit 719cf5c
Show file tree
Hide file tree
Showing 23 changed files with 1,320 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,5 @@ wmt_en_ro
action_youtube_naudio
kinetics
movie_posters
CameraRGB
CameraSeg
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Lightning Flash
reference/translation
reference/object_detection
reference/video_classification

reference/semantic_segmentation

.. toctree::
:maxdepth: 1
Expand Down
151 changes: 151 additions & 0 deletions docs/source/reference/semantic_segmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@

.. _semantinc_segmentation:

######################
Semantinc Segmentation
######################

********
The task
********
Semantic segmentation, or image segmentation, is the task of performing classification at a pixel-level, meaning each pixel will associated to a given class. The model output shape is ``(batch_size, num_classes, heigh, width)``.

See more: https://paperswithcode.com/task/semantic-segmentation

.. raw:: html

<p>
<a href="https://i2.wp.com/syncedreview.com/wp-content/uploads/2019/12/image-9-1.png" >
<img src="https://i2.wp.com/syncedreview.com/wp-content/uploads/2019/12/image-9-1.png"/>
</a>
</p>

------

*********
Inference
*********

A :class:`~flash.vision.SemanticSegmentation` `fcn_resnet50` pre-trained on `CARLA <http://carla.org/>`_ simulator is provided for the inference example.


Use the :class:`~flash.vision.SemanticSegmentation` pretrained model for inference on any string sequence using :func:`~flash.vision.SemanticSegmentation.predict`:

.. code-block:: python
# import our libraries
from flash.data.utils import download_data
from flash.vision import SemanticSegmentation
from flash.vision.segmentation.serialization import SegmentationLabels
# 1. Download the data
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
"data/"
)
# 2. Load the model from a checkpoint
model = SemanticSegmentation.load_from_checkpoint(
"https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt"
)
model.serializer = SegmentationLabels(visualize=True)
# 3. Predict what's on a few images and visualize!
predictions = model.predict([
'data/CameraRGB/F61-1.png',
'data/CameraRGB/F62-1.png',
'data/CameraRGB/F63-1.png',
])
For more advanced inference options, see :ref:`predictions`.

------

**********
Finetuning
**********

you now want to customise your model with new data using the same dataset.
Once we download the data using :func:`~flash.data.download_data`, all we need is the train data and validation data folders to create the :class:`~flash.vision.SemanticSegmentationData`.

.. note:: the dataset is structured in a way that each sample (an image and its corresponding labels) is stored in separated directories but keeping the same filename.

.. code-block::
data
├── CameraRGB
│ ├── F61-1.png
│ ├── F61-2.png
│ ...
└── CameraSeg
├── F61-1.png
├── F61-2.png
...
Now all we need is three lines of code to build to train our task!

.. code-block:: python
import flash
from flash.data.utils import download_data
from flash.vision import SemanticSegmentation, SemanticSegmentationData
from flash.vision.segmentation.serialization import SegmentationLabels
# 1. Download the data
download_data(
"https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
"data/"
)
# 2.1 Load the data
datamodule = SemanticSegmentationData.from_folders(
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
batch_size=4,
val_split=0.3,
image_size=(200, 200), # (600, 800)
)
# 2.2 Visualise the samples
labels_map = SegmentationLabels.create_random_labels_map(num_classes=21)
datamodule.set_labels_map(labels_map)
datamodule.show_train_batch(["load_sample", "post_tensor_transform"])
# 3. Build the model
model = SemanticSegmentation(backbone="torchvision/fcn_resnet50", num_classes=21)
# 4. Create the trainer.
trainer = flash.Trainer(max_epochs=1)
# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy='freeze')
# 7. Save it!
trainer.save_checkpoint("semantic_segmentation_model.pt")
------

*************
API reference
*************

.. _segmentation:

SemanticSegmentation
--------------------

.. autoclass:: flash.vision.SemanticSegmentation
:members:
:exclude-members: forward

.. _segmentation_data:

SemanticSegmentationData
------------------------

.. autoclass:: flash.vision.SemanticSegmentationData

.. automethod:: flash.vision.SemanticSegmentationData.from_folders

.. autoclass:: flash.vision.SemanticSegmentationPreprocess
3 changes: 2 additions & 1 deletion flash/core/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(
def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
if getattr(self.hparams, "multi_label", False):
return torch.sigmoid(x)
return torch.softmax(x, -1)
# we'll assume that the data always comes as `(B, C, ...)`
return torch.softmax(x, dim=1)


class ClassificationSerializer(Serializer):
Expand Down
6 changes: 6 additions & 0 deletions flash/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def __init__(
self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess)

def forward(self, samples: Sequence[Any]) -> Any:
# we create a new dict to prevent from potential memory leaks
# assuming that the dictionary samples are stored in between and
# potentially modified before the transforms are applied.
if isinstance(samples, dict):
samples = dict(samples.items())

with self._current_stage_context:

if self.apply_per_sample_transform:
Expand Down
2 changes: 2 additions & 0 deletions flash/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def _show_batch(self, stage: str, func_names: Union[str, List[str]], reset: bool
_ = next(iter_dataloader)
data_fetcher: BaseVisualization = self.data_fetcher
data_fetcher._show(stage, func_names)
if reset:
self.data_fetcher.batches[stage] = {}

def show_train_batch(self, hooks_names: Union[str, List[str]] = 'load_sample', reset: bool = True) -> None:
"""This function is used to visualize a batch from the train dataloader."""
Expand Down
34 changes: 30 additions & 4 deletions flash/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,41 @@ def __init__(self, keys: Union[str, Sequence[str]], *args):
self.keys = keys

def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
inputs = [x[key] for key in filter(lambda key: key in x, self.keys)]
keys = list(filter(lambda key: key in x, self.keys))
inputs = [x[key] for key in keys]
if len(inputs) > 0:
outputs = super().forward(*inputs)
if not isinstance(outputs, tuple):
if len(inputs) == 1:
inputs = inputs[0]
outputs = super().forward(inputs)
if not isinstance(outputs, Sequence):
outputs = (outputs, )

result = {}
result.update(x)
for i, key in enumerate(self.keys):
for i, key in enumerate(keys):
result[key] = outputs[i]
return result
return x


class KorniaParallelTransforms(nn.Sequential):
"""The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each
input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when
multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask)."""

def __init__(self, *args):
super().__init__(*[convert_to_modules(arg) for arg in args])

def forward(self, inputs: Any):
result = list(inputs) if isinstance(inputs, Sequence) else [inputs]
for transform in self.children():
inputs = result
for i, input in enumerate(inputs):
if hasattr(transform, "_params") and bool(transform._params):
params = transform._params
result[i] = transform(input, params)
else: # case for non random transforms
result[i] = transform(input)
if hasattr(transform, "_params") and bool(transform._params):
transform._params = None
return result
1 change: 1 addition & 0 deletions flash/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from flash.vision.classification import ImageClassificationData, ImageClassificationPreprocess, ImageClassifier
from flash.vision.detection import ObjectDetectionData, ObjectDetector
from flash.vision.embedding import ImageEmbedder
from flash.vision.segmentation import SemanticSegmentation, SemanticSegmentationData, SemanticSegmentationPreprocess
2 changes: 1 addition & 1 deletion flash/vision/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def collate(self, samples: Sequence[Dict[str, Any]]) -> Any:
for key in sample.keys():
if torch.is_tensor(sample[key]):
sample[key] = sample[key].squeeze(0)
return default_collate(samples)
return super().collate(samples)

@property
def default_train_transforms(self) -> Optional[Dict[str, Callable]]:
Expand Down
2 changes: 2 additions & 0 deletions flash/vision/segmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from flash.vision.segmentation.data import SemanticSegmentationData, SemanticSegmentationPreprocess
from flash.vision.segmentation.model import SemanticSegmentation
36 changes: 36 additions & 0 deletions flash/vision/segmentation/backbones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.utils.imports import _TORCHVISION_AVAILABLE

if _TORCHVISION_AVAILABLE:
import torchvision

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")


@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50")
def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet50(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
return model


@SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet101")
def load_torchvision_fcn_resnet101(num_classes: int, pretrained: bool = True) -> nn.Module:
model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained)
model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
return model
Loading

0 comments on commit 719cf5c

Please sign in to comment.