Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
[Feat] Add PointCloud Segmentation (#566)
Browse files Browse the repository at this point in the history
* update

* wip

* update

* update

* update

* resolve issues

* update

* update

* add doc

* update

* add tests

* update

* update tests

* update on comments

* update

* update

* resolve some bugs

* remove breakpoint

* Update docs/source/api/pointcloud.rst

* update

Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
tchaton and ethanwharris authored Jul 14, 2021
1 parent a340464 commit 9c42528
Show file tree
Hide file tree
Showing 33 changed files with 1,311 additions and 22 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: ['text']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: ['pointcloud']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,5 @@ CameraRGB
CameraSeg
jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
logs/cache/*
flash_examples/data
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575))

- Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566))

- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73))

### Changed
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ For help or questions, join our huge community on [Slack](https://join.slack.com
## Citations
We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffee, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors.

Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts).
Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [open3d-ml](https://github.com/intel-isl/Open3D-ML) for pointcloud, [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts).

## License
Please observe the Apache 2.0 license that is listed in this repository. In addition
Expand Down
25 changes: 25 additions & 0 deletions docs/source/api/pointcloud.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
################
flash.pointcloud
################

.. contents::
:depth: 1
:local:
:backlinks: top

.. currentmodule:: flash.pointcloud

Segmentation
____________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~segmentation.model.PointCloudSegmentation
~segmentation.data.PointCloudSegmentationData

segmentation.data.PointCloudSegmentationPreprocess
segmentation.data.PointCloudSegmentationFoldersDataSource
segmentation.data.PointCloudSegmentationDatasetDataSource
7 changes: 7 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ Lightning Flash
reference/summarization
reference/translation

.. toctree::
:maxdepth: 1
:caption: PointCloud

reference/pointcloud_segmentation

.. toctree::
:maxdepth: 1
:caption: Graph
Expand All @@ -76,6 +82,7 @@ Lightning Flash
api/data
api/serve
api/image
api/pointcloud
api/tabular
api/text
api/video
Expand Down
73 changes: 73 additions & 0 deletions docs/source/reference/pointcloud_segmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

.. _pointcloud_segmentation:

#######################
PointCloud Segmentation
#######################

********
The Task
********

A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates.

PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class.
The current integration builds on top `Open3D-ML <https://github.com/intel-isl/Open3D-ML>`_.

------

*******
Example
*******

Let's look at an example using a data set generated from the `KITTI Vision Benchmark <http://www.semantic-kitti.org/dataset.html>`_.
The data are a tiny subset of the original dataset and contains sequences of point clouds.
The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map.
A sequence should contain one folder for scans and one folder for labels, plus a ``pose.txt`` to re-align the sequence if required.
Here's the structure:

.. code-block::
data
├── meta.yaml
├── 00
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── labels
| | ├── 00000.label
| | ├── 00001.label
| | ...
| ├── pose.txt
│ ...
|
└── XX
├── scans
| ├── 00000.bin
| ├── 00001.bin
| ...
├── labels
| ├── 00000.label
| ├── 00001.label
| ...
├── pose.txt
Learn more: http://www.semantic-kitti.org/dataset.html


Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.PointCloudSegmentationData`.
We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.segmentation.model.PointCloudSegmentation` task.
We then use the trained :class:`~flash.image.segmentation.model.PointCloudSegmentation` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/pointcloud_segmentation.py
:language: python
:lines: 14-



.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/getting_started_ml_visualizer.gif
:width: 100%
8 changes: 4 additions & 4 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,10 @@ def __init__(

@staticmethod
def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]:
if isinstance(batch, Mapping):
return batch, batch.get(DefaultDataKeys.METADATA, None)
return batch, None
metadata = None
if isinstance(batch, Mapping) and DefaultDataKeys.METADATA in batch:
metadata = batch.pop(DefaultDataKeys.METADATA, None)
return batch, metadata

def forward(self, batch: Sequence[Any]):
batch, metadata = self._extract_metadata(batch)
Expand Down Expand Up @@ -331,7 +332,6 @@ def __str__(self) -> str:
def default_uncollate(batch: Any):
"""
This function is used to uncollate a batch into samples.
Examples:
>>> a, b = default_uncollate(torch.rand((2,1)))
"""
Expand Down
74 changes: 62 additions & 12 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,37 +275,78 @@ def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) ->
def _train_dataloader(self) -> DataLoader:
train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds
shuffle: bool = False
collate_fn = self._resolve_collate_fn(train_ds, RunningStage.TRAINING)
drop_last = False
pin_memory = True

if self.sampler is None:
shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset))

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_train_dataset(
train_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
sampler=self.sampler
)

return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=self.sampler,
num_workers=self.num_workers,
pin_memory=True,
drop_last=True,
collate_fn=self._resolve_collate_fn(train_ds, RunningStage.TRAINING)
pin_memory=pin_memory,
drop_last=drop_last,
collate_fn=collate_fn
)

def _val_dataloader(self) -> DataLoader:
val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds
collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_val_dataset(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(val_ds, RunningStage.VALIDATING)
pin_memory=pin_memory,
collate_fn=collate_fn
)

def _test_dataloader(self) -> DataLoader:
test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds
collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(test_ds, RunningStage.TESTING)
pin_memory=pin_memory,
collate_fn=collate_fn
)

def _predict_dataloader(self) -> DataLoader:
Expand All @@ -314,12 +355,21 @@ def _predict_dataloader(self) -> DataLoader:
batch_size = self.batch_size
else:
batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1)

collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING)
pin_memory = True

if isinstance(getattr(self, "trainer", None), pl.Trainer):
return self.trainer.lightning_module.process_test_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=pin_memory,
collate_fn=collate_fn
)

return DataLoader(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING)
predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=collate_fn
)

@property
Expand Down
7 changes: 7 additions & 0 deletions flash/core/data/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from flash.core.data.callback import FlashCallback
from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources
from flash.core.data.properties import Properties
from flash.core.data.states import CollateFn
from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext


Expand Down Expand Up @@ -361,6 +362,12 @@ def per_batch_transform(self, batch: Any) -> Any:

def collate(self, samples: Sequence) -> Any:
""" Transform to convert a sequence of samples to a collated batch. """

# the model can provide a custom ``collate_fn``.
collate_fn = self.get_state(CollateFn)
if collate_fn is not None:
return collate_fn.collate_fn(samples)

current_transform = self.current_transform
if current_transform is self._identity:
return self._default_collate(samples)
Expand Down
10 changes: 10 additions & 0 deletions flash/core/data/states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass
from typing import Callable, Optional

from flash.core.data.properties import ProcessState


@dataclass(unsafe_hash=True, frozen=True)
class CollateFn(ProcessState):

collate_fn: Optional[Callable] = None
Loading

0 comments on commit 9c42528

Please sign in to comment.