Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[Feat] Add PointCloud Segmentation #566

Merged
merged 25 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: ['text']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: ['pointcloud']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ CameraRGB
CameraSeg
jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
logs/cache/*
flash_examples/data
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575))

- Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566))

- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73))

### Changed
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ For help or questions, join our huge community on [Slack](https://join.slack.com
## Citations
We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffee, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors.

Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts).
Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [open3d-ml](https://github.com/intel-isl/Open3D-ML) for pointcloud, [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts).

## License
Please observe the Apache 2.0 license that is listed in this repository. In addition
Expand Down
25 changes: 25 additions & 0 deletions docs/source/api/pointcloud.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
################
flash.pointcloud
################

.. contents::
:depth: 1
:local:
:backlinks: top

.. currentmodule:: flash.pointcloud

Segmentation
____________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~segmentation.model.PointCloudSegmentation
~segmentation.data.PointCloudSegmentationData

segmentation.data.PointCloudSegmentationPreprocess
segmentation.data.PointCloudSegmentationFoldersDataSource
segmentation.data.PointCloudSegmentationDatasetDataSource
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ Lightning Flash
reference/summarization
reference/translation

.. toctree::
:maxdepth: 1
:caption: PointCloud

reference/pointcloud_segmentation

.. toctree::
:maxdepth: 1
:caption: Graph
Expand All @@ -76,6 +82,7 @@ Lightning Flash
api/data
api/serve
api/image
api/pointcloud
api/tabular
api/text
api/video
Expand Down
73 changes: 73 additions & 0 deletions docs/source/reference/pointcloud_segmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

.. _pointcloud_segmentation:

#######################
PointCloud Segmentation
#######################

********
The Task
********

A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates.

PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class.
The current integration builds on top `Open3D-ML <https://github.com/intel-isl/Open3D-ML>`_.

------

*******
Example
*******

Let's look at an example using a data set generated from the `KITTI Vision Benchmark <http://www.semantic-kitti.org/dataset.html>`_.
The data are a tiny subset of the original dataset and contains sequences of point clouds.
The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map.
A sequence should contain one folder for scans and one folder for labels, plus a ``pose.txt`` to re-align the sequence if required.
Here's the structure:

.. code-block::

data
├── meta.yaml
├── 00
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── labels
| | ├── 00000.label
| | ├── 00001.label
| | ...
| ├── pose.txt
│ ...
|
└── XX
├── scans
| ├── 00000.bin
| ├── 00001.bin
| ...
├── labels
| ├── 00000.label
| ├── 00001.label
| ...
├── pose.txt


Learn more: http://www.semantic-kitti.org/dataset.html


Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.PointCloudSegmentationData`.
We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.segmentation.model.PointCloudSegmentation` task.
We then use the trained :class:`~flash.image.segmentation.model.PointCloudSegmentation` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/pointcloud_segmentation.py
:language: python
:lines: 14-



.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/getting_started_ml_visualizer.gif
:width: 100%
8 changes: 4 additions & 4 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,10 @@ def __init__(

@staticmethod
def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]:
if isinstance(batch, Mapping):
return batch, batch.get(DefaultDataKeys.METADATA, None)
return batch, None
metadata = None
if isinstance(batch, Mapping) and DefaultDataKeys.METADATA in batch:
metadata = batch.pop(DefaultDataKeys.METADATA, None)
return batch, metadata

def forward(self, batch: Sequence[Any]):
batch, metadata = self._extract_metadata(batch)
Expand Down Expand Up @@ -331,7 +332,6 @@ def __str__(self) -> str:
def default_uncollate(batch: Any):
"""
This function is used to uncollate a batch into samples.

Examples:
>>> a, b = default_uncollate(torch.rand((2,1)))
"""
Expand Down
74 changes: 62 additions & 12 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,37 +275,78 @@ def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) ->
def _train_dataloader(self) -> DataLoader:
train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds
shuffle: bool = False
collate_fn = self._resolve_collate_fn(train_ds, RunningStage.TRAINING)
drop_last = False
pin_memory = True

if self.sampler is None:
shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset))

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_train_dataset(
train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
sampler=self.sampler
)

return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=self.sampler,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
collate_fn=self._resolve_collate_fn(train_ds, RunningStage.TRAINING)
pin_memory=pin_memory,
drop_last=drop_last,
collate_fn=collate_fn
)

def _val_dataloader(self) -> DataLoader:
val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds
collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_val_dataset(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
pin_memory=pin_memory,
collate_fn=collate_fn
)

def _test_dataloader(self) -> DataLoader:
test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds
collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(test_ds, RunningStage.TESTING)
pin_memory=pin_memory,
collate_fn=collate_fn
)

def _predict_dataloader(self) -> DataLoader:
Expand All @@ -314,12 +355,21 @@ def _predict_dataloader(self) -> DataLoader:
batch_size = self.batch_size
else:
batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1)

collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING)
predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=collate_fn
)

@property
Expand Down
7 changes: 7 additions & 0 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flash.core.data.callback import FlashCallback
from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources
from flash.core.data.properties import Properties
from flash.core.data.states import CollateFn
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext


Expand Down Expand Up @@ -361,6 +362,12 @@ def per_batch_transform(self, batch: Any) -> Any:

def collate(self, samples: Sequence) -> Any:
""" Transform to convert a sequence of samples to a collated batch. """

# the model can provide a custom ``collate_fn``.
collate_fn = self.get_state(CollateFn)
if collate_fn is not None:
return collate_fn.collate_fn(samples)

current_transform = self.current_transform
if current_transform is self._identity:
return self._default_collate(samples)
Expand Down
10 changes: 10 additions & 0 deletions flash/core/data/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass
from typing import Callable, Optional

from flash.core.data.properties import ProcessState


@dataclass(unsafe_hash=True, frozen=True)
class CollateFn(ProcessState):

collate_fn: Optional[Callable] = None
Loading