Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[Feat] Add PointCloud Segmentation #566

Merged
merged 25 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: ['text']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: ['pointcloud']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,4 @@ CameraRGB
CameraSeg
jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
logs/cache/*
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575))

- Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566))


### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ For help or questions, join our huge community on [Slack](https://join.slack.com
## Citations
We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffee, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors.

Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts).
Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [open3d-ml](https://github.com/intel-isl/Open3D-ML) for pointcloud, [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts).

## License
Please observe the Apache 2.0 license that is listed in this repository. In addition
Expand Down
24 changes: 24 additions & 0 deletions docs/source/api/pointcloud.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
################
flash.pointcloud
################

.. contents::
:depth: 1
:local:
:backlinks: top

.. currentmodule:: flash.pointcloud

Segmentation
____________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~segmentation.data.PointCloudSegmentationData
~segmentation.data.PointCloudSegmentationPreprocess
~segmentation.data.PointCloudSegmentationFoldersDataSource
~segmentation.data.PointCloudSegmentationDatasetDataSource
~segmentation.model.PointCloudSegmentation
tchaton marked this conversation as resolved.
Show resolved Hide resolved
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ Lightning Flash
reference/summarization
reference/translation

.. toctree::
:maxdepth: 1
:caption: PointCloud

reference/pointcloud_segmentation

.. toctree::
:maxdepth: 1
:caption: Integrations
Expand All @@ -70,6 +76,7 @@ Lightning Flash
api/data
api/serve
api/image
api/pointcloud
api/tabular
api/text
api/video
Expand Down
73 changes: 73 additions & 0 deletions docs/source/reference/pointcloud_segmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

.. _pointcloud_segmentation:

#######################
PointCloud Segmentation
#######################

********
The Task
********

A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates.

PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class.
The current integration builds on top `Open3D-ML <https://github.com/intel-isl/Open3D-ML>`_.

------

*******
Example
*******

Let's look at an example using a data set generated from the `KITTI Vision Benchmark <http://www.semantic-kitti.org/dataset.html>`_.
The data are a tiny subset of the original dataset and contains sequences of point clouds.
The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map.
A sequence should contain one folder for scans and one folder for labels, plus a ``pose.txt`` to re-align the sequence if required.
Here's the structure:

.. code-block::

data
├── meta.yaml
├── 00
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── labels
| | ├── 00000.label
| | ├── 00001.label
| | ...
| ├── pose.txt
│ ...
|
└── XX
├── scans
| ├── 00000.bin
| ├── 00001.bin
| ...
├── labels
| ├── 00000.label
| ├── 00001.label
| ...
├── pose.txt


Learn more: http://www.semantic-kitti.org/dataset.html


Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.PointCloudSegmentationData`.
We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.segmentation.model.PointCloudSegmentation` task.
We then use the trained :class:`~flash.image.segmentation.model.PointCloudSegmentation` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/pointcloud_segmentation.py
:language: python
:lines: 14-



.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/getting_started_ml_visualizer.gif
:width: 100%
6 changes: 3 additions & 3 deletions docs/source/reference/semantic_segmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ Here's the structure:
├── F61-2.png
...

Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.SemanticSegmentationData`.
We select a pre-trained ``mobilenet_v3_large`` backbone with an ``fpn`` head to use for our :class:`~flash.image.segmentation.model.SemanticSegmentation` task and fine-tune on the CARLA data.
We then use the trained :class:`~flash.image.segmentation.model.SemanticSegmentation` for inference.
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.PointCloudSegmentationData`.
We select a pre-trained ``randlanet_PointCloud_kitti`` backbone for our :class:`~flash.image.segmentation.model.PointCloudSegmentation` task and fine-tune on the CARLA data.
We then use the trained :class:`~flash.image.segmentation.model.PointCloudSegmentation` for inference.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
Finally, we save the model.
Here's the full example:

Expand Down
3 changes: 1 addition & 2 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def __init__(
@staticmethod
def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]:
if isinstance(batch, Mapping):
return batch, batch.get(DefaultDataKeys.METADATA, None)
return batch, batch.pop(DefaultDataKeys.METADATA, None)
return batch, None

def forward(self, batch: Sequence[Any]):
Expand Down Expand Up @@ -331,7 +331,6 @@ def __str__(self) -> str:
def default_uncollate(batch: Any):
"""
This function is used to uncollate a batch into samples.

Examples:
>>> a, b = default_uncollate(torch.rand((2,1)))
"""
Expand Down
77 changes: 64 additions & 13 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,37 +275,79 @@ def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) ->
def _train_dataloader(self) -> DataLoader:
train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds
shuffle: bool = False
collate_fn = self._resolve_collate_fn(train_ds, RunningStage.TRAINING)
sampler = self.sampler
drop_last = False
pin_memory = True

if self.sampler is None:
shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset))

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_train_dataset(
train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
sampler=sampler
)

return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=self.sampler,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
collate_fn=self._resolve_collate_fn(train_ds, RunningStage.TRAINING)
pin_memory=pin_memory,
drop_last=drop_last,
collate_fn=collate_fn
)

def _val_dataloader(self) -> DataLoader:
val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds
collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_val_dataset(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
pin_memory=pin_memory,
collate_fn=collate_fn
)

def _test_dataloader(self) -> DataLoader:
test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds
collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(test_ds, RunningStage.TESTING)
pin_memory=pin_memory,
collate_fn=collate_fn
)

def _predict_dataloader(self) -> DataLoader:
Expand All @@ -314,12 +356,21 @@ def _predict_dataloader(self) -> DataLoader:
batch_size = self.batch_size
else:
batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1)

collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING)
predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=collate_fn
)

@property
Expand Down
6 changes: 6 additions & 0 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flash.core.data.callback import FlashCallback
from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources
from flash.core.data.properties import Properties
from flash.core.data.states import CollateFn
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext


Expand Down Expand Up @@ -361,6 +362,11 @@ def per_batch_transform(self, batch: Any) -> Any:

def collate(self, samples: Sequence) -> Any:
""" Transform to convert a sequence of samples to a collated batch. """

collate_fn = self.get_state(CollateFn)
if collate_fn is not None:
return collate_fn.collate_fn(samples)

current_transform = self.current_transform
if current_transform is self._identity:
return self._default_collate(samples)
Expand Down
10 changes: 10 additions & 0 deletions flash/core/data/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass
from typing import Callable, Optional

from flash.core.data.properties import ProcessState


@dataclass(unsafe_hash=True, frozen=True)
class CollateFn(ProcessState):

collate_fn: Optional[Callable] = None
Loading