From 9c42528b68d2f31b2a5dbbfd372238f66f536684 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 14 Jul 2021 21:14:16 +0200 Subject: [PATCH] [Feat] Add PointCloud Segmentation (#566) * update * wip * update * update * update * resolve issues * update * update * add doc * update * add tests * update * update tests * update on comments * update * update * resolve some bugs * remove breakpoint * Update docs/source/api/pointcloud.rst * update Co-authored-by: Ethan Harris --- .github/workflows/ci-testing.yml | 4 + .gitignore | 1 + CHANGELOG.md | 2 + README.md | 2 +- docs/source/api/pointcloud.rst | 25 ++ docs/source/index.rst | 7 + .../reference/pointcloud_segmentation.rst | 73 ++++++ flash/core/data/batch.py | 8 +- flash/core/data/data_module.py | 74 +++++- flash/core/data/process.py | 7 + flash/core/data/states.py | 10 + flash/core/model.py | 148 +++++++++++- flash/core/utilities/imports.py | 3 + flash/image/classification/data.py | 5 +- flash/pointcloud/__init__.py | 3 + flash/pointcloud/segmentation/__init__.py | 2 + flash/pointcloud/segmentation/backbones.py | 19 ++ flash/pointcloud/segmentation/data.py | 103 ++++++++ flash/pointcloud/segmentation/datasets.py | 47 ++++ flash/pointcloud/segmentation/model.py | 226 ++++++++++++++++++ .../segmentation/open3d_ml/__init__.py | 0 .../pointcloud/segmentation/open3d_ml/app.py | 101 ++++++++ .../segmentation/open3d_ml/backbones.py | 79 ++++++ .../open3d_ml/sequences_dataset.py | 181 ++++++++++++++ flash_examples/pointcloud_segmentation.py | 41 ++++ .../visualizations/pointcloud_segmentation.py | 45 ++++ requirements.txt | 2 +- requirements/datatype_pointcloud.txt | 4 + setup.py | 5 +- tests/examples/test_scripts.py | 13 +- tests/helpers/utils.py | 3 + tests/pointcloud/segmentation/test_data.py | 57 +++++ tests/pointcloud/segmentation/test_model.py | 33 +++ 33 files changed, 1311 insertions(+), 22 deletions(-) create mode 100644 docs/source/api/pointcloud.rst create mode 100644 docs/source/reference/pointcloud_segmentation.rst create mode 100644 flash/core/data/states.py create mode 100644 flash/pointcloud/__init__.py create mode 100644 flash/pointcloud/segmentation/__init__.py create mode 100644 flash/pointcloud/segmentation/backbones.py create mode 100644 flash/pointcloud/segmentation/data.py create mode 100644 flash/pointcloud/segmentation/datasets.py create mode 100644 flash/pointcloud/segmentation/model.py create mode 100644 flash/pointcloud/segmentation/open3d_ml/__init__.py create mode 100644 flash/pointcloud/segmentation/open3d_ml/app.py create mode 100644 flash/pointcloud/segmentation/open3d_ml/backbones.py create mode 100644 flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py create mode 100644 flash_examples/pointcloud_segmentation.py create mode 100644 flash_examples/visualizations/pointcloud_segmentation.py create mode 100644 requirements/datatype_pointcloud.txt create mode 100644 tests/pointcloud/segmentation/test_data.py create mode 100644 tests/pointcloud/segmentation/test_model.py diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 6a5e2a67b7..d26d8ecee2 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -49,6 +49,10 @@ jobs: python-version: 3.8 requires: 'latest' topic: ['text'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['pointcloud'] - os: ubuntu-20.04 python-version: 3.8 requires: 'latest' diff --git a/.gitignore b/.gitignore index 22806ac066..48be6f46a7 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ CameraRGB CameraSeg jigsaw_toxic_comments flash_examples/serve/tabular_classification/data +logs/cache/* flash_examples/data diff --git a/CHANGELOG.md b/CHANGELOG.md index afdf24e5da..966e910304 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575)) +- Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566)) + - Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73)) ### Changed diff --git a/README.md b/README.md index 2fea03b506..b5d9a59187 100644 --- a/README.md +++ b/README.md @@ -605,7 +605,7 @@ For help or questions, join our huge community on [Slack](https://join.slack.com ## Citations We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffee, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors. -Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts). +Flash leverages models from [torchvision](https://pytorch.org/vision/stable/index.html), [huggingface/transformers](https://huggingface.co/transformers/), [timm](https://github.com/rwightman/pytorch-image-models), [open3d-ml](https://github.com/intel-isl/Open3D-ML) for pointcloud, [pytorch-tabnet](https://dreamquark-ai.github.io/tabnet/), and [asteroid](https://github.com/asteroid-team/asteroid) for the `vision`, `text`, `tabular`, and `audio` tasks respectively. Also supports self-supervised backbones from [bolts](https://github.com/PyTorchLightning/lightning-bolts). ## License Please observe the Apache 2.0 license that is listed in this repository. In addition diff --git a/docs/source/api/pointcloud.rst b/docs/source/api/pointcloud.rst new file mode 100644 index 0000000000..d29a3d4e32 --- /dev/null +++ b/docs/source/api/pointcloud.rst @@ -0,0 +1,25 @@ +################ +flash.pointcloud +################ + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +.. currentmodule:: flash.pointcloud + +Segmentation +____________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~segmentation.model.PointCloudSegmentation + ~segmentation.data.PointCloudSegmentationData + + segmentation.data.PointCloudSegmentationPreprocess + segmentation.data.PointCloudSegmentationFoldersDataSource + segmentation.data.PointCloudSegmentationDatasetDataSource diff --git a/docs/source/index.rst b/docs/source/index.rst index 0718b4d4fb..9630e55e23 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -55,6 +55,12 @@ Lightning Flash reference/summarization reference/translation +.. toctree:: + :maxdepth: 1 + :caption: PointCloud + + reference/pointcloud_segmentation + .. toctree:: :maxdepth: 1 :caption: Graph @@ -76,6 +82,7 @@ Lightning Flash api/data api/serve api/image + api/pointcloud api/tabular api/text api/video diff --git a/docs/source/reference/pointcloud_segmentation.rst b/docs/source/reference/pointcloud_segmentation.rst new file mode 100644 index 0000000000..eb4a576492 --- /dev/null +++ b/docs/source/reference/pointcloud_segmentation.rst @@ -0,0 +1,73 @@ + +.. _pointcloud_segmentation: + +####################### +PointCloud Segmentation +####################### + +******** +The Task +******** + +A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates. + +PointCloud Segmentation is the task of performing classification at a point-level, meaning each point will associated to a given class. +The current integration builds on top `Open3D-ML `_. + +------ + +******* +Example +******* + +Let's look at an example using a data set generated from the `KITTI Vision Benchmark `_. +The data are a tiny subset of the original dataset and contains sequences of point clouds. +The data contains multiple folder, one for each sequence and a meta.yaml file describing the classes and their official associated color map. +A sequence should contain one folder for scans and one folder for labels, plus a ``pose.txt`` to re-align the sequence if required. +Here's the structure: + +.. code-block:: + + data + ├── meta.yaml + ├── 00 + │ ├── scans + | | ├── 00000.bin + | | ├── 00001.bin + | | ... + │ ├── labels + | | ├── 00000.label + | | ├── 00001.label + | | ... + | ├── pose.txt + │ ... + | + └── XX + ├── scans + | ├── 00000.bin + | ├── 00001.bin + | ... + ├── labels + | ├── 00000.label + | ├── 00001.label + | ... + ├── pose.txt + + +Learn more: http://www.semantic-kitti.org/dataset.html + + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.segmentation.data.PointCloudSegmentationData`. +We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.segmentation.model.PointCloudSegmentation` task. +We then use the trained :class:`~flash.image.segmentation.model.PointCloudSegmentation` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/pointcloud_segmentation.py + :language: python + :lines: 14- + + + +.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/getting_started_ml_visualizer.gif + :width: 100% diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index 12505bf181..51d28d2a22 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -289,9 +289,10 @@ def __init__( @staticmethod def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]: - if isinstance(batch, Mapping): - return batch, batch.get(DefaultDataKeys.METADATA, None) - return batch, None + metadata = None + if isinstance(batch, Mapping) and DefaultDataKeys.METADATA in batch: + metadata = batch.pop(DefaultDataKeys.METADATA, None) + return batch, metadata def forward(self, batch: Sequence[Any]): batch, metadata = self._extract_metadata(batch) @@ -331,7 +332,6 @@ def __str__(self) -> str: def default_uncollate(batch: Any): """ This function is used to uncollate a batch into samples. - Examples: >>> a, b = default_uncollate(torch.rand((2,1))) """ diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index ce25412418..0cdfc99ed3 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -275,37 +275,78 @@ def _resolve_collate_fn(self, dataset: Dataset, running_stage: RunningStage) -> def _train_dataloader(self) -> DataLoader: train_ds: Dataset = self._train_ds() if isinstance(self._train_ds, Callable) else self._train_ds shuffle: bool = False + collate_fn = self._resolve_collate_fn(train_ds, RunningStage.TRAINING) + drop_last = False + pin_memory = True + if self.sampler is None: shuffle = not isinstance(train_ds, (IterableDataset, IterableAutoDataset)) + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_train_dataset( + train_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn, + sampler=self.sampler + ) + return DataLoader( train_ds, batch_size=self.batch_size, shuffle=shuffle, sampler=self.sampler, num_workers=self.num_workers, - pin_memory=True, - drop_last=True, - collate_fn=self._resolve_collate_fn(train_ds, RunningStage.TRAINING) + pin_memory=pin_memory, + drop_last=drop_last, + collate_fn=collate_fn ) def _val_dataloader(self) -> DataLoader: val_ds: Dataset = self._val_ds() if isinstance(self._val_ds, Callable) else self._val_ds + collate_fn = self._resolve_collate_fn(val_ds, RunningStage.VALIDATING) + pin_memory = True + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_val_dataset( + val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn + ) + return DataLoader( val_ds, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, - collate_fn=self._resolve_collate_fn(val_ds, RunningStage.VALIDATING) + pin_memory=pin_memory, + collate_fn=collate_fn ) def _test_dataloader(self) -> DataLoader: test_ds: Dataset = self._test_ds() if isinstance(self._test_ds, Callable) else self._test_ds + collate_fn = self._resolve_collate_fn(test_ds, RunningStage.TESTING) + pin_memory = True + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_test_dataset( + test_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn + ) + return DataLoader( test_ds, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, - collate_fn=self._resolve_collate_fn(test_ds, RunningStage.TESTING) + pin_memory=pin_memory, + collate_fn=collate_fn ) def _predict_dataloader(self) -> DataLoader: @@ -314,12 +355,21 @@ def _predict_dataloader(self) -> DataLoader: batch_size = self.batch_size else: batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1) + + collate_fn = self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) + pin_memory = True + + if isinstance(getattr(self, "trainer", None), pl.Trainer): + return self.trainer.lightning_module.process_test_dataset( + predict_ds, + batch_size=batch_size, + num_workers=self.num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn + ) + return DataLoader( - predict_ds, - batch_size=batch_size, - num_workers=self.num_workers, - pin_memory=True, - collate_fn=self._resolve_collate_fn(predict_ds, RunningStage.PREDICTING) + predict_ds, batch_size=batch_size, num_workers=self.num_workers, pin_memory=True, collate_fn=collate_fn ) @property diff --git a/flash/core/data/process.py b/flash/core/data/process.py index d3a767d161..7020e32d36 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -26,6 +26,7 @@ from flash.core.data.callback import FlashCallback from flash.core.data.data_source import DatasetDataSource, DataSource, DefaultDataSources from flash.core.data.properties import Properties +from flash.core.data.states import CollateFn from flash.core.data.utils import _PREPROCESS_FUNCS, _STAGES_PREFIX, convert_to_modules, CurrentRunningStageFuncContext @@ -361,6 +362,12 @@ def per_batch_transform(self, batch: Any) -> Any: def collate(self, samples: Sequence) -> Any: """ Transform to convert a sequence of samples to a collated batch. """ + + # the model can provide a custom ``collate_fn``. + collate_fn = self.get_state(CollateFn) + if collate_fn is not None: + return collate_fn.collate_fn(samples) + current_transform = self.current_transform if current_transform is self._identity: return self._default_collate(samples) diff --git a/flash/core/data/states.py b/flash/core/data/states.py new file mode 100644 index 0000000000..5755e7445f --- /dev/null +++ b/flash/core/data/states.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass +from typing import Callable, Optional + +from flash.core.data.properties import ProcessState + + +@dataclass(unsafe_hash=True, frozen=True) +class CollateFn(ProcessState): + + collate_fn: Optional[Callable] = None diff --git a/flash/core/model.py b/flash/core/model.py index 31abeb3b94..8bf0be76ac 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -28,8 +28,10 @@ from torch import nn from torch.optim.lr_scheduler import _LRScheduler from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader, Sampler import flash +from flash.core.data.auto_dataset import BaseAutoDataset from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.data_source import DataSource from flash.core.data.process import ( @@ -40,6 +42,7 @@ Serializer, SerializerMapping, ) +from flash.core.data.properties import ProcessState from flash.core.registry import FlashRegistry from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.serve import Composition @@ -154,6 +157,9 @@ def __init__( # TODO: create enum values to define what are the exact states self._data_pipeline_state: Optional[DataPipelineState] = None + # model own internal state shared with the data pipeline. + self._state: Dict[Type[ProcessState], ProcessState] = {} + # Explicitly set the serializer to call the setter self.deserializer = deserializer self.serializer = serializer @@ -176,6 +182,7 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: """ x, y = batch y_hat = self(x) + y, y_hat = self.apply_filtering(y, y_hat) output = {"y_hat": y_hat} y_hat = self.to_loss_format(output["y_hat"]) losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} @@ -196,6 +203,11 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: output["y"] = y return output + @staticmethod + def apply_filtering(y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """This function is used to filter some labels or predictions which aren't conform.""" + return y, y_hat + @staticmethod def to_loss_format(x: torch.Tensor) -> torch.Tensor: return x @@ -242,7 +254,8 @@ def predict( running_stage = RunningStage.PREDICTING data_pipeline = self.build_data_pipeline(data_source or "default", deserializer, data_pipeline) - x = list(data_pipeline.data_source.generate_dataset(x, running_stage)) + dataset = data_pipeline.data_source.generate_dataset(x, running_stage) + x = list(self.process_predict_dataset(dataset, convert_to_dataloader=False)) x = data_pipeline.worker_preprocessor(running_stage)(x) # todo (tchaton): Remove this when sync with Lightning master. if len(inspect.signature(self.transfer_batch_to_device).parameters) == 3: @@ -428,6 +441,8 @@ def build_data_pipeline( deserializer = getattr(preprocess, "deserializer", deserializer) data_pipeline = DataPipeline(data_source, preprocess, postprocess, deserializer, serializer) + self._data_pipeline_state = self._data_pipeline_state or DataPipelineState() + self.attach_data_pipeline_state(self._data_pipeline_state) self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) return data_pipeline @@ -456,6 +471,7 @@ def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: getattr(data_pipeline, '_postprocess_pipeline', None), getattr(data_pipeline, '_serializer', None), ) + # self._preprocess.state_dict() if getattr(self._preprocess, "_ddp_params_and_buffers_to_ignore", None): self._ddp_params_and_buffers_to_ignore = self._preprocess._ddp_params_and_buffers_to_ignore @@ -667,3 +683,133 @@ def serve(self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = composition = Composition(predict=comp, TESTING=flash._IS_TESTING) composition.serve(host=host, port=port) return composition + + def get_state(self, state_type): + if state_type in self._state: + return self._state[state_type] + if self._data_pipeline_state is not None: + return self._data_pipeline_state.get_state(state_type) + return None + + def set_state(self, state: ProcessState): + self._state[type(state)] = state + if self._data_pipeline_state is not None: + self._data_pipeline_state.set_state(state) + + def attach_data_pipeline_state(self, data_pipeline_state: 'DataPipelineState'): + for state in self._state.values(): + data_pipeline_state.set_state(state) + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + convert_to_dataloader: bool = True, + ) -> DataLoader: + if convert_to_dataloader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + drop_last=drop_last, + collate_fn=collate_fn + ) + return dataset + + def process_train_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_val_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_test_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None + ) -> DataLoader: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler + ) + + def process_predict_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int = 1, + num_workers: int = 0, + pin_memory: bool = False, + collate_fn: Callable = lambda x: x, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + convert_to_dataloader: bool = True + ) -> Union[DataLoader, BaseAutoDataset]: + return self._process_dataset( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + convert_to_dataloader=convert_to_dataloader + ) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index fe319b93d5..9922f49eba 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -83,6 +83,7 @@ def _compare_version(package: str, op, version) -> bool: _CYTOOLZ_AVAILABLE = _module_available("cytoolz") _UVICORN_AVAILABLE = _module_available("uvicorn") _PIL_AVAILABLE = _module_available("PIL") +_OPEN3D_AVAILABLE = _module_available("open3d") _ASTEROID_AVAILABLE = _module_available("asteroid") _SEGMENTATION_MODELS_AVAILABLE = _module_available("segmentation_models_pytorch") _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") @@ -104,6 +105,7 @@ def _compare_version(package: str, op, version) -> bool: _SEGMENTATION_MODELS_AVAILABLE, ]) _SERVE_AVAILABLE = _FASTAPI_AVAILABLE and _PYDANTIC_AVAILABLE and _CYTOOLZ_AVAILABLE and _UVICORN_AVAILABLE +_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE _AUDIO_AVAILABLE = all([ _ASTEROID_AVAILABLE, ]) @@ -114,6 +116,7 @@ def _compare_version(package: str, op, version) -> bool: 'tabular': _TABULAR_AVAILABLE, 'text': _TEXT_AVAILABLE, 'video': _VIDEO_AVAILABLE, + 'pointcloud': _POINTCLOUD_AVAILABLE, 'serve': _SERVE_AVAILABLE, 'audio': _AUDIO_AVAILABLE, 'graph': _GRAPH_AVAILABLE, diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 891a02c50f..d61c8bc8d0 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -427,7 +427,10 @@ def _show_images_and_labels(self, data: List[Any], num_samples: int, title: str) fig, axs = plt.subplots(rows, cols) fig.suptitle(title) - for i, ax in enumerate(axs.ravel()): + if not isinstance(axs, np.ndarray): + axs = [axs] + + for i, ax in enumerate(axs): # unpack images and labels if isinstance(data, list): _img, _label = data[i][DefaultDataKeys.INPUT], data[i].get(DefaultDataKeys.TARGET, "") diff --git a/flash/pointcloud/__init__.py b/flash/pointcloud/__init__.py new file mode 100644 index 0000000000..5d10606f79 --- /dev/null +++ b/flash/pointcloud/__init__.py @@ -0,0 +1,3 @@ +from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401 +from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401 +from flash.pointcloud.segmentation.open3d_ml.app import launch_app # noqa: F401 diff --git a/flash/pointcloud/segmentation/__init__.py b/flash/pointcloud/segmentation/__init__.py new file mode 100644 index 0000000000..bf7f46a89c --- /dev/null +++ b/flash/pointcloud/segmentation/__init__.py @@ -0,0 +1,2 @@ +from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401 +from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401 diff --git a/flash/pointcloud/segmentation/backbones.py b/flash/pointcloud/segmentation/backbones.py new file mode 100644 index 0000000000..023daa9ac0 --- /dev/null +++ b/flash/pointcloud/segmentation/backbones.py @@ -0,0 +1,19 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.registry import FlashRegistry +from flash.pointcloud.segmentation.open3d_ml.backbones import register_open_3d_ml + +POINTCLOUD_SEGMENTATION_BACKBONES = FlashRegistry("backbones") + +register_open_3d_ml(POINTCLOUD_SEGMENTATION_BACKBONES) diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py new file mode 100644 index 0000000000..940092438d --- /dev/null +++ b/flash/pointcloud/segmentation/data.py @@ -0,0 +1,103 @@ +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.core.data.data_module import DataModule +from flash.core.data.data_pipeline import Deserializer +from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, requires_extras + +if _POINTCLOUD_AVAILABLE: + from flash.pointcloud.segmentation.open3d_ml.sequences_dataset import SequencesDataset + + +class PointCloudSegmentationDatasetDataSource(DataSource): + + def load_data( + self, + data: Any, + dataset: Optional[Any] = None, + ) -> Any: + if self.training: + dataset.num_classes = len(data.dataset.label_to_names) + + dataset.dataset = data + + return range(len(data)) + + def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: + + sample = dataset.dataset[index] + + return { + DefaultDataKeys.INPUT: sample['data'], + DefaultDataKeys.METADATA: sample["attr"], + } + + +class PointCloudSegmentationFoldersDataSource(DataSource): + + @requires_extras("pointcloud") + def load_data( + self, + folder: Any, + dataset: Optional[Any] = None, + ) -> Any: + + sequence_dataset = SequencesDataset(folder, use_cache=True, predicting=self.predicting) + dataset.dataset = sequence_dataset + if self.training: + dataset.num_classes = sequence_dataset.num_classes + + return range(len(sequence_dataset)) + + def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: + + sample = dataset.dataset[index] + + return { + DefaultDataKeys.INPUT: sample['data'], + DefaultDataKeys.METADATA: sample["attr"], + } + + +class PointCloudSegmentationPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + image_size: Tuple[int, int] = (196, 196), + deserializer: Optional[Deserializer] = None, + **data_source_kwargs: Any, + ): + self.image_size = image_size + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.DATASET: PointCloudSegmentationDatasetDataSource(**data_source_kwargs), + DefaultDataSources.FOLDERS: PointCloudSegmentationFoldersDataSource(**data_source_kwargs), + }, + deserializer=deserializer, + default_data_source=DefaultDataSources.FOLDERS, + ) + + def get_state_dict(self): + return {} + + def state_dict(self): + return {} + + @classmethod + def load_state_dict(cls, state_dict, strict: bool): + pass + + +class PointCloudSegmentationData(DataModule): + + preprocess_cls = PointCloudSegmentationPreprocess diff --git a/flash/pointcloud/segmentation/datasets.py b/flash/pointcloud/segmentation/datasets.py new file mode 100644 index 0000000000..92048e2612 --- /dev/null +++ b/flash/pointcloud/segmentation/datasets.py @@ -0,0 +1,47 @@ +import os + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + from open3d.ml.datasets import Lyft, SemanticKITTI + +_SEGMENTATION_DATASET = FlashRegistry("dataset") + + +def executor(download_script, preprocess_script, dataset_path, name): + if not os.path.exists(os.path.join(dataset_path, name)): + os.system(f'bash -c "bash <(curl -s {download_script}) {dataset_path}"') + if preprocess_script: + os.system(f'bash -c "bash <(curl -s {preprocess_script}) {dataset_path}"') + + +@_SEGMENTATION_DATASET +def lyft(dataset_path): + name = "Lyft" + executor( + "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_lyft.sh", + "https://github.com/intel-isl/Open3D-ML/blob/master/scripts/preprocess_lyft.py", dataset_path, name + ) + return Lyft(os.path.join(dataset_path, name)) + + +def LyftDataset(dataset_path): + return _SEGMENTATION_DATASET.get("lyft")(dataset_path) + + +@_SEGMENTATION_DATASET +def semantickitti(dataset_path, download, **kwargs): + name = "SemanticKitti" + if download: + executor( + "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_semantickitti.sh", # noqa E501 + None, + dataset_path, + name + ) + return SemanticKITTI(os.path.join(dataset_path, name), **kwargs) + + +def SemanticKITTIDataset(dataset_path, download: bool = True, **kwargs): + return _SEGMENTATION_DATASET.get("semantickitti")(dataset_path, download, **kwargs) diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py new file mode 100644 index 0000000000..b3936acc21 --- /dev/null +++ b/flash/pointcloud/segmentation/model.py @@ -0,0 +1,226 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union + +import torch +import torchmetrics +from pytorch_lightning import Callback, LightningModule +from torch import nn +from torch.nn import functional as F +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader, Sampler +from torchmetrics import IoU + +import flash +from flash.core.classification import ClassificationTask +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Serializer +from flash.core.data.states import CollateFn +from flash.core.finetuning import BaseFinetuning +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES + +if _POINTCLOUD_AVAILABLE: + from open3d._ml3d.torch.modules.losses.semseg_loss import filter_valid_label + from open3d.ml.torch.dataloaders import TorchDataloader + + +class PointCloudSegmentationFinetuning(BaseFinetuning): + + def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1): + super().__init__() + self.num_layers = num_layers + self.train_bn = train_bn + self.unfreeze_epoch = unfreeze_epoch + + def freeze_before_training(self, pl_module: LightningModule) -> None: + self.freeze(modules=list(pl_module.backbone.children())[:-self.num_layers], train_bn=self.train_bn) + + def finetune_function( + self, + pl_module: LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + if epoch != self.unfreeze_epoch: + return + self.unfreeze_and_add_param_group( + modules=list(pl_module.backbone.children())[-self.num_layers:], + optimizer=optimizer, + train_bn=self.train_bn, + ) + + +class PointCloudSegmentationSerializer(Serializer): + pass + + +class PointCloudSegmentation(ClassificationTask): + """The ``PointCloudClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that classifies + pointcloud data. + + Args: + num_features: The number of features (elements) in the input data. + num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`. + backbone: The backbone name (or a tuple of ``nn.Module``, output size) to use. + backbone_kwargs: Any additional kwargs to pass to the backbone constructor. + loss_fn: The loss function to use. If ``None``, a default will be selected by the + :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + optimizer: The optimizer or optimizer class to use. + optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). + scheduler: The scheduler or scheduler class to use. + scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected + by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + learning_rate: The learning rate for the optimizer. + multi_label: If ``True``, this will be treated as a multi-label classification problem. + serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs. + """ + + backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES + + required_extras: str = "pointcloud" + + def __init__( + self, + num_classes: int, + backbone: Union[str, Tuple[nn.Module, int]] = "RandLANet", + backbone_kwargs: Optional[Dict] = None, + head: Optional[nn.Module] = None, + loss_fn: Optional[Callable] = torch.nn.functional.cross_entropy, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-2, + multi_label: bool = False, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudSegmentationSerializer(), + ): + if metrics is None: + metrics = IoU(num_classes=num_classes) + + super().__init__( + model=None, + loss_fn=loss_fn, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + metrics=metrics, + learning_rate=learning_rate, + multi_label=multi_label, + serializer=serializer, + ) + + self.save_hyperparameters() + + if not backbone_kwargs: + backbone_kwargs = {"num_classes": num_classes} + + if isinstance(backbone, tuple): + self.backbone, out_features = backbone + else: + self.backbone, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs) + # replace latest layer + if not flash._IS_TESTING: + self.backbone.fc = nn.Identity() + self.set_state(CollateFn(collate_fn)) + + self.head = nn.Identity() if flash._IS_TESTING else (head or nn.Linear(out_features, num_classes)) + + def apply_filtering(self, labels, scores): + scores, labels = filter_valid_label(scores, labels, self.hparams.num_classes, [0], self.device) + return labels, scores + + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: + return F.softmax(self.to_loss_format(x)) + + def to_loss_format(self, x: torch.Tensor) -> torch.Tensor: + return x.reshape(-1, x.shape[-1]) + + def training_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + return super().validation_step(batch, batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + batch = (batch[DefaultDataKeys.INPUT], batch[DefaultDataKeys.INPUT]["labels"].view(-1)) + return super().test_step(batch, batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + batch[DefaultDataKeys.PREDS] = self(batch[DefaultDataKeys.INPUT]) + batch[DefaultDataKeys.TARGET] = batch[DefaultDataKeys.INPUT]['labels'] + # drop sub-sampled pointclouds + batch[DefaultDataKeys.INPUT] = batch[DefaultDataKeys.INPUT]['xyz'][0] + return batch + + def forward(self, x) -> torch.Tensor: + """First call the backbone, then the model head.""" + # hack to enable backbone to work properly. + self.backbone.device = self.device + x = self.backbone(x) + if self.head is not None: + x = self.head(x) + return x + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + convert_to_dataloader: bool = True, + ) -> Union[DataLoader, BaseAutoDataset]: + + if not _POINTCLOUD_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") + + if not isinstance(dataset.dataset, TorchDataloader): + + dataset.dataset = TorchDataloader( + dataset.dataset, + preprocess=self.backbone.preprocess, + transform=self.backbone.transform, + use_cache=False, + ) + + if convert_to_dataloader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + else: + return dataset + + def configure_finetune_callback(self) -> List[Callback]: + return [PointCloudSegmentationFinetuning()] diff --git a/flash/pointcloud/segmentation/open3d_ml/__init__.py b/flash/pointcloud/segmentation/open3d_ml/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py new file mode 100644 index 0000000000..a226d6f5b2 --- /dev/null +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -0,0 +1,101 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import flash +from flash import DataModule +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + + from open3d._ml3d.torch.dataloaders import TorchDataloader + from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer + + class Visualizer(Visualizer): + + def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768): + """Visualize a dataset. + + Example: + Minimal example for visualizing a dataset:: + import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d + + dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/') + vis = ml3d.vis.Visualizer() + vis.visualize_dataset(dataset, 'all', indices=range(100)) + + Args: + dataset: The dataset to use for visualization. + split: The dataset split to be used, such as 'training' + indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4]. + width: The width of the visualization window. + height: The height of the visualization window. + """ + # Setup the labels + lut = LabelLUT() + color_map = dataset.color_map + for id, val in dataset.label_to_names.items(): + lut.add_label(val, id, color=color_map[id]) + self.set_lut("labels", lut) + + self._consolidate_bounding_boxes = True + self._init_dataset(dataset, split, indices) + self._visualize("Open3D - " + dataset.name, width, height) + + class App: + + def __init__(self, datamodule: DataModule): + self.datamodule = datamodule + self._enabled = not flash._IS_TESTING + + def get_dataset(self, stage: str = "train"): + dataloader = getattr(self.datamodule, f"{stage}_dataloader")() + dataset = dataloader.dataset.dataset + if isinstance(dataset, TorchDataloader): + return dataset.dataset + return dataset + + def show_train_dataset(self, indices=None): + if self._enabled: + dataset = self.get_dataset("train") + viz = Visualizer() + viz.visualize_dataset(dataset, 'all', indices=indices) + + def show_predictions(self, predictions): + if self._enabled: + dataset = self.get_dataset("train") + color_map = dataset.color_map + + predictions_visualizations = [] + for pred in predictions: + predictions_visualizations.append({ + "points": torch.stack(pred[DefaultDataKeys.INPUT]), + "labels": torch.stack(pred[DefaultDataKeys.TARGET]), + "predictions": torch.argmax(torch.stack(pred[DefaultDataKeys.PREDS]), axis=-1) + 1, + "name": pred[DefaultDataKeys.METADATA]["name"], + }) + + viz = Visualizer() + lut = LabelLUT() + color_map = dataset.color_map + for id, val in dataset.label_to_names.items(): + lut.add_label(val, id, color=color_map[id]) + viz.set_lut("labels", lut) + viz.set_lut("predictions", lut) + viz.visualize(predictions_visualizations) + + +def launch_app(datamodule: DataModule) -> 'App': + return App(datamodule) diff --git a/flash/pointcloud/segmentation/open3d_ml/backbones.py b/flash/pointcloud/segmentation/open3d_ml/backbones.py new file mode 100644 index 0000000000..0fe44a72ce --- /dev/null +++ b/flash/pointcloud/segmentation/open3d_ml/backbones.py @@ -0,0 +1,79 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable + +import torch +from pytorch_lightning.utilities.cloud_io import load as pl_load + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/" + + +def register_open_3d_ml(register: FlashRegistry): + if _POINTCLOUD_AVAILABLE: + import open3d + import open3d.ml as _ml3d + from open3d.ml.torch.dataloaders import ConcatBatcher, DefaultBatcher + from open3d.ml.torch.models import RandLANet + + CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs") + + def get_collate_fn(model) -> Callable: + batcher_name = model.cfg.batcher + if batcher_name == 'DefaultBatcher': + batcher = DefaultBatcher() + elif batcher_name == 'ConcatBatcher': + batcher = ConcatBatcher(torch, model.__class__.__name__) + else: + batcher = None + return batcher.collate_fn + + @register + def randlanet_s3dis(*args, use_fold_5: bool = True, **kwargs) -> RandLANet: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_s3dis.yml")) + model = RandLANet(**cfg.model) + if use_fold_5: + weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_area5_202010091333utc.pth") + else: + weight_url = os.path.join(ROOT_URL, "randlanet_s3dis_202010091238.pth") + model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict']) + return model, 32, get_collate_fn(model) + + @register + def randlanet_toronto3d(*args, **kwargs) -> RandLANet: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_toronto3d.yml")) + model = RandLANet(**cfg.model) + model.load_state_dict( + pl_load(os.path.join(ROOT_URL, "randlanet_toronto3d_202010091306utc.pth"), + map_location='cpu')['model_state_dict'], + ) + return model, 32, get_collate_fn(model) + + @register + def randlanet_semantic_kitti(*args, **kwargs) -> RandLANet: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "randlanet_semantickitti.yml")) + model = RandLANet(**cfg.model) + model.load_state_dict( + pl_load(os.path.join(ROOT_URL, "randlanet_semantickitti_202009090354utc.pth"), + map_location='cpu')['model_state_dict'], + ) + return model, 32, get_collate_fn(model) + + @register + def randlanet(*args, **kwargs) -> RandLANet: + model = RandLANet(*args, **kwargs) + return model, 32, get_collate_fn(model) diff --git a/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py new file mode 100644 index 0000000000..0609e2e098 --- /dev/null +++ b/flash/pointcloud/segmentation/open3d_ml/sequences_dataset.py @@ -0,0 +1,181 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from os.path import basename, dirname, exists, isdir, isfile, join, split + +import numpy as np +import yaml +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch.utils.data import Dataset + +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + + from open3d._ml3d.datasets.utils import DataProcessing + from open3d._ml3d.utils.config import Config + + class SequencesDataset(Dataset): + + def __init__( + self, + data, + cache_dir='./logs/cache', + use_cache=False, + num_points=65536, + ignored_label_inds=[0], + predicting=False, + **kwargs + ): + + super().__init__() + + self.name = "Dataset" + self.ignored_label_inds = ignored_label_inds + + kwargs["cache_dir"] = cache_dir + kwargs["use_cache"] = use_cache + kwargs["num_points"] = num_points + kwargs["ignored_label_inds"] = ignored_label_inds + + self.cfg = Config(kwargs) + self.predicting = predicting + + if not predicting: + self.on_fit(data) + else: + self.on_predict(data) + + @property + def color_map(self): + return self.meta["color_map"] + + def on_fit(self, dataset_path): + self.split = basename(dataset_path) + + self.load_meta(dirname(dataset_path)) + self.dataset_path = dataset_path + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) - len(self.ignored_label_inds) + self.make_datasets() + + def load_meta(self, root_dir): + meta_file = join(root_dir, "meta.yaml") + if not exists(meta_file): + raise MisconfigurationException( + f"The {root_dir} should contain a `meta.yaml` file about the pointcloud sequences." + ) + + with open(meta_file, 'r') as f: + self.meta = yaml.safe_load(f) + + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + + with open(meta_file, 'r') as f: + self.meta = yaml.safe_load(f) + + remap_dict_val = self.meta["learning_map"] + max_key = max(remap_dict_val.keys()) + remap_lut_val = np.zeros((max_key + 100), dtype=np.int32) + remap_lut_val[list(remap_dict_val.keys())] = list(remap_dict_val.values()) + + self.remap_lut_val = remap_lut_val + + def make_datasets(self): + self.path_list = [] + for seq in os.listdir(self.dataset_path): + sequence_path = join(self.dataset_path, seq) + directories = [f for f in os.listdir(sequence_path) if isdir(join(sequence_path, f)) and f != "labels"] + assert len(directories) == 1 + scan_dir = join(sequence_path, directories[0]) + for scan_name in os.listdir(scan_dir): + self.path_list.append(join(scan_dir, scan_name)) + + def on_predict(self, data): + if isinstance(data, list): + if not all(isfile(p) for p in data): + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + root_dir = split(data[0])[0] + elif isinstance(data, str): + if not isdir(data) and not isfile(data): + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + if isdir(data): + root_dir = data + data = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if ".bin" in f] + elif isfile(data): + root_dir = dirname(data) + data = [data] + else: + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + else: + raise MisconfigurationException("The predict input data takes only a list of paths or a directory.") + + self.path_list = data + self.split = "predict" + self.load_meta(root_dir) + + def get_label_to_names(self): + """Returns a label to names dictonary object. + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + return self.meta["label_to_names"] + + def __getitem__(self, index): + data = self.get_data(index) + data['attr'] = self.get_attr(index) + return data + + def get_data(self, idx): + pc_path = self.path_list[idx] + points = DataProcessing.load_pc_kitti(pc_path) + + dir, file = split(pc_path) + if self.predicting: + label_path = join(dir, file[:-4] + '.label') + else: + label_path = join(dir, '../labels', file[:-4] + '.label') + if not exists(label_path): + labels = np.zeros(np.shape(points)[0], dtype=np.int32) + if self.split not in ['test', 'all']: + raise FileNotFoundError(f' Label file {label_path} not found') + + else: + labels = DataProcessing.load_label_kitti(label_path, self.remap_lut_val).astype(np.int32) + + data = { + 'point': points[:, 0:3], + 'feat': None, + 'label': labels, + } + + return data + + def get_attr(self, idx): + pc_path = self.path_list[idx] + dir, file = split(pc_path) + _, seq = split(split(dir)[0]) + name = '{}_{}'.format(seq, file[:-4]) + + pc_path = str(pc_path) + attr = {'idx': idx, 'name': name, 'path': pc_path, 'split': self.split} + return attr + + def __len__(self): + return len(self.path_list) + + def get_split(self, *_): + return self diff --git a/flash_examples/pointcloud_segmentation.py b/flash_examples/pointcloud_segmentation.py new file mode 100644 index 0000000000..f316cc9108 --- /dev/null +++ b/flash_examples/pointcloud_segmentation.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.data.utils import download_data +from flash.pointcloud import PointCloudSegmentation, PointCloudSegmentationData + +# 1. Create the DataModule +# Dataset Credit: http://www.semantic-kitti.org/ +download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/") + +datamodule = PointCloudSegmentationData.from_folders( + train_folder="data/SemanticKittiTiny/train", + val_folder='data/SemanticKittiTiny/val', +) + +# 2. Build the task +model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0) +trainer.fit(model, datamodule) + +# 4. Predict what's within a few PointClouds? +predictions = model.predict([ + "data/SemanticKittiTiny/predict/000000.bin", + "data/SemanticKittiTiny/predict/000001.bin", +]) + +# 5. Save the model! +trainer.save_checkpoint("pointcloud_segmentation_model.pt") diff --git a/flash_examples/visualizations/pointcloud_segmentation.py b/flash_examples/visualizations/pointcloud_segmentation.py new file mode 100644 index 0000000000..e4859a8d90 --- /dev/null +++ b/flash_examples/visualizations/pointcloud_segmentation.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.data.utils import download_data +from flash.pointcloud import launch_app, PointCloudSegmentation, PointCloudSegmentationData + +# 1. Create the DataModule +# Dataset Credit: http://www.semantic-kitti.org/ +download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiTiny.zip", "data/") + +datamodule = PointCloudSegmentationData.from_folders( + train_folder="data/SemanticKittiTiny/train", + val_folder='data/SemanticKittiTiny/val', +) + +# 2. Build the task +model = PointCloudSegmentation(backbone="randlanet_semantic_kitti", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1, limit_train_batches=0, limit_val_batches=0, num_sanity_val_steps=0) +trainer.fit(model, datamodule) + +# 4. Predict what's within a few PointClouds? +predictions = model.predict([ + "data/SemanticKittiTiny/predict/000000.bin", + "data/SemanticKittiTiny/predict/000001.bin", +]) + +# 5. Save the model! +trainer.save_checkpoint("pointcloud_segmentation_model.pt") + +# 6. Optional Visualize +app = launch_app(datamodule) +app.show_predictions(predictions) diff --git a/requirements.txt b/requirements.txt index 01330917d4..b85542e0b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch>=1.8 +torch torchmetrics pytorch-lightning>=1.3.1 pyDeprecate diff --git a/requirements/datatype_pointcloud.txt b/requirements/datatype_pointcloud.txt new file mode 100644 index 0000000000..544ab6061b --- /dev/null +++ b/requirements/datatype_pointcloud.txt @@ -0,0 +1,4 @@ +open3d +torch==1.7.1 +torchvision +tensorboard diff --git a/setup.py b/setup.py index c83ec4b354..14e0c34dc6 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ def _load_py_module(fname, pkg="flash"): "image": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_image.txt"), "image_extras": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_image_extras.txt"), "video": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_video.txt"), + "pointcloud": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_pointcloud.txt"), "video_extras": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_video_extras.txt"), "serve": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="serve.txt"), "audio": setup_tools._load_requirements(path_dir=_PATH_REQUIRE, file_name="datatype_audio.txt"), @@ -58,7 +59,9 @@ def _load_py_module(fname, pkg="flash"): } extras["vision"] = list(set(extras["image"] + extras["video"])) -extras["all"] = list(set(extras["vision"] + extras["tabular"] + extras["text"])) +extras["all"] = list( + set(extras["vision"] + extras["tabular"] + extras["text"]) +) # + extras["pointcloud"] dependencies conflicts extras["dev"] = list(set(extras["all"] + extras["test"] + extras["docs"])) # https://packaging.python.org/discussions/install-requires-vs-requirements / diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index ec3dc48ce1..68252601e5 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -20,7 +20,14 @@ import flash from flash.core.utilities.imports import _SKLEARN_AVAILABLE from tests.examples.utils import run_test -from tests.helpers.utils import _GRAPH_TESTING, _IMAGE_TESTING, _TABULAR_TESTING, _TEXT_TESTING, _VIDEO_TESTING +from tests.helpers.utils import ( + _GRAPH_TESTING, + _IMAGE_TESTING, + _POINTCLOUD_TESTING, + _TABULAR_TESTING, + _TEXT_TESTING, + _VIDEO_TESTING, +) @mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) @@ -70,6 +77,10 @@ "video_classification.py", marks=pytest.mark.skipif(not _VIDEO_TESTING, reason="video libraries aren't installed") ), + pytest.param( + "pointcloud_segmentation.py", + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") + ), pytest.param( "graph_classification.py", marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed") diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 2f1f2c9c80..5bb699b664 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -16,6 +16,7 @@ from flash.core.utilities.imports import ( _GRAPH_AVAILABLE, _IMAGE_AVAILABLE, + _POINTCLOUD_AVAILABLE, _SERVE_AVAILABLE, _TABULAR_AVAILABLE, _TEXT_AVAILABLE, @@ -27,6 +28,7 @@ _TABULAR_TESTING = _TABULAR_AVAILABLE _TEXT_TESTING = _TEXT_AVAILABLE _SERVE_TESTING = _SERVE_AVAILABLE +_POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE _GRAPH_TESTING = _GRAPH_AVAILABLE if "FLASH_TEST_TOPIC" in os.environ: @@ -36,4 +38,5 @@ _TABULAR_TESTING = topic == "tabular" _TEXT_TESTING = topic == "text" _SERVE_TESTING = topic == "serve" + _POINTCLOUD_TESTING = topic == "pointcloud" _GRAPH_TESTING = topic == "graph" diff --git a/tests/pointcloud/segmentation/test_data.py b/tests/pointcloud/segmentation/test_data.py new file mode 100644 index 0000000000..00fa47c208 --- /dev/null +++ b/tests/pointcloud/segmentation/test_data.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from os.path import join + +import pytest +import torch +from pytorch_lightning import seed_everything + +from flash import Trainer +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.utils import download_data +from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData +from tests.helpers.utils import _POINTCLOUD_TESTING + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_pointcloud_segmentation_data(tmpdir): + + seed_everything(52) + + download_data("https://pl-flash-data.s3.amazonaws.com/SemanticKittiMicro.zip", tmpdir) + + dm = PointCloudSegmentationData.from_folders(train_folder=join(tmpdir, "SemanticKittiMicro", "train"), ) + + class MockModel(PointCloudSegmentation): + + def training_step(self, batch, batch_idx: int): + assert batch[DefaultDataKeys.INPUT]["xyz"][0].shape == torch.Size([2, 45056, 3]) + assert batch[DefaultDataKeys.INPUT]["xyz"][1].shape == torch.Size([2, 11264, 3]) + assert batch[DefaultDataKeys.INPUT]["xyz"][2].shape == torch.Size([2, 2816, 3]) + assert batch[DefaultDataKeys.INPUT]["xyz"][3].shape == torch.Size([2, 704, 3]) + assert batch[DefaultDataKeys.INPUT]["labels"].shape == torch.Size([2, 45056]) + assert batch[DefaultDataKeys.INPUT]["labels"].max() == 19 + assert batch[DefaultDataKeys.INPUT]["labels"].min() == 0 + assert batch[DefaultDataKeys.METADATA][0]["name"] == '00_000000' + assert batch[DefaultDataKeys.METADATA][1]["name"] == '00_000001' + + num_classes = 19 + model = MockModel(backbone="randlanet", num_classes=num_classes) + trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0) + trainer.fit(model, dm) + + predictions = model.predict(join(tmpdir, "SemanticKittiMicro", "predict")) + assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape == torch.Size([45056, 3]) + assert torch.stack(predictions[0][DefaultDataKeys.PREDS]).shape == torch.Size([45056, 19]) + assert torch.stack(predictions[0][DefaultDataKeys.TARGET]).shape == torch.Size([45056]) diff --git a/tests/pointcloud/segmentation/test_model.py b/tests/pointcloud/segmentation/test_model.py new file mode 100644 index 0000000000..06eabc2c31 --- /dev/null +++ b/tests/pointcloud/segmentation/test_model.py @@ -0,0 +1,33 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from flash.pointcloud.segmentation import PointCloudSegmentation +from tests.helpers.utils import _POINTCLOUD_TESTING + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_backbones(): + + backbones = PointCloudSegmentation.available_backbones() + assert backbones == ['randlanet', 'randlanet_s3dis', 'randlanet_semantic_kitti', 'randlanet_toronto3d'] + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_models(): + + num_classes = 13 + model = PointCloudSegmentation(backbone="randlanet", num_classes=num_classes) + assert model.head.weight.shape == torch.Size([13, 32])