diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml
index d26d8ecee2..21ac8fbd45 100644
--- a/.github/workflows/ci-testing.yml
+++ b/.github/workflows/ci-testing.yml
@@ -61,6 +61,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: ['graph']
+ - os: ubuntu-20.04
+ python-version: 3.8
+ requires: 'latest'
+ topic: ['audio']
# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
@@ -128,6 +132,13 @@ jobs:
run: |
pip install '.[all]' --pre --upgrade
+ - name: Install audio test dependencies
+ if: matrix.topic[0] == 'audio'
+ run: |
+ sudo apt-get install libsndfile1
+ pip install matplotlib
+ pip install '.[image]' --pre --upgrade
+
- name: Cache datasets
uses: actions/cache@v2
with:
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 97085839cd..cb7c1cb3b8 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -20,12 +20,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566))
+- Added `PointCloudObjectDetection` Task ([#600](https://github.com/PyTorchLightning/lightning-flash/pull/600))
+
- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73))
- Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587))
- Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585))
+- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594))
+
### Changed
- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html
index d3312220d7..d050db39c5 100644
--- a/docs/source/_templates/layout.html
+++ b/docs/source/_templates/layout.html
@@ -4,7 +4,7 @@
{% block footer %}
{{ super() }}
{% endblock %}
diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst
new file mode 100644
index 0000000000..79662fea87
--- /dev/null
+++ b/docs/source/api/audio.rst
@@ -0,0 +1,21 @@
+###########
+flash.audio
+###########
+
+.. contents::
+ :depth: 1
+ :local:
+ :backlinks: top
+
+.. currentmodule:: flash.audio
+
+Classification
+______________
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ ~classification.data.AudioClassificationData
+ ~classification.data.AudioClassificationPreprocess
diff --git a/docs/source/api/pointcloud.rst b/docs/source/api/pointcloud.rst
index d29a3d4e32..a98c6124f0 100644
--- a/docs/source/api/pointcloud.rst
+++ b/docs/source/api/pointcloud.rst
@@ -23,3 +23,19 @@ ____________
segmentation.data.PointCloudSegmentationPreprocess
segmentation.data.PointCloudSegmentationFoldersDataSource
segmentation.data.PointCloudSegmentationDatasetDataSource
+
+
+Object Detection
+________________
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ ~detection.model.PointCloudObjectDetector
+ ~detection.data.PointCloudObjectDetectorData
+
+ detection.data.PointCloudObjectDetectorPreprocess
+ detection.data.PointCloudObjectDetectorFoldersDataSource
+ detection.data.PointCloudObjectDetectorDatasetDataSource
diff --git a/docs/source/general/registry.rst b/docs/source/general/registry.rst
index 62ae14c67f..12ef22728b 100644
--- a/docs/source/general/registry.rst
+++ b/docs/source/general/registry.rst
@@ -100,7 +100,7 @@ Example::
from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES
- print(IMAGE_CLASSIFIER_BACKBONES.available_models())
+ print(IMAGE_CLASSIFIER_BACKBONES.available_keys())
""" out:
['adv_inception_v3', 'cspdarknet53', 'cspdarknet53_iabn', 430+.., 'xception71']
"""
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 34616e011d..d12099d884 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -40,6 +40,12 @@ Lightning Flash
reference/style_transfer
reference/video_classification
+.. toctree::
+ :maxdepth: 1
+ :caption: Audio
+
+ reference/audio_classification
+
.. toctree::
:maxdepth: 1
:caption: Tabular
@@ -60,6 +66,7 @@ Lightning Flash
:caption: Point Cloud
reference/pointcloud_segmentation
+ reference/pointcloud_object_detection
.. toctree::
:maxdepth: 1
@@ -82,6 +89,7 @@ Lightning Flash
api/data
api/serve
api/image
+ api/audio
api/pointcloud
api/tabular
api/text
diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst
new file mode 100644
index 0000000000..eb122e6995
--- /dev/null
+++ b/docs/source/reference/audio_classification.rst
@@ -0,0 +1,73 @@
+
+.. _audio_classification:
+
+####################
+Audio Classification
+####################
+
+********
+The Task
+********
+
+The task of identifying what is in an audio file is called audio classification.
+Typically, Audio Classification is used to identify audio files containing sounds or words.
+The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty.
+A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc.
+
+------
+
+*******
+Example
+*******
+
+Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset.
+The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes.
+
+.. code-block::
+
+ urban8k_images
+ ├── train
+ │ ├── air_conditioner
+ │ ├── car_horn
+ │ ├── children_playing
+ │ ├── dog_bark
+ │ ├── drilling
+ │ ├── engine_idling
+ │ ├── gun_shot
+ │ ├── jackhammer
+ │ ├── siren
+ │ └── street_music
+ ├── test
+ │ ├── air_conditioner
+ │ ├── car_horn
+ │ ├── children_playing
+ │ ├── dog_bark
+ │ ├── drilling
+ │ ├── engine_idling
+ │ ├── gun_shot
+ │ ├── jackhammer
+ │ ├── siren
+ │ └── street_music
+ └── val
+ ├── air_conditioner
+ ├── car_horn
+ ├── children_playing
+ ├── dog_bark
+ ├── drilling
+ ├── engine_idling
+ ├── gun_shot
+ ├── jackhammer
+ ├── siren
+ └── street_music
+
+ ...
+
+Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`.
+We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data.
+We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference.
+Finally, we save the model.
+Here's the full example:
+
+.. literalinclude:: ../../../flash_examples/audio_classification.py
+ :language: python
+ :lines: 14-
diff --git a/docs/source/reference/pointcloud_object_detection.rst b/docs/source/reference/pointcloud_object_detection.rst
new file mode 100644
index 0000000000..36c1b19e6b
--- /dev/null
+++ b/docs/source/reference/pointcloud_object_detection.rst
@@ -0,0 +1,82 @@
+
+.. _pointcloud_object_detection:
+
+############################
+Point Cloud Object Detection
+############################
+
+********
+The Task
+********
+
+A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates.
+
+PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes.
+
+The current integration builds on top `Open3D-ML `_.
+
+------
+
+*******
+Example
+*******
+
+Let's look at an example using a data set generated from the `KITTI Vision Benchmark `_.
+The data are a tiny subset of the original dataset and contains sequences of point clouds.
+
+The data contains:
+ * one folder for scans
+ * one folder for scan calibrations
+ * one folder for labels
+ * a meta.yaml file describing the classes and their official associated color map.
+
+Here's the structure:
+
+.. code-block::
+
+ data
+ ├── meta.yaml
+ ├── train
+ │ ├── scans
+ | | ├── 00000.bin
+ | | ├── 00001.bin
+ | | ...
+ │ ├── calibs
+ | | ├── 00000.txt
+ | | ├── 00001.txt
+ | | ...
+ │ ├── labels
+ | | ├── 00000.txt
+ | | ├── 00001.txt
+ │ ...
+ ├── val
+ │ ...
+ ├── predict
+ ├── scans
+ | ├── 00000.bin
+ | ├── 00001.bin
+ |
+ ├── calibs
+ | ├── 00000.txt
+ | ├── 00001.txt
+ ├── meta.yaml
+
+
+
+Learn more: http://www.semantic-kitti.org/dataset.html
+
+
+Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.detection.data.PointCloudObjectDetectorData`.
+We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.detection.model.PointCloudObjectDetector` task.
+We then use the trained :class:`~flash.image.detection.model.PointCloudObjectDetector` for inference.
+Finally, we save the model.
+Here's the full example:
+
+.. literalinclude:: ../../../flash_examples/pointcloud_detection.py
+ :language: python
+ :lines: 14-
+
+
+
+.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/visualizer_BoundingBoxes.png
+ :width: 100%
diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py
new file mode 100644
index 0000000000..40eeaae124
--- /dev/null
+++ b/flash/audio/__init__.py
@@ -0,0 +1 @@
+from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
diff --git a/flash/audio/classification/__init__.py b/flash/audio/classification/__init__.py
new file mode 100644
index 0000000000..476a303d49
--- /dev/null
+++ b/flash/audio/classification/__init__.py
@@ -0,0 +1 @@
+from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py
new file mode 100644
index 0000000000..68678b2a1b
--- /dev/null
+++ b/flash/audio/classification/data.py
@@ -0,0 +1,87 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Callable, Dict, Optional, Tuple
+
+from flash.audio.classification.transforms import default_transforms, train_default_transforms
+from flash.core.data.callback import BaseDataFetcher
+from flash.core.data.data_module import DataModule
+from flash.core.data.data_source import DefaultDataSources
+from flash.core.data.process import Deserializer, Preprocess
+from flash.core.utilities.imports import requires_extras
+from flash.image.classification.data import MatplotlibVisualization
+from flash.image.data import ImageDeserializer, ImagePathsDataSource
+
+
+class AudioClassificationPreprocess(Preprocess):
+
+ @requires_extras(["audio", "image"])
+ def __init__(
+ self,
+ train_transform: Optional[Dict[str, Callable]],
+ val_transform: Optional[Dict[str, Callable]],
+ test_transform: Optional[Dict[str, Callable]],
+ predict_transform: Optional[Dict[str, Callable]],
+ spectrogram_size: Tuple[int, int] = (196, 196),
+ time_mask_param: int = 80,
+ freq_mask_param: int = 80,
+ deserializer: Optional['Deserializer'] = None,
+ ):
+ self.spectrogram_size = spectrogram_size
+ self.time_mask_param = time_mask_param
+ self.freq_mask_param = freq_mask_param
+
+ super().__init__(
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_sources={
+ DefaultDataSources.FILES: ImagePathsDataSource(),
+ DefaultDataSources.FOLDERS: ImagePathsDataSource()
+ },
+ deserializer=deserializer or ImageDeserializer(),
+ default_data_source=DefaultDataSources.FILES,
+ )
+
+ def get_state_dict(self) -> Dict[str, Any]:
+ return {
+ **self.transforms,
+ "spectrogram_size": self.spectrogram_size,
+ "time_mask_param": self.time_mask_param,
+ "freq_mask_param": self.freq_mask_param,
+ }
+
+ @classmethod
+ def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
+ return cls(**state_dict)
+
+ def default_transforms(self) -> Optional[Dict[str, Callable]]:
+ return default_transforms(self.spectrogram_size)
+
+ def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
+ return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param)
+
+
+class AudioClassificationData(DataModule):
+ """Data module for audio classification."""
+
+ preprocess_cls = AudioClassificationPreprocess
+
+ def set_block_viz_window(self, value: bool) -> None:
+ """Setter method to switch on/off matplotlib to pop up windows."""
+ self.data_fetcher.block_viz_window = value
+
+ @staticmethod
+ def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
+ return MatplotlibVisualization(*args, **kwargs)
diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py
new file mode 100644
index 0000000000..02a9ed2cbc
--- /dev/null
+++ b/flash/audio/classification/transforms.py
@@ -0,0 +1,54 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, Dict, Tuple
+
+import torch
+from torch import nn
+
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms
+from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE
+
+if _TORCHVISION_AVAILABLE:
+ import torchvision
+ from torchvision import transforms as T
+
+if _TORCHAUDIO_AVAILABLE:
+ from torchaudio import transforms as TAudio
+
+
+def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]:
+ """The default transforms for audio classification for spectrograms: resize the spectrogram,
+ convert the spectrogram and target to a tensor, and collate the batch."""
+ return {
+ "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)),
+ "to_tensor_transform": nn.Sequential(
+ ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
+ ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor),
+ ),
+ "collate": kornia_collate,
+ }
+
+
+def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int,
+ freq_mask_param: int) -> Dict[str, Callable]:
+ """During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``"""
+ transforms = {
+ "post_tensor_transform": nn.Sequential(
+ ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)),
+ ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param))
+ )
+ }
+
+ return merge_transforms(default_transforms(spectrogram_size), transforms)
diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py
index d3c7c611ef..c24e937b08 100644
--- a/flash/core/data/data_source.py
+++ b/flash/core/data/data_source.py
@@ -176,6 +176,13 @@ def __hash__(self) -> int:
return hash(self.value)
+class BaseDataFormat(LightningEnum):
+ """The base class for creating ``data_format`` for :class:`~flash.core.data.data_source.DataSource`."""
+
+ def __hash__(self) -> int:
+ return hash(self.value)
+
+
class MockDataset:
"""The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. This is passed to
:meth:`~flash.core.data.data_source.DataSource.load_data` so that attributes can be set on the generated
diff --git a/flash/core/data/states.py b/flash/core/data/states.py
index 5755e7445f..de026f7d73 100644
--- a/flash/core/data/states.py
+++ b/flash/core/data/states.py
@@ -4,6 +4,24 @@
from flash.core.data.properties import ProcessState
+@dataclass(unsafe_hash=True, frozen=True)
+class PreTensorTransform(ProcessState):
+
+ transform: Optional[Callable] = None
+
+
+@dataclass(unsafe_hash=True, frozen=True)
+class ToTensorTransform(ProcessState):
+
+ transform: Optional[Callable] = None
+
+
+@dataclass(unsafe_hash=True, frozen=True)
+class PostTensorTransform(ProcessState):
+
+ transform: Optional[Callable] = None
+
+
@dataclass(unsafe_hash=True, frozen=True)
class CollateFn(ProcessState):
diff --git a/flash/core/model.py b/flash/core/model.py
index 1036e45e7f..21fa1a40f3 100644
--- a/flash/core/model.py
+++ b/flash/core/model.py
@@ -188,21 +188,32 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any:
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
y_hat = self.to_metrics_format(output["y_hat"])
+
+ logs = {}
+
for name, metric in metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
logs[name] = metric # log the metric itself if it is of type Metric
else:
logs[name] = metric(y_hat, y)
- logs.update(losses)
+
if len(losses.values()) > 1:
logs["total_loss"] = sum(losses.values())
return logs["total_loss"], logs
- output["loss"] = list(losses.values())[0]
- output["logs"] = logs
+
+ output["loss"] = self.compute_loss(losses)
+ output["logs"] = self.compute_logs(logs, losses)
output["y"] = y
return output
+ def compute_loss(self, losses: Dict[str, torch.Tensor]) -> torch.Tensor:
+ return list(losses.values())[0]
+
+ def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]):
+ logs.update(losses)
+ return logs
+
@staticmethod
def apply_filtering(y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function is used to filter some labels or predictions which aren't conform."""
diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py
index eda24ef98c..654ae3a165 100644
--- a/flash/core/utilities/imports.py
+++ b/flash/core/utilities/imports.py
@@ -16,6 +16,7 @@
import operator
import types
from importlib.util import find_spec
+from typing import Callable, List, Union
from pkg_resources import DistributionNotFound
@@ -89,6 +90,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter")
_TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse")
_TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric")
+_TORCHAUDIO_AVAILABLE = _module_available("torchaudio")
_ICEVISION_AVAILABLE = _module_available("icevision")
if Version:
@@ -110,6 +112,7 @@ def _compare_version(package: str, op, version) -> bool:
_POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE
_AUDIO_AVAILABLE = all([
_ASTEROID_AVAILABLE,
+ _TORCHAUDIO_AVAILABLE,
])
_GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE
@@ -125,15 +128,22 @@ def _compare_version(package: str, op, version) -> bool:
}
-def _requires(module_path: str, module_available: bool):
+def _requires(
+ module_paths: Union[str, List],
+ module_available: Callable[[str], bool],
+ formatter: Callable[[List[str]], str],
+):
+
+ if not isinstance(module_paths, list):
+ module_paths = [module_paths]
def decorator(func):
- if not module_available:
+ if not all(module_available(module_path) for module_path in module_paths):
@functools.wraps(func)
def wrapper(*args, **kwargs):
raise ModuleNotFoundError(
- f"Required dependencies not available. Please run: pip install '{module_path}'"
+ f"Required dependencies not available. Please run: pip install {formatter(module_paths)}"
)
return wrapper
@@ -143,12 +153,14 @@ def wrapper(*args, **kwargs):
return decorator
-def requires(module_path: str):
- return _requires(module_path, _module_available(module_path))
+def requires(module_paths: Union[str, List]):
+ return _requires(module_paths, _module_available, lambda module_paths: " ".join(module_paths))
-def requires_extras(extras: str):
- return _requires(f"lightning-flash[{extras}]", _EXTRAS_AVAILABLE[extras])
+def requires_extras(extras: Union[str, List]):
+ return _requires(
+ extras, lambda extras: _EXTRAS_AVAILABLE[extras], lambda extras: f"'lightning-flash[{','.join(extras)}]'"
+ )
def lazy_import(module_name, callback=None):
diff --git a/flash/pointcloud/__init__.py b/flash/pointcloud/__init__.py
index 5d10606f79..766f2f2e89 100644
--- a/flash/pointcloud/__init__.py
+++ b/flash/pointcloud/__init__.py
@@ -1,3 +1,2 @@
-from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401
-from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401
-from flash.pointcloud.segmentation.open3d_ml.app import launch_app # noqa: F401
+from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData # noqa: F401
+from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData # noqa: F401
diff --git a/flash/pointcloud/detection/__init__.py b/flash/pointcloud/detection/__init__.py
new file mode 100644
index 0000000000..cfe4c690f0
--- /dev/null
+++ b/flash/pointcloud/detection/__init__.py
@@ -0,0 +1,3 @@
+from flash.pointcloud.detection.data import PointCloudObjectDetectorData # noqa: F401
+from flash.pointcloud.detection.model import PointCloudObjectDetector # noqa: F401
+from flash.pointcloud.detection.open3d_ml.app import launch_app # noqa: F401
diff --git a/flash/pointcloud/detection/backbones.py b/flash/pointcloud/detection/backbones.py
new file mode 100644
index 0000000000..88268dd036
--- /dev/null
+++ b/flash/pointcloud/detection/backbones.py
@@ -0,0 +1,19 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from flash.core.registry import FlashRegistry
+from flash.pointcloud.detection.open3d_ml.backbones import register_open_3d_ml
+
+POINTCLOUD_OBJECT_DETECTION_BACKBONES = FlashRegistry("backbones")
+
+register_open_3d_ml(POINTCLOUD_OBJECT_DETECTION_BACKBONES)
diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py
new file mode 100644
index 0000000000..30c877e70d
--- /dev/null
+++ b/flash/pointcloud/detection/data.py
@@ -0,0 +1,178 @@
+from typing import Any, Callable, Dict, Optional
+
+from torch.utils.data import Sampler
+
+from flash.core.data.base_viz import BaseDataFetcher
+from flash.core.data.data_module import DataModule
+from flash.core.data.data_pipeline import Deserializer
+from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources
+from flash.core.data.process import Preprocess
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+if _POINTCLOUD_AVAILABLE:
+ from flash.pointcloud.detection.open3d_ml.data_sources import (
+ PointCloudObjectDetectionDataFormat,
+ PointCloudObjectDetectorFoldersDataSource,
+ )
+else:
+ PointCloudObjectDetectorFoldersDataSource = object()
+
+ class PointCloudObjectDetectionDataFormat:
+ KITTI = None
+
+
+class PointCloudObjectDetectorDatasetDataSource(DataSource):
+
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ def load_data(
+ self,
+ data: Any,
+ dataset: Optional[Any] = None,
+ ) -> Any:
+
+ dataset.dataset = data
+
+ return range(len(data))
+
+ def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any:
+ sample = dataset.dataset[index]
+
+ return {
+ DefaultDataKeys.INPUT: sample['data'],
+ DefaultDataKeys.METADATA: sample["attr"],
+ }
+
+
+class PointCloudObjectDetectorPreprocess(Preprocess):
+
+ def __init__(
+ self,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ deserializer: Optional[Deserializer] = None,
+ **data_source_kwargs,
+ ):
+
+ super().__init__(
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_sources={
+ DefaultDataSources.DATASET: PointCloudObjectDetectorDatasetDataSource(**data_source_kwargs),
+ DefaultDataSources.FOLDERS: PointCloudObjectDetectorFoldersDataSource(**data_source_kwargs),
+ },
+ deserializer=deserializer,
+ default_data_source=DefaultDataSources.FOLDERS,
+ )
+
+ def get_state_dict(self):
+ return {}
+
+ def state_dict(self):
+ return {}
+
+ @classmethod
+ def load_state_dict(cls, state_dict, strict: bool = False):
+ pass
+
+
+class PointCloudObjectDetectorData(DataModule):
+
+ preprocess_cls = PointCloudObjectDetectorPreprocess
+
+ @classmethod
+ def from_folders(
+ cls,
+ train_folder: Optional[str] = None,
+ val_folder: Optional[str] = None,
+ test_folder: Optional[str] = None,
+ predict_folder: Optional[str] = None,
+ train_transform: Optional[Dict[str, Callable]] = None,
+ val_transform: Optional[Dict[str, Callable]] = None,
+ test_transform: Optional[Dict[str, Callable]] = None,
+ predict_transform: Optional[Dict[str, Callable]] = None,
+ data_fetcher: Optional[BaseDataFetcher] = None,
+ preprocess: Optional[Preprocess] = None,
+ val_split: Optional[float] = None,
+ batch_size: int = 4,
+ num_workers: Optional[int] = None,
+ sampler: Optional[Sampler] = None,
+ scans_folder_name: Optional[str] = "scans",
+ labels_folder_name: Optional[str] = "labels",
+ calibrations_folder_name: Optional[str] = "calibs",
+ data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI,
+ **preprocess_kwargs: Any,
+ ) -> 'DataModule':
+ """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the
+ :class:`~flash.core.data.data_source.DataSource` of name
+ :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS`
+ from the passed or constructed :class:`~flash.core.data.process.Preprocess`.
+
+ Args:
+ train_folder: The folder containing the train data.
+ val_folder: The folder containing the validation data.
+ test_folder: The folder containing the test data.
+ predict_folder: The folder containing the predict data.
+ train_transform: The dictionary of transforms to use during training which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ val_transform: The dictionary of transforms to use during validation which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ test_transform: The dictionary of transforms to use during testing which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ predict_transform: The dictionary of transforms to use during predicting which maps
+ :class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
+ data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
+ :class:`~flash.core.data.data_module.DataModule`.
+ preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
+ :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
+ will be constructed and used.
+ val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
+ preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
+ if ``preprocess = None``.
+ scans_folder_name: The name of the pointcloud scan folder
+ labels_folder_name: The name of the pointcloud scan labels folder
+ calibrations_folder_name: The name of the pointcloud scan calibration folder
+ data_format: Format in which the data are stored.
+
+ Returns:
+ The constructed data module.
+
+ Examples::
+
+ data_module = DataModule.from_folders(
+ train_folder="train_folder",
+ train_transform={
+ "to_tensor_transform": torch.as_tensor,
+ },
+ )
+ """
+ return cls.from_data_source(
+ DefaultDataSources.FOLDERS,
+ train_folder,
+ val_folder,
+ test_folder,
+ predict_folder,
+ train_transform=train_transform,
+ val_transform=val_transform,
+ test_transform=test_transform,
+ predict_transform=predict_transform,
+ data_fetcher=data_fetcher,
+ preprocess=preprocess,
+ val_split=val_split,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ sampler=sampler,
+ scans_folder_name=scans_folder_name,
+ labels_folder_name=labels_folder_name,
+ calibrations_folder_name=calibrations_folder_name,
+ data_format=data_format,
+ **preprocess_kwargs,
+ )
diff --git a/flash/pointcloud/detection/datasets.py b/flash/pointcloud/detection/datasets.py
new file mode 100644
index 0000000000..4860da1363
--- /dev/null
+++ b/flash/pointcloud/detection/datasets.py
@@ -0,0 +1,41 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+from flash.pointcloud.segmentation.datasets import executor
+
+if _POINTCLOUD_AVAILABLE:
+ from open3d.ml.datasets import KITTI
+
+_OBJECT_DETECTION_DATASET = FlashRegistry("dataset")
+
+
+@_OBJECT_DETECTION_DATASET
+def kitti(dataset_path, download, **kwargs):
+ name = "KITTI"
+ download_path = os.path.join(dataset_path, name, "Kitti")
+ if not os.path.exists(download_path):
+ executor(
+ "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_kitti.sh", # noqa E501
+ None,
+ dataset_path,
+ name
+ )
+ return KITTI(download_path, **kwargs)
+
+
+def KITTIDataset(dataset_path, download: bool = True, **kwargs):
+ return _OBJECT_DETECTION_DATASET.get("kitti")(dataset_path, download, **kwargs)
diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py
new file mode 100644
index 0000000000..ff1e718484
--- /dev/null
+++ b/flash/pointcloud/detection/model.py
@@ -0,0 +1,187 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union
+
+import torch
+import torchmetrics
+from torch import nn
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.utils.data import DataLoader, Sampler
+
+import flash
+from flash.core.data.auto_dataset import BaseAutoDataset
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.process import Serializer
+from flash.core.data.states import CollateFn
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.apply_func import get_callable_dict
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+from flash.pointcloud.detection.backbones import POINTCLOUD_OBJECT_DETECTION_BACKBONES
+
+__FILE_EXAMPLE__ = "pointcloud_detection"
+
+
+class PointCloudObjectDetectorSerializer(Serializer):
+ pass
+
+
+class PointCloudObjectDetector(flash.Task):
+ """The ``PointCloudObjectDetector`` is a :class:`~flash.core.classification.ClassificationTask` that classifies
+ pointcloud data.
+
+ Args:
+ num_features: The number of features (elements) in the input data.
+ num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`.
+ backbone: The backbone name (or a tuple of ``nn.Module``, output size) to use.
+ backbone_kwargs: Any additional kwargs to pass to the backbone constructor.
+ loss_fn: The loss function to use. If ``None``, a default will be selected by the
+ :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
+ optimizer: The optimizer or optimizer class to use.
+ optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance).
+ scheduler: The scheduler or scheduler class to use.
+ scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance).
+ metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected
+ by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument.
+ learning_rate: The learning rate for the optimizer.
+ multi_label: If ``True``, this will be treated as a multi-label classification problem.
+ serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs.
+ lambda_loss_cls: The value to scale the loss classification.
+ lambda_loss_bbox: The value to scale the bounding boxes loss.
+ lambda_loss_dir: The value to scale the bounding boxes direction loss.
+ """
+
+ backbones: FlashRegistry = POINTCLOUD_OBJECT_DETECTION_BACKBONES
+ required_extras: str = "pointcloud"
+
+ def __init__(
+ self,
+ num_classes: int,
+ backbone: Union[str, Tuple[nn.Module, int]] = "pointpillars_kitti",
+ backbone_kwargs: Optional[Dict] = None,
+ head: Optional[nn.Module] = None,
+ loss_fn: Optional[Callable] = None,
+ optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam,
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
+ scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None,
+ scheduler_kwargs: Optional[Dict[str, Any]] = None,
+ metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
+ learning_rate: float = 1e-2,
+ serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudObjectDetectorSerializer(),
+ lambda_loss_cls: float = 1.,
+ lambda_loss_bbox: float = 1.,
+ lambda_loss_dir: float = 1.,
+ ):
+
+ super().__init__(
+ model=None,
+ loss_fn=loss_fn,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
+ scheduler=scheduler,
+ scheduler_kwargs=scheduler_kwargs,
+ metrics=metrics,
+ learning_rate=learning_rate,
+ serializer=serializer,
+ )
+
+ self.save_hyperparameters()
+
+ if backbone_kwargs is None:
+ backbone_kwargs = {}
+
+ if isinstance(backbone, tuple):
+ self.backbone, out_features = backbone
+ else:
+ self.model, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs)
+ self.backbone = self.model.backbone
+ self.neck = self.model.neck
+ self.set_state(CollateFn(collate_fn))
+ self.set_state(CollateFn(collate_fn))
+ self.set_state(CollateFn(collate_fn))
+ self.loss_fn = get_callable_dict(self.model.loss)
+
+ if __FILE_EXAMPLE__ not in sys.argv[0]:
+ self.model.bbox_head.conv_cls = self.head = nn.Conv2d(
+ out_features, num_classes, kernel_size=(1, 1), stride=(1, 1)
+ )
+
+ def compute_loss(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ losses = losses["loss"]
+ return (
+ self.hparams.lambda_loss_cls * losses["loss_cls"] + self.hparams.lambda_loss_bbox * losses["loss_bbox"] +
+ self.hparams.lambda_loss_dir * losses["loss_dir"]
+ )
+
+ def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]):
+ logs.update({"loss": self.compute_loss(losses)})
+ return logs
+
+ def training_step(self, batch: Any, batch_idx: int) -> Any:
+ return super().training_step((batch, batch), batch_idx)
+
+ def validation_step(self, batch: Any, batch_idx: int) -> Any:
+ super().validation_step((batch, batch), batch_idx)
+
+ def test_step(self, batch: Any, batch_idx: int) -> Any:
+ super().validation_step((batch, batch), batch_idx)
+
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
+ results = self.model(batch)
+ boxes = self.model.inference_end(results, batch)
+ return {
+ DefaultDataKeys.INPUT: getattr(batch, "point", None),
+ DefaultDataKeys.PREDS: boxes,
+ DefaultDataKeys.METADATA: [a["name"] for a in batch.attr]
+ }
+
+ def forward(self, x) -> torch.Tensor:
+ """First call the backbone, then the model head."""
+ # hack to enable backbone to work properly.
+ self.model.device = self.device
+ return self.model(x)
+
+ def _process_dataset(
+ self,
+ dataset: BaseAutoDataset,
+ batch_size: int,
+ num_workers: int,
+ pin_memory: bool,
+ collate_fn: Callable,
+ shuffle: bool = False,
+ drop_last: bool = True,
+ sampler: Optional[Sampler] = None,
+ convert_to_dataloader: bool = True,
+ ) -> Union[DataLoader, BaseAutoDataset]:
+
+ if not _POINTCLOUD_AVAILABLE:
+ raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.")
+
+ dataset.preprocess_fn = self.model.preprocess
+ dataset.transform_fn = self.model.transform
+
+ if convert_to_dataloader:
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=pin_memory,
+ collate_fn=collate_fn,
+ shuffle=shuffle,
+ drop_last=drop_last,
+ sampler=sampler,
+ )
+
+ else:
+ return dataset
diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py
new file mode 100644
index 0000000000..5578955d8a
--- /dev/null
+++ b/flash/pointcloud/detection/open3d_ml/app.py
@@ -0,0 +1,171 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+
+import flash
+from flash import DataModule
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+if _POINTCLOUD_AVAILABLE:
+
+ from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer
+ from open3d.visualization import gui
+
+ class Visualizer(Visualizer):
+
+ def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768):
+ """Visualize a dataset.
+
+ Example:
+ Minimal example for visualizing a dataset::
+ import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d
+
+ dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/')
+ vis = ml3d.vis.Visualizer()
+ vis.visualize_dataset(dataset, 'all', indices=range(100))
+
+ Args:
+ dataset: The dataset to use for visualization.
+ split: The dataset split to be used, such as 'training'
+ indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4].
+ width: The width of the visualization window.
+ height: The height of the visualization window.
+ """
+ # Setup the labels
+ lut = LabelLUT()
+ for id, color in dataset.color_map.items():
+ lut.add_label(id, id, color=color)
+ self.set_lut("label", lut)
+
+ self._consolidate_bounding_boxes = True
+ self._init_dataset(dataset, split, indices)
+
+ self._visualize("Open3D - " + dataset.name, width, height)
+
+ def _visualize(self, title, width, height):
+ gui.Application.instance.initialize()
+ self._init_user_interface(title, width, height)
+
+ # override just to set background color to back :)
+ bgcolor = gui.ColorEdit()
+ bgcolor.color_value = gui.Color(0, 0, 0)
+ self._on_bgcolor_changed(bgcolor.color_value)
+
+ self._3d.scene.downsample_threshold = 400000
+
+ # Turn all the objects off except the first one
+ for name, node in self._name2treenode.items():
+ node.checkbox.checked = False
+ self._3d.scene.show_geometry(name, False)
+ for name in [self._objects.data_names[0]]:
+ self._name2treenode[name].checkbox.checked = True
+ self._3d.scene.show_geometry(name, True)
+
+ def on_done_ui():
+ # Add bounding boxes here: bounding boxes belonging to the dataset
+ # will not be loaded until now.
+ self._update_bounding_boxes()
+
+ self._update_datasource_combobox()
+ self._update_shaders_combobox()
+
+ # Display "colors" by default if available, "points" if not
+ available_attrs = self._get_available_attrs()
+ self._set_shader(self.SOLID_NAME, force_update=True)
+ if "colors" in available_attrs:
+ self._datasource_combobox.selected_text = "colors"
+ elif "points" in available_attrs:
+ self._datasource_combobox.selected_text = "points"
+
+ self._dont_update_geometry = True
+ self._on_datasource_changed(
+ self._datasource_combobox.selected_text, self._datasource_combobox.selected_index
+ )
+ self._update_geometry_colors()
+ self._dont_update_geometry = False
+ # _datasource_combobox was empty, now isn't, re-layout.
+ self.window.set_needs_layout()
+
+ self._update_geometry()
+ self.setup_camera()
+
+ self._load_geometries(self._objects.data_names, on_done_ui)
+ gui.Application.instance.run()
+
+ class VizDataset(Dataset):
+
+ name = "VizDataset"
+
+ def __init__(self, dataset):
+ self.dataset = dataset
+ self.label_to_names = getattr(dataset, "label_to_names", {})
+ self.path_list = getattr(dataset, "path_list", [])
+ self.color_map = getattr(dataset, "color_map", {})
+
+ def get_data(self, index):
+ data = self.dataset[index]["data"]
+ data["bounding_boxes"] = data["bbox_objs"]
+ data["color"] = np.ones_like(data["point"])
+ return data
+
+ def get_attr(self, index):
+ return self.dataset[index]["attr"]
+
+ def get_split(self, *_) -> 'VizDataset':
+ return self
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ class App:
+
+ def __init__(self, datamodule: DataModule):
+ self.datamodule = datamodule
+ self._enabled = not flash._IS_TESTING
+
+ def get_dataset(self, stage: str = "train"):
+ dataloader = getattr(self.datamodule, f"{stage}_dataloader")()
+ return VizDataset(dataloader.dataset)
+
+ def show_train_dataset(self, indices=None):
+ if self._enabled:
+ dataset = self.get_dataset("train")
+ viz = Visualizer()
+ viz.visualize_dataset(dataset, 'all', indices=indices)
+
+ def show_predictions(self, predictions):
+ if self._enabled:
+ dataset = self.get_dataset("train")
+
+ viz = Visualizer()
+ lut = LabelLUT()
+ for id, color in dataset.color_map.items():
+ lut.add_label(id, id, color=color)
+ viz.set_lut("label", lut)
+
+ for pred in predictions:
+ data = {
+ "points": torch.stack(pred[DefaultDataKeys.INPUT])[:, :3],
+ "name": pred[DefaultDataKeys.METADATA],
+ }
+ bounding_box = pred[DefaultDataKeys.PREDS]
+
+ viz.visualize([data], bounding_boxes=bounding_box)
+
+
+def launch_app(datamodule: DataModule) -> 'App':
+ return App(datamodule)
diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/flash/pointcloud/detection/open3d_ml/backbones.py
new file mode 100644
index 0000000000..622971299e
--- /dev/null
+++ b/flash/pointcloud/detection/open3d_ml/backbones.py
@@ -0,0 +1,81 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+from abc import ABC
+from typing import Callable
+
+import torch
+from pytorch_lightning.utilities.cloud_io import load as pl_load
+
+from flash.core.registry import FlashRegistry
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/"
+
+if _POINTCLOUD_AVAILABLE:
+ import open3d
+ import open3d.ml as _ml3d
+ from open3d._ml3d.torch.dataloaders.concat_batcher import ConcatBatcher, ObjectDetectBatch
+ from open3d._ml3d.torch.models.point_pillars import PointPillars
+ from open3d.ml.torch.dataloaders import DefaultBatcher
+else:
+ ObjectDetectBatch = ABC
+ PointPillars = ABC
+
+
+class ObjectDetectBatchCollator(ObjectDetectBatch):
+
+ def __init__(self, batches):
+ self.num_batches = len(batches)
+ super().__init__(batches)
+
+ def to(self, device):
+ super().to(device)
+ return self
+
+ def __len__(self):
+ return self.num_batches
+
+
+def register_open_3d_ml(register: FlashRegistry):
+
+ if _POINTCLOUD_AVAILABLE:
+
+ CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs")
+
+ def get_collate_fn(model) -> Callable:
+ batcher_name = model.cfg.batcher
+ if batcher_name == 'DefaultBatcher':
+ batcher = DefaultBatcher()
+ elif batcher_name == 'ConcatBatcher':
+ batcher = ConcatBatcher(torch, model.__class__.__name__)
+ elif batcher_name == 'ObjectDetectBatchCollator':
+ return ObjectDetectBatchCollator
+ return batcher.collate_fn
+
+ @register(parameters=PointPillars.__init__)
+ def pointpillars_kitti(*args, **kwargs) -> PointPillars:
+ cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "pointpillars_kitti.yml"))
+ cfg.model.device = "cpu"
+ model = PointPillars(**cfg.model)
+ weight_url = os.path.join(ROOT_URL, "pointpillars_kitti_202012221652utc.pth")
+ model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict'], )
+ model.cfg.batcher = "ObjectDetectBatchCollator"
+ return model, 384, get_collate_fn(model)
+
+ @register(parameters=PointPillars.__init__)
+ def pointpillars(*args, **kwargs) -> PointPillars:
+ model = PointPillars(*args, **kwargs)
+ model.cfg.batcher = "ObjectDetectBatch"
+ return model, get_collate_fn(model)
diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/data_sources.py
new file mode 100644
index 0000000000..bd594ebe2f
--- /dev/null
+++ b/flash/pointcloud/detection/open3d_ml/data_sources.py
@@ -0,0 +1,244 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from os.path import basename, dirname, exists, isdir, isfile, join
+from posix import listdir
+from typing import Any, Dict, List, Optional, Union
+
+import yaml
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+
+from flash.core.data.auto_dataset import BaseAutoDataset
+from flash.core.data.data_source import BaseDataFormat, DataSource
+from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
+
+if _POINTCLOUD_AVAILABLE:
+ from open3d._ml3d.datasets.kitti import DataProcessing, KITTI
+
+
+class PointCloudObjectDetectionDataFormat(BaseDataFormat):
+ KITTI = "kitti"
+
+
+class BasePointCloudObjectDetectorLoader:
+
+ pass
+
+
+class KITTIPointCloudObjectDetectorLoader(BasePointCloudObjectDetectorLoader):
+
+ def __init__(
+ self,
+ image_size: tuple = (375, 1242),
+ scans_folder_name: Optional[str] = "scans",
+ labels_folder_name: Optional[str] = "labels",
+ calibrations_folder_name: Optional[str] = "calibs",
+ **kwargs,
+ ):
+
+ self.image_size = image_size
+ self.scans_folder_name = scans_folder_name
+ self.labels_folder_name = labels_folder_name
+ self.calibrations_folder_name = calibrations_folder_name
+
+ def load_meta(self, root_dir, dataset: Optional[BaseAutoDataset]):
+ meta_file = join(root_dir, "meta.yaml")
+ if not exists(meta_file):
+ raise MisconfigurationException(f"The {root_dir} should contain a `meta.yaml` file about the classes.")
+
+ with open(meta_file, 'r') as f:
+ self.meta = yaml.safe_load(f)
+
+ if "label_to_names" not in self.meta:
+ raise MisconfigurationException(
+ f"The {root_dir} should contain a `meta.yaml` file about the classes with the field `label_to_names`."
+ )
+
+ dataset.num_classes = len(self.meta["label_to_names"])
+ dataset.label_to_names = self.meta["label_to_names"]
+ dataset.color_map = self.meta["color_map"]
+
+ def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]):
+ sub_directories = listdir(folder)
+ if len(sub_directories) != 3:
+ raise MisconfigurationException(
+ f"Using KITTI Format, the {folder} should contains 3 directories "
+ "for ``calibrations``, ``labels`` and ``scans``."
+ )
+
+ assert self.scans_folder_name in sub_directories
+ assert self.labels_folder_name in sub_directories
+ assert self.calibrations_folder_name in sub_directories
+
+ scans_dir = join(folder, self.scans_folder_name)
+ labels_dir = join(folder, self.labels_folder_name)
+ calibrations_dir = join(folder, self.calibrations_folder_name)
+
+ scan_paths = [join(scans_dir, f) for f in listdir(scans_dir)]
+ label_paths = [join(labels_dir, f) for f in listdir(labels_dir)]
+ calibration_paths = [join(calibrations_dir, f) for f in listdir(calibrations_dir)]
+
+ assert len(scan_paths) == len(label_paths) == len(calibration_paths)
+
+ self.load_meta(dirname(folder), dataset)
+
+ dataset.path_list = scan_paths
+
+ return [{
+ "scan_path": scan_path,
+ "label_path": label_path,
+ "calibration_path": calibration_path
+ } for scan_path, label_path, calibration_path, in zip(scan_paths, label_paths, calibration_paths)]
+
+ def load_sample(
+ self, sample: Dict[str, str], dataset: Optional[BaseAutoDataset] = None, has_label: bool = True
+ ) -> Any:
+ pc = KITTI.read_lidar(sample["scan_path"])
+ calib = KITTI.read_calib(sample["calibration_path"])
+ label = None
+ if has_label:
+ label = KITTI.read_label(sample["label_path"], calib)
+
+ reduced_pc = DataProcessing.remove_outside_points(pc, calib['world_cam'], calib['cam_img'], self.image_size)
+
+ attr = {
+ "name": basename(sample["scan_path"]),
+ "path": sample["scan_path"],
+ "calibration_path": sample["calibration_path"],
+ "label_path": sample["label_path"] if has_label else None,
+ "split": "val",
+ }
+
+ data = {
+ 'point': reduced_pc,
+ 'full_point': pc,
+ 'feat': None,
+ 'calib': calib,
+ 'bounding_boxes': label if has_label else None,
+ 'attr': attr
+ }
+ return data, attr
+
+ def load_files(self, scan_paths: Union[str, List[str]], dataset: Optional[BaseAutoDataset] = None):
+ if isinstance(scan_paths, str):
+ scan_paths = [scan_paths]
+
+ def clean_fn(path: str) -> str:
+ return path.replace(self.scans_folder_name, self.calibrations_folder_name).replace(".bin", ".txt")
+
+ dataset.path_list = scan_paths
+
+ return [{"scan_path": scan_path, "calibration_path": clean_fn(scan_path)} for scan_path in scan_paths]
+
+ def predict_load_data(self, data, dataset: Optional[BaseAutoDataset] = None):
+ if (isinstance(data, str) and isfile(data)) or (isinstance(data, list) and all(isfile(p) for p in data)):
+ return self.load_files(data, dataset)
+ elif isinstance(data, str) and isdir(data):
+ raise NotImplementedError
+
+ def predict_load_sample(self, data, dataset: Optional[BaseAutoDataset] = None):
+ data, attr = self.load_sample(data, dataset, has_label=False)
+ # hack to prevent manipulation of labels
+ attr["split"] = "test"
+ return data, attr
+
+
+class PointCloudObjectDetectorFoldersDataSource(DataSource):
+
+ def __init__(
+ self,
+ data_format: Optional[BaseDataFormat] = None,
+ image_size: tuple = (375, 1242),
+ **loader_kwargs,
+ ):
+ super().__init__()
+
+ self.loaders = {
+ PointCloudObjectDetectionDataFormat.KITTI: KITTIPointCloudObjectDetectorLoader(
+ **loader_kwargs, image_size=image_size
+ )
+ }
+
+ self.data_format = data_format or PointCloudObjectDetectionDataFormat.KITTI
+ self.loader = self.loaders[data_format]
+
+ def _validate_data(self, folder: str) -> None:
+ msg = f"The provided dataset for stage {self._running_stage} should be a folder. Found {folder}."
+ if not isinstance(folder, str):
+ raise MisconfigurationException(msg)
+
+ if isinstance(folder, str) and not isdir(folder):
+ raise MisconfigurationException(msg)
+
+ def load_data(
+ self,
+ data: Any,
+ dataset: Optional[BaseAutoDataset] = None,
+ ) -> Any:
+
+ self._validate_data(data)
+
+ return self.loader.load_data(data, dataset)
+
+ def load_sample(self, metadata: Dict[str, str], dataset: Optional[BaseAutoDataset] = None) -> Any:
+
+ data, metadata = self.loader.load_sample(metadata, dataset)
+
+ preprocess_fn = getattr(dataset, "preprocess_fn", None)
+ if preprocess_fn:
+ data = preprocess_fn(data, metadata)
+
+ transform_fn = getattr(dataset, "transform_fn", None)
+ if transform_fn:
+ data = transform_fn(data, metadata)
+
+ return {"data": data, "attr": metadata}
+
+ def _validate_predict_data(self, data: Union[str, List[str]]) -> None:
+ msg = f"The provided predict data should be a either a folder or a single/list of scan path(s). Found {data}."
+ if not isinstance(data, str) and not isinstance(data, list):
+ raise MisconfigurationException(msg)
+
+ if isinstance(data, str) and (not isfile(data) or not isdir(data)):
+ raise MisconfigurationException(msg)
+
+ if isinstance(data, list) and not all(isfile(p) for p in data):
+ raise MisconfigurationException(msg)
+
+ def predict_load_data(
+ self,
+ data: Any,
+ dataset: Optional[BaseAutoDataset] = None,
+ ) -> Any:
+
+ self._validate_predict_data(data)
+
+ return self.loader.predict_load_data(data, dataset)
+
+ def predict_load_sample(
+ self,
+ metadata: Any,
+ dataset: Optional[BaseAutoDataset] = None,
+ ) -> Any:
+
+ data, metadata = self.loader.predict_load_sample(metadata, dataset)
+
+ preprocess_fn = getattr(dataset, "preprocess_fn", None)
+ if preprocess_fn:
+ data = preprocess_fn(data, metadata)
+
+ transform_fn = getattr(dataset, "transform_fn", None)
+ if transform_fn:
+ data = transform_fn(data, metadata)
+
+ return {"data": data, "attr": metadata}
diff --git a/flash/pointcloud/segmentation/__init__.py b/flash/pointcloud/segmentation/__init__.py
index bf7f46a89c..5d10606f79 100644
--- a/flash/pointcloud/segmentation/__init__.py
+++ b/flash/pointcloud/segmentation/__init__.py
@@ -1,2 +1,3 @@
from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401
from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401
+from flash.pointcloud.segmentation.open3d_ml.app import launch_app # noqa: F401
diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py
index a226d6f5b2..879f45570e 100644
--- a/flash/pointcloud/segmentation/open3d_ml/app.py
+++ b/flash/pointcloud/segmentation/open3d_ml/app.py
@@ -13,7 +13,6 @@
# limitations under the License.
import torch
-import flash
from flash import DataModule
from flash.core.data.data_source import DefaultDataKeys
from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE
@@ -58,7 +57,7 @@ class App:
def __init__(self, datamodule: DataModule):
self.datamodule = datamodule
- self._enabled = not flash._IS_TESTING
+ self._enabled = True # not flash._IS_TESTING
def get_dataset(self, stage: str = "train"):
dataloader = getattr(self.datamodule, f"{stage}_dataloader")()
diff --git a/flash/pointcloud/segmentation/open3d_ml/backbones.py b/flash/pointcloud/segmentation/open3d_ml/backbones.py
index 0fe44a72ce..aec3aa0123 100644
--- a/flash/pointcloud/segmentation/open3d_ml/backbones.py
+++ b/flash/pointcloud/segmentation/open3d_ml/backbones.py
@@ -27,8 +27,8 @@ def register_open_3d_ml(register: FlashRegistry):
if _POINTCLOUD_AVAILABLE:
import open3d
import open3d.ml as _ml3d
- from open3d.ml.torch.dataloaders import ConcatBatcher, DefaultBatcher
- from open3d.ml.torch.models import RandLANet
+ from open3d._ml3d.torch.dataloaders import ConcatBatcher, DefaultBatcher
+ from open3d._ml3d.torch.models import RandLANet
CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs")
diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py
new file mode 100644
index 0000000000..b8f0f8a312
--- /dev/null
+++ b/flash_examples/audio_classification.py
@@ -0,0 +1,45 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import flash
+from flash.audio import AudioClassificationData
+from flash.core.data.utils import download_data
+from flash.core.finetuning import FreezeUnfreeze
+from flash.image import ImageClassifier
+
+# 1. Create the DataModule
+download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data")
+
+datamodule = AudioClassificationData.from_folders(
+ train_folder="data/urban8k_images/train",
+ val_folder="data/urban8k_images/val",
+ spectrogram_size=(64, 64),
+)
+
+# 2. Build the model.
+model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(max_epochs=3)
+trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))
+
+# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c
+predictions = model.predict([
+ "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg",
+ "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg",
+ "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg",
+])
+print(predictions)
+
+# 5. Save the model!
+trainer.save_checkpoint("audio_classification_model.pt")
diff --git a/flash_examples/pointcloud_detection.py b/flash_examples/pointcloud_detection.py
new file mode 100644
index 0000000000..6cd0409893
--- /dev/null
+++ b/flash_examples/pointcloud_detection.py
@@ -0,0 +1,41 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import flash
+from flash.core.data.utils import download_data
+from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData
+
+# 1. Create the DataModule
+# Dataset Credit: http://www.semantic-kitti.org/
+download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")
+
+datamodule = PointCloudObjectDetectorData.from_folders(
+ train_folder="data/KITTI_Tiny/Kitti/train",
+ val_folder="data/KITTI_Tiny/Kitti/val",
+)
+
+# 2. Build the task
+model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0)
+trainer.fit(model, datamodule)
+
+# 4. Predict what's within a few PointClouds?
+predictions = model.predict([
+ "data/KITTI_Tiny/Kitti/predict/scans/000000.bin",
+ "data/KITTI_Tiny/Kitti/predict/scans/000001.bin",
+])
+
+# 5. Save the model!
+trainer.save_checkpoint("pointcloud_segmentation_model.pt")
diff --git a/flash_examples/visualizations/pointcloud_detection.py b/flash_examples/visualizations/pointcloud_detection.py
new file mode 100644
index 0000000000..ebfb0eb5a0
--- /dev/null
+++ b/flash_examples/visualizations/pointcloud_detection.py
@@ -0,0 +1,43 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import flash
+from flash.core.data.utils import download_data
+from flash.pointcloud.detection import launch_app, PointCloudObjectDetector, PointCloudObjectDetectorData
+
+# 1. Create the DataModule
+# Dataset Credit: http://www.semantic-kitti.org/
+download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/")
+
+datamodule = PointCloudObjectDetectorData.from_folders(
+ train_folder="data/KITTI_Tiny/Kitti/train",
+ val_folder="data/KITTI_Tiny/Kitti/val",
+)
+
+# 2. Build the task
+model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes)
+
+# 3. Create the trainer and finetune the model
+trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0)
+trainer.fit(model, datamodule)
+
+# 4. Predict what's within a few PointClouds?
+predictions = model.predict(["data/KITTI_Tiny/Kitti/predict/scans/000000.bin"])
+
+# 5. Save the model!
+trainer.save_checkpoint("pointcloud_segmentation_model.pt")
+
+# 6. Optional Visualize
+app = launch_app(datamodule)
+# app.show_train_dataset()
+app.show_predictions(predictions)
diff --git a/flash_examples/visualizations/pointcloud_segmentation.py b/flash_examples/visualizations/pointcloud_segmentation.py
index e4859a8d90..85565a7027 100644
--- a/flash_examples/visualizations/pointcloud_segmentation.py
+++ b/flash_examples/visualizations/pointcloud_segmentation.py
@@ -13,7 +13,7 @@
# limitations under the License.
import flash
from flash.core.data.utils import download_data
-from flash.pointcloud import launch_app, PointCloudSegmentation, PointCloudSegmentationData
+from flash.pointcloud.segmentation import launch_app, PointCloudSegmentation, PointCloudSegmentationData
# 1. Create the DataModule
# Dataset Credit: http://www.semantic-kitti.org/
@@ -42,4 +42,5 @@
# 6. Optional Visualize
app = launch_app(datamodule)
+# app.show_train_dataset()
app.show_predictions(predictions)
diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt
index 03c90d99ec..e608a13b78 100644
--- a/requirements/datatype_audio.txt
+++ b/requirements/datatype_audio.txt
@@ -1 +1,2 @@
asteroid>=0.5.1
+torchaudio
diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/audio/classification/__init__.py b/tests/audio/classification/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py
new file mode 100644
index 0000000000..a1c0ba0677
--- /dev/null
+++ b/tests/audio/classification/test_data.py
@@ -0,0 +1,340 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pathlib import Path
+from typing import Any, List, Tuple
+
+import numpy as np
+import pytest
+import torch
+import torch.nn as nn
+
+from flash.audio import AudioClassificationData
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.transforms import ApplyToKeys
+from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE
+from tests.helpers.utils import _AUDIO_TESTING
+
+if _TORCHVISION_AVAILABLE:
+ import torchvision
+
+if _PIL_AVAILABLE:
+ from PIL import Image
+
+
+def _rand_image(size: Tuple[int, int] = None):
+ if size is None:
+ _size = np.random.choice([196, 244])
+ size = (_size, _size)
+ return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8"))
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_filepaths_smoke(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "a").mkdir()
+ (tmpdir / "b").mkdir()
+ _rand_image().save(tmpdir / "a_1.png")
+ _rand_image().save(tmpdir / "b_1.png")
+
+ train_images = [
+ str(tmpdir / "a_1.png"),
+ str(tmpdir / "b_1.png"),
+ ]
+
+ spectrograms_data = AudioClassificationData.from_files(
+ train_files=train_images,
+ train_targets=[1, 2],
+ batch_size=2,
+ num_workers=0,
+ )
+ assert spectrograms_data.train_dataloader() is not None
+ assert spectrograms_data.val_dataloader() is None
+ assert spectrograms_data.test_dataloader() is None
+
+ data = next(iter(spectrograms_data.train_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, )
+ assert sorted(list(labels.numpy())) == [1, 2]
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_filepaths_list_image_paths(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "e").mkdir()
+ _rand_image().save(tmpdir / "e_1.png")
+
+ train_images = [
+ str(tmpdir / "e_1.png"),
+ str(tmpdir / "e_1.png"),
+ str(tmpdir / "e_1.png"),
+ ]
+
+ spectrograms_data = AudioClassificationData.from_files(
+ train_files=train_images,
+ train_targets=[0, 3, 6],
+ val_files=train_images,
+ val_targets=[1, 4, 7],
+ test_files=train_images,
+ test_targets=[2, 5, 8],
+ batch_size=2,
+ num_workers=0,
+ )
+
+ # check training data
+ data = next(iter(spectrograms_data.train_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, )
+ assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here
+ assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here
+
+ # check validation data
+ data = next(iter(spectrograms_data.val_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, )
+ assert list(labels.numpy()) == [1, 4]
+
+ # check test data
+ data = next(iter(spectrograms_data.test_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, )
+ assert list(labels.numpy()) == [2, 5]
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
+def test_from_filepaths_visualise(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "e").mkdir()
+ _rand_image().save(tmpdir / "e_1.png")
+
+ train_images = [
+ str(tmpdir / "e_1.png"),
+ str(tmpdir / "e_1.png"),
+ str(tmpdir / "e_1.png"),
+ ]
+
+ dm = AudioClassificationData.from_files(
+ train_files=train_images,
+ train_targets=[0, 3, 6],
+ val_files=train_images,
+ val_targets=[1, 4, 7],
+ test_files=train_images,
+ test_targets=[2, 5, 8],
+ batch_size=2,
+ num_workers=0,
+ )
+
+ # disable visualisation for testing
+ assert dm.data_fetcher.block_viz_window is True
+ dm.set_block_viz_window(False)
+ assert dm.data_fetcher.block_viz_window is False
+
+ # call show functions
+ # dm.show_train_batch()
+ dm.show_train_batch("pre_tensor_transform")
+ dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"])
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
+def test_from_filepaths_visualise_multilabel(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "a").mkdir()
+ (tmpdir / "b").mkdir()
+
+ image_a = str(tmpdir / "a" / "a_1.png")
+ image_b = str(tmpdir / "b" / "b_1.png")
+
+ _rand_image().save(image_a)
+ _rand_image().save(image_b)
+
+ dm = AudioClassificationData.from_files(
+ train_files=[image_a, image_b],
+ train_targets=[[0, 1, 0], [0, 1, 1]],
+ val_files=[image_b, image_a],
+ val_targets=[[1, 1, 0], [0, 0, 1]],
+ test_files=[image_b, image_b],
+ test_targets=[[0, 0, 1], [1, 1, 0]],
+ batch_size=2,
+ spectrogram_size=(64, 64),
+ )
+ # disable visualisation for testing
+ assert dm.data_fetcher.block_viz_window is True
+ dm.set_block_viz_window(False)
+ assert dm.data_fetcher.block_viz_window is False
+
+ # call show functions
+ dm.show_train_batch()
+ dm.show_train_batch("pre_tensor_transform")
+ dm.show_train_batch("to_tensor_transform")
+ dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"])
+ dm.show_val_batch("per_batch_transform")
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_filepaths_splits(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ B, _, H, W = 2, 3, 224, 224
+ img_size: Tuple[int, int] = (H, W)
+
+ (tmpdir / "splits").mkdir()
+ _rand_image(img_size).save(tmpdir / "s.png")
+
+ num_samples: int = 10
+ val_split: float = .3
+
+ train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)]
+
+ train_labels: List[int] = list(range(num_samples))
+
+ assert len(train_filepaths) == len(train_labels)
+
+ _to_tensor = {
+ "to_tensor_transform": nn.Sequential(
+ ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()),
+ ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor)
+ ),
+ }
+
+ def run(transform: Any = None):
+ dm = AudioClassificationData.from_files(
+ train_files=train_filepaths,
+ train_targets=train_labels,
+ train_transform=transform,
+ val_transform=transform,
+ batch_size=B,
+ num_workers=0,
+ val_split=val_split,
+ spectrogram_size=img_size,
+ )
+ data = next(iter(dm.train_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (B, 3, H, W)
+ assert labels.shape == (B, )
+
+ run(_to_tensor)
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_folders_only_train(tmpdir):
+ train_dir = Path(tmpdir / "train")
+ train_dir.mkdir()
+
+ (train_dir / "a").mkdir()
+ _rand_image().save(train_dir / "a" / "1.png")
+ _rand_image().save(train_dir / "a" / "2.png")
+
+ (train_dir / "b").mkdir()
+ _rand_image().save(train_dir / "b" / "1.png")
+ _rand_image().save(train_dir / "b" / "2.png")
+
+ spectrograms_data = AudioClassificationData.from_folders(train_dir, train_transform=None, batch_size=1)
+
+ data = next(iter(spectrograms_data.train_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (1, 3, 196, 196)
+ assert labels.shape == (1, )
+
+ assert spectrograms_data.val_dataloader() is None
+ assert spectrograms_data.test_dataloader() is None
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_folders_train_val(tmpdir):
+
+ train_dir = Path(tmpdir / "train")
+ train_dir.mkdir()
+
+ (train_dir / "a").mkdir()
+ _rand_image().save(train_dir / "a" / "1.png")
+ _rand_image().save(train_dir / "a" / "2.png")
+
+ (train_dir / "b").mkdir()
+ _rand_image().save(train_dir / "b" / "1.png")
+ _rand_image().save(train_dir / "b" / "2.png")
+ spectrograms_data = AudioClassificationData.from_folders(
+ train_dir,
+ val_folder=train_dir,
+ test_folder=train_dir,
+ batch_size=2,
+ num_workers=0,
+ )
+
+ data = next(iter(spectrograms_data.train_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, )
+
+ data = next(iter(spectrograms_data.val_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, )
+ assert list(labels.numpy()) == [0, 0]
+
+ data = next(iter(spectrograms_data.test_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, )
+ assert list(labels.numpy()) == [0, 0]
+
+
+@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.")
+def test_from_filepaths_multilabel(tmpdir):
+ tmpdir = Path(tmpdir)
+
+ (tmpdir / "a").mkdir()
+ _rand_image().save(tmpdir / "a1.png")
+ _rand_image().save(tmpdir / "a2.png")
+
+ train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")]
+ train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]]
+ valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]]
+ test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]]
+
+ dm = AudioClassificationData.from_files(
+ train_files=train_images,
+ train_targets=train_labels,
+ val_files=train_images,
+ val_targets=valid_labels,
+ test_files=train_images,
+ test_targets=test_labels,
+ batch_size=2,
+ num_workers=0,
+ )
+
+ data = next(iter(dm.train_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 4)
+
+ data = next(iter(dm.val_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 4)
+ torch.testing.assert_allclose(labels, torch.tensor(valid_labels))
+
+ data = next(iter(dm.test_dataloader()))
+ imgs, labels = data['input'], data['target']
+ assert imgs.shape == (2, 3, 196, 196)
+ assert labels.shape == (2, 4)
+ torch.testing.assert_allclose(labels, torch.tensor(test_labels))
diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py
index 68252601e5..56b729e36e 100644
--- a/tests/examples/test_scripts.py
+++ b/tests/examples/test_scripts.py
@@ -21,6 +21,7 @@
from flash.core.utilities.imports import _SKLEARN_AVAILABLE
from tests.examples.utils import run_test
from tests.helpers.utils import (
+ _AUDIO_TESTING,
_GRAPH_TESTING,
_IMAGE_TESTING,
_POINTCLOUD_TESTING,
@@ -37,6 +38,10 @@
pytest.param(
"custom_task.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
+ pytest.param(
+ "audio_classification.py",
+ marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed")
+ ),
pytest.param(
"image_classification.py",
marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed")
@@ -81,6 +86,10 @@
"pointcloud_segmentation.py",
marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
),
+ pytest.param(
+ "pointcloud_detection.py",
+ marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+ ),
pytest.param(
"graph_classification.py",
marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed")
@@ -89,3 +98,16 @@
)
def test_example(tmpdir, file):
run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file))
+
+
+@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"})
+@pytest.mark.parametrize(
+ "file", [
+ pytest.param(
+ "pointcloud_detection.py",
+ marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+ ),
+ ]
+)
+def test_example_2(tmpdir, file):
+ run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file))
diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py
index 5bb699b664..bd57cf570d 100644
--- a/tests/helpers/utils.py
+++ b/tests/helpers/utils.py
@@ -14,6 +14,7 @@
import os
from flash.core.utilities.imports import (
+ _AUDIO_AVAILABLE,
_GRAPH_AVAILABLE,
_IMAGE_AVAILABLE,
_POINTCLOUD_AVAILABLE,
@@ -30,6 +31,7 @@
_SERVE_TESTING = _SERVE_AVAILABLE
_POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE
_GRAPH_TESTING = _GRAPH_AVAILABLE
+_AUDIO_TESTING = _AUDIO_AVAILABLE
if "FLASH_TEST_TOPIC" in os.environ:
topic = os.environ["FLASH_TEST_TOPIC"]
@@ -40,3 +42,4 @@
_SERVE_TESTING = topic == "serve"
_POINTCLOUD_TESTING = topic == "pointcloud"
_GRAPH_TESTING = topic == "graph"
+ _AUDIO_TESTING = topic == "audio"
diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py
index 6a80b5774a..87cb183504 100644
--- a/tests/image/classification/test_data.py
+++ b/tests/image/classification/test_data.py
@@ -168,7 +168,7 @@ def test_from_filepaths_visualise(tmpdir):
dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"])
-@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.")
+@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.")
@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.")
def test_from_filepaths_visualise_multilabel(tmpdir):
tmpdir = Path(tmpdir)
diff --git a/tests/pointcloud/detection/__init__.py b/tests/pointcloud/detection/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py
new file mode 100644
index 0000000000..26484f476e
--- /dev/null
+++ b/tests/pointcloud/detection/test_data.py
@@ -0,0 +1,60 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from os.path import join
+
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from flash import Trainer
+from flash.core.data.data_source import DefaultDataKeys
+from flash.core.data.utils import download_data
+from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData
+from tests.helpers.utils import _POINTCLOUD_TESTING
+
+if _POINTCLOUD_TESTING:
+ from flash.pointcloud.detection.open3d_ml.backbones import ObjectDetectBatchCollator
+
+
+@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+def test_pointcloud_object_detection_data(tmpdir):
+
+ seed_everything(52)
+
+ download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_micro.zip", tmpdir)
+
+ dm = PointCloudObjectDetectorData.from_folders(train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train"), )
+
+ class MockModel(PointCloudObjectDetector):
+
+ def training_step(self, batch, batch_idx: int):
+ assert isinstance(batch, ObjectDetectBatchCollator)
+ assert len(batch.point) == 2
+ assert batch.point[0][1].shape == torch.Size([4])
+ assert len(batch.bboxes) > 1
+ assert batch.attr[0]["name"] == '000000.bin'
+ assert batch.attr[1]["name"] == '000001.bin'
+
+ num_classes = 19
+ model = MockModel(backbone="pointpillars_kitti", num_classes=num_classes)
+ trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0)
+ trainer.fit(model, dm)
+
+ predict_path = join(tmpdir, "KITTI_Micro", "Kitti", "predict")
+ model.eval()
+
+ predictions = model.predict([join(predict_path, "scans/000000.bin")])
+ assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape[1] == 4
+ assert len(predictions[0][DefaultDataKeys.PREDS]) == 158
+ assert predictions[0][DefaultDataKeys.PREDS][0].__dict__["identifier"] == 'box:1'
diff --git a/tests/pointcloud/detection/test_model.py b/tests/pointcloud/detection/test_model.py
new file mode 100644
index 0000000000..b7d807c837
--- /dev/null
+++ b/tests/pointcloud/detection/test_model.py
@@ -0,0 +1,24 @@
+# Copyright The PyTorch Lightning team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+from flash.pointcloud.detection import PointCloudObjectDetector
+from tests.helpers.utils import _POINTCLOUD_TESTING
+
+
+@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed")
+def test_backbones():
+
+ backbones = PointCloudObjectDetector.available_backbones()
+ assert backbones == ['pointpillars', 'pointpillars_kitti']