From 00fd908c31637d457cb5032da6214d8e194393b8 Mon Sep 17 00:00:00 2001 From: Suman Michael Date: Fri, 16 Jul 2021 20:52:54 +0530 Subject: [PATCH 1/4] Replaced available_models in docs (#602) Replaced available_models in docs/source/general/registry.rst with available_keys --- docs/source/general/registry.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/general/registry.rst b/docs/source/general/registry.rst index 62ae14c67f..12ef22728b 100644 --- a/docs/source/general/registry.rst +++ b/docs/source/general/registry.rst @@ -100,7 +100,7 @@ Example:: from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES - print(IMAGE_CLASSIFIER_BACKBONES.available_models()) + print(IMAGE_CLASSIFIER_BACKBONES.available_keys()) """ out: ['adv_inception_v3', 'cspdarknet53', 'cspdarknet53_iabn', 430+.., 'xception71'] """ From 5b853c2b47e4db2ed6c006abeba3f546980165d4 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 16 Jul 2021 19:22:54 +0200 Subject: [PATCH 2/4] [Feat] Add PointCloud ObjectDetection (#600) * wip * wip * wip * add tests * add docs * update changelog * update * update * update * update * update * update * update * update * update * update * update * Update tests/pointcloud/detection/test_data.py * Apply suggestions from code review * Update tests/pointcloud/detection/test_data.py * Update tests/pointcloud/detection/test_data.py * Update tests/pointcloud/detection/test_data.py * Update tests/pointcloud/detection/test_data.py * resolve test * Update tests/pointcloud/detection/test_data.py Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris --- CHANGELOG.md | 2 + docs/source/api/pointcloud.rst | 16 ++ docs/source/index.rst | 1 + .../reference/pointcloud_object_detection.rst | 82 ++++++ flash/core/data/data_source.py | 7 + flash/core/data/states.py | 18 ++ flash/core/model.py | 17 +- flash/pointcloud/__init__.py | 3 +- flash/pointcloud/detection/__init__.py | 3 + flash/pointcloud/detection/backbones.py | 19 ++ flash/pointcloud/detection/data.py | 178 +++++++++++++ flash/pointcloud/detection/datasets.py | 41 +++ flash/pointcloud/detection/model.py | 187 ++++++++++++++ flash/pointcloud/detection/open3d_ml/app.py | 171 ++++++++++++ .../detection/open3d_ml/backbones.py | 81 ++++++ .../detection/open3d_ml/data_sources.py | 244 ++++++++++++++++++ flash/pointcloud/segmentation/__init__.py | 1 + .../pointcloud/segmentation/open3d_ml/app.py | 3 +- .../segmentation/open3d_ml/backbones.py | 4 +- flash_examples/pointcloud_detection.py | 41 +++ .../visualizations/pointcloud_detection.py | 43 +++ .../visualizations/pointcloud_segmentation.py | 3 +- tests/examples/test_scripts.py | 17 ++ tests/pointcloud/detection/__init__.py | 0 tests/pointcloud/detection/test_data.py | 60 +++++ tests/pointcloud/detection/test_model.py | 24 ++ 26 files changed, 1257 insertions(+), 9 deletions(-) create mode 100644 docs/source/reference/pointcloud_object_detection.rst create mode 100644 flash/pointcloud/detection/__init__.py create mode 100644 flash/pointcloud/detection/backbones.py create mode 100644 flash/pointcloud/detection/data.py create mode 100644 flash/pointcloud/detection/datasets.py create mode 100644 flash/pointcloud/detection/model.py create mode 100644 flash/pointcloud/detection/open3d_ml/app.py create mode 100644 flash/pointcloud/detection/open3d_ml/backbones.py create mode 100644 flash/pointcloud/detection/open3d_ml/data_sources.py create mode 100644 flash_examples/pointcloud_detection.py create mode 100644 flash_examples/visualizations/pointcloud_detection.py create mode 100644 tests/pointcloud/detection/__init__.py create mode 100644 tests/pointcloud/detection/test_data.py create mode 100644 tests/pointcloud/detection/test_model.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 97085839cd..54851b160e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566)) +- Added `PointCloudObjectDetection` Task ([#600](https://github.com/PyTorchLightning/lightning-flash/pull/600)) + - Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73)) - Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587)) diff --git a/docs/source/api/pointcloud.rst b/docs/source/api/pointcloud.rst index d29a3d4e32..a98c6124f0 100644 --- a/docs/source/api/pointcloud.rst +++ b/docs/source/api/pointcloud.rst @@ -23,3 +23,19 @@ ____________ segmentation.data.PointCloudSegmentationPreprocess segmentation.data.PointCloudSegmentationFoldersDataSource segmentation.data.PointCloudSegmentationDatasetDataSource + + +Object Detection +________________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~detection.model.PointCloudObjectDetector + ~detection.data.PointCloudObjectDetectorData + + detection.data.PointCloudObjectDetectorPreprocess + detection.data.PointCloudObjectDetectorFoldersDataSource + detection.data.PointCloudObjectDetectorDatasetDataSource diff --git a/docs/source/index.rst b/docs/source/index.rst index 34616e011d..cf3917f11d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -60,6 +60,7 @@ Lightning Flash :caption: Point Cloud reference/pointcloud_segmentation + reference/pointcloud_object_detection .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/pointcloud_object_detection.rst b/docs/source/reference/pointcloud_object_detection.rst new file mode 100644 index 0000000000..36c1b19e6b --- /dev/null +++ b/docs/source/reference/pointcloud_object_detection.rst @@ -0,0 +1,82 @@ + +.. _pointcloud_object_detection: + +############################ +Point Cloud Object Detection +############################ + +******** +The Task +******** + +A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates. + +PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes. + +The current integration builds on top `Open3D-ML `_. + +------ + +******* +Example +******* + +Let's look at an example using a data set generated from the `KITTI Vision Benchmark `_. +The data are a tiny subset of the original dataset and contains sequences of point clouds. + +The data contains: + * one folder for scans + * one folder for scan calibrations + * one folder for labels + * a meta.yaml file describing the classes and their official associated color map. + +Here's the structure: + +.. code-block:: + + data + ├── meta.yaml + ├── train + │ ├── scans + | | ├── 00000.bin + | | ├── 00001.bin + | | ... + │ ├── calibs + | | ├── 00000.txt + | | ├── 00001.txt + | | ... + │ ├── labels + | | ├── 00000.txt + | | ├── 00001.txt + │ ... + ├── val + │ ... + ├── predict + ├── scans + | ├── 00000.bin + | ├── 00001.bin + | + ├── calibs + | ├── 00000.txt + | ├── 00001.txt + ├── meta.yaml + + + +Learn more: http://www.semantic-kitti.org/dataset.html + + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.detection.data.PointCloudObjectDetectorData`. +We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.detection.model.PointCloudObjectDetector` task. +We then use the trained :class:`~flash.image.detection.model.PointCloudObjectDetector` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/pointcloud_detection.py + :language: python + :lines: 14- + + + +.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/visualizer_BoundingBoxes.png + :width: 100% diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index d3c7c611ef..c24e937b08 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -176,6 +176,13 @@ def __hash__(self) -> int: return hash(self.value) +class BaseDataFormat(LightningEnum): + """The base class for creating ``data_format`` for :class:`~flash.core.data.data_source.DataSource`.""" + + def __hash__(self) -> int: + return hash(self.value) + + class MockDataset: """The ``MockDataset`` catches any metadata that is attached through ``__setattr__``. This is passed to :meth:`~flash.core.data.data_source.DataSource.load_data` so that attributes can be set on the generated diff --git a/flash/core/data/states.py b/flash/core/data/states.py index 5755e7445f..de026f7d73 100644 --- a/flash/core/data/states.py +++ b/flash/core/data/states.py @@ -4,6 +4,24 @@ from flash.core.data.properties import ProcessState +@dataclass(unsafe_hash=True, frozen=True) +class PreTensorTransform(ProcessState): + + transform: Optional[Callable] = None + + +@dataclass(unsafe_hash=True, frozen=True) +class ToTensorTransform(ProcessState): + + transform: Optional[Callable] = None + + +@dataclass(unsafe_hash=True, frozen=True) +class PostTensorTransform(ProcessState): + + transform: Optional[Callable] = None + + @dataclass(unsafe_hash=True, frozen=True) class CollateFn(ProcessState): diff --git a/flash/core/model.py b/flash/core/model.py index 1036e45e7f..21fa1a40f3 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -188,21 +188,32 @@ def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} logs = {} y_hat = self.to_metrics_format(output["y_hat"]) + + logs = {} + for name, metric in metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) logs[name] = metric # log the metric itself if it is of type Metric else: logs[name] = metric(y_hat, y) - logs.update(losses) + if len(losses.values()) > 1: logs["total_loss"] = sum(losses.values()) return logs["total_loss"], logs - output["loss"] = list(losses.values())[0] - output["logs"] = logs + + output["loss"] = self.compute_loss(losses) + output["logs"] = self.compute_logs(logs, losses) output["y"] = y return output + def compute_loss(self, losses: Dict[str, torch.Tensor]) -> torch.Tensor: + return list(losses.values())[0] + + def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]): + logs.update(losses) + return logs + @staticmethod def apply_filtering(y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """This function is used to filter some labels or predictions which aren't conform.""" diff --git a/flash/pointcloud/__init__.py b/flash/pointcloud/__init__.py index 5d10606f79..8ad5b88538 100644 --- a/flash/pointcloud/__init__.py +++ b/flash/pointcloud/__init__.py @@ -1,3 +1,4 @@ +from flash.pointcloud.detection.data import PointCloudObjectDetectorData # noqa: F401 +from flash.pointcloud.detection.model import PointCloudObjectDetector # noqa: F401 from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401 from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401 -from flash.pointcloud.segmentation.open3d_ml.app import launch_app # noqa: F401 diff --git a/flash/pointcloud/detection/__init__.py b/flash/pointcloud/detection/__init__.py new file mode 100644 index 0000000000..cfe4c690f0 --- /dev/null +++ b/flash/pointcloud/detection/__init__.py @@ -0,0 +1,3 @@ +from flash.pointcloud.detection.data import PointCloudObjectDetectorData # noqa: F401 +from flash.pointcloud.detection.model import PointCloudObjectDetector # noqa: F401 +from flash.pointcloud.detection.open3d_ml.app import launch_app # noqa: F401 diff --git a/flash/pointcloud/detection/backbones.py b/flash/pointcloud/detection/backbones.py new file mode 100644 index 0000000000..88268dd036 --- /dev/null +++ b/flash/pointcloud/detection/backbones.py @@ -0,0 +1,19 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from flash.core.registry import FlashRegistry +from flash.pointcloud.detection.open3d_ml.backbones import register_open_3d_ml + +POINTCLOUD_OBJECT_DETECTION_BACKBONES = FlashRegistry("backbones") + +register_open_3d_ml(POINTCLOUD_OBJECT_DETECTION_BACKBONES) diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py new file mode 100644 index 0000000000..30c877e70d --- /dev/null +++ b/flash/pointcloud/detection/data.py @@ -0,0 +1,178 @@ +from typing import Any, Callable, Dict, Optional + +from torch.utils.data import Sampler + +from flash.core.data.base_viz import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_pipeline import Deserializer +from flash.core.data.data_source import BaseDataFormat, DataSource, DefaultDataKeys, DefaultDataSources +from flash.core.data.process import Preprocess +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + from flash.pointcloud.detection.open3d_ml.data_sources import ( + PointCloudObjectDetectionDataFormat, + PointCloudObjectDetectorFoldersDataSource, + ) +else: + PointCloudObjectDetectorFoldersDataSource = object() + + class PointCloudObjectDetectionDataFormat: + KITTI = None + + +class PointCloudObjectDetectorDatasetDataSource(DataSource): + + def __init__(self, **kwargs): + super().__init__() + + def load_data( + self, + data: Any, + dataset: Optional[Any] = None, + ) -> Any: + + dataset.dataset = data + + return range(len(data)) + + def load_sample(self, index: int, dataset: Optional[Any] = None) -> Any: + sample = dataset.dataset[index] + + return { + DefaultDataKeys.INPUT: sample['data'], + DefaultDataKeys.METADATA: sample["attr"], + } + + +class PointCloudObjectDetectorPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + deserializer: Optional[Deserializer] = None, + **data_source_kwargs, + ): + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.DATASET: PointCloudObjectDetectorDatasetDataSource(**data_source_kwargs), + DefaultDataSources.FOLDERS: PointCloudObjectDetectorFoldersDataSource(**data_source_kwargs), + }, + deserializer=deserializer, + default_data_source=DefaultDataSources.FOLDERS, + ) + + def get_state_dict(self): + return {} + + def state_dict(self): + return {} + + @classmethod + def load_state_dict(cls, state_dict, strict: bool = False): + pass + + +class PointCloudObjectDetectorData(DataModule): + + preprocess_cls = PointCloudObjectDetectorPreprocess + + @classmethod + def from_folders( + cls, + train_folder: Optional[str] = None, + val_folder: Optional[str] = None, + test_folder: Optional[str] = None, + predict_folder: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, + scans_folder_name: Optional[str] = "scans", + labels_folder_name: Optional[str] = "labels", + calibrations_folder_name: Optional[str] = "calibs", + data_format: Optional[BaseDataFormat] = PointCloudObjectDetectionDataFormat.KITTI, + **preprocess_kwargs: Any, + ) -> 'DataModule': + """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given folders using the + :class:`~flash.core.data.data_source.DataSource` of name + :attr:`~flash.core.data.data_source.DefaultDataSources.FOLDERS` + from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + + Args: + train_folder: The folder containing the train data. + val_folder: The folder containing the validation data. + test_folder: The folder containing the test data. + predict_folder: The folder containing the predict data. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + scans_folder_name: The name of the pointcloud scan folder + labels_folder_name: The name of the pointcloud scan labels folder + calibrations_folder_name: The name of the pointcloud scan calibration folder + data_format: Format in which the data are stored. + + Returns: + The constructed data module. + + Examples:: + + data_module = DataModule.from_folders( + train_folder="train_folder", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + ) + """ + return cls.from_data_source( + DefaultDataSources.FOLDERS, + train_folder, + val_folder, + test_folder, + predict_folder, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + scans_folder_name=scans_folder_name, + labels_folder_name=labels_folder_name, + calibrations_folder_name=calibrations_folder_name, + data_format=data_format, + **preprocess_kwargs, + ) diff --git a/flash/pointcloud/detection/datasets.py b/flash/pointcloud/detection/datasets.py new file mode 100644 index 0000000000..4860da1363 --- /dev/null +++ b/flash/pointcloud/detection/datasets.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.pointcloud.segmentation.datasets import executor + +if _POINTCLOUD_AVAILABLE: + from open3d.ml.datasets import KITTI + +_OBJECT_DETECTION_DATASET = FlashRegistry("dataset") + + +@_OBJECT_DETECTION_DATASET +def kitti(dataset_path, download, **kwargs): + name = "KITTI" + download_path = os.path.join(dataset_path, name, "Kitti") + if not os.path.exists(download_path): + executor( + "https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/scripts/download_datasets/download_kitti.sh", # noqa E501 + None, + dataset_path, + name + ) + return KITTI(download_path, **kwargs) + + +def KITTIDataset(dataset_path, download: bool = True, **kwargs): + return _OBJECT_DETECTION_DATASET.get("kitti")(dataset_path, download, **kwargs) diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py new file mode 100644 index 0000000000..ff1e718484 --- /dev/null +++ b/flash/pointcloud/detection/model.py @@ -0,0 +1,187 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Type, Union + +import torch +import torchmetrics +from torch import nn +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader, Sampler + +import flash +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.process import Serializer +from flash.core.data.states import CollateFn +from flash.core.registry import FlashRegistry +from flash.core.utilities.apply_func import get_callable_dict +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE +from flash.pointcloud.detection.backbones import POINTCLOUD_OBJECT_DETECTION_BACKBONES + +__FILE_EXAMPLE__ = "pointcloud_detection" + + +class PointCloudObjectDetectorSerializer(Serializer): + pass + + +class PointCloudObjectDetector(flash.Task): + """The ``PointCloudObjectDetector`` is a :class:`~flash.core.classification.ClassificationTask` that classifies + pointcloud data. + + Args: + num_features: The number of features (elements) in the input data. + num_classes: The number of classes (outputs) for this :class:`~flash.core.model.Task`. + backbone: The backbone name (or a tuple of ``nn.Module``, output size) to use. + backbone_kwargs: Any additional kwargs to pass to the backbone constructor. + loss_fn: The loss function to use. If ``None``, a default will be selected by the + :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + optimizer: The optimizer or optimizer class to use. + optimizer_kwargs: Additional kwargs to use when creating the optimizer (if not passed as an instance). + scheduler: The scheduler or scheduler class to use. + scheduler_kwargs: Additional kwargs to use when creating the scheduler (if not passed as an instance). + metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected + by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. + learning_rate: The learning rate for the optimizer. + multi_label: If ``True``, this will be treated as a multi-label classification problem. + serializer: The :class:`~flash.core.data.process.Serializer` to use for prediction outputs. + lambda_loss_cls: The value to scale the loss classification. + lambda_loss_bbox: The value to scale the bounding boxes loss. + lambda_loss_dir: The value to scale the bounding boxes direction loss. + """ + + backbones: FlashRegistry = POINTCLOUD_OBJECT_DETECTION_BACKBONES + required_extras: str = "pointcloud" + + def __init__( + self, + num_classes: int, + backbone: Union[str, Tuple[nn.Module, int]] = "pointpillars_kitti", + backbone_kwargs: Optional[Dict] = None, + head: Optional[nn.Module] = None, + loss_fn: Optional[Callable] = None, + optimizer: Union[Type[torch.optim.Optimizer], torch.optim.Optimizer] = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler: Optional[Union[Type[_LRScheduler], str, _LRScheduler]] = None, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 1e-2, + serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = PointCloudObjectDetectorSerializer(), + lambda_loss_cls: float = 1., + lambda_loss_bbox: float = 1., + lambda_loss_dir: float = 1., + ): + + super().__init__( + model=None, + loss_fn=loss_fn, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, + metrics=metrics, + learning_rate=learning_rate, + serializer=serializer, + ) + + self.save_hyperparameters() + + if backbone_kwargs is None: + backbone_kwargs = {} + + if isinstance(backbone, tuple): + self.backbone, out_features = backbone + else: + self.model, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs) + self.backbone = self.model.backbone + self.neck = self.model.neck + self.set_state(CollateFn(collate_fn)) + self.set_state(CollateFn(collate_fn)) + self.set_state(CollateFn(collate_fn)) + self.loss_fn = get_callable_dict(self.model.loss) + + if __FILE_EXAMPLE__ not in sys.argv[0]: + self.model.bbox_head.conv_cls = self.head = nn.Conv2d( + out_features, num_classes, kernel_size=(1, 1), stride=(1, 1) + ) + + def compute_loss(self, losses: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + losses = losses["loss"] + return ( + self.hparams.lambda_loss_cls * losses["loss_cls"] + self.hparams.lambda_loss_bbox * losses["loss_bbox"] + + self.hparams.lambda_loss_dir * losses["loss_dir"] + ) + + def compute_logs(self, logs: Dict[str, Any], losses: Dict[str, torch.Tensor]): + logs.update({"loss": self.compute_loss(losses)}) + return logs + + def training_step(self, batch: Any, batch_idx: int) -> Any: + return super().training_step((batch, batch), batch_idx) + + def validation_step(self, batch: Any, batch_idx: int) -> Any: + super().validation_step((batch, batch), batch_idx) + + def test_step(self, batch: Any, batch_idx: int) -> Any: + super().validation_step((batch, batch), batch_idx) + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + results = self.model(batch) + boxes = self.model.inference_end(results, batch) + return { + DefaultDataKeys.INPUT: getattr(batch, "point", None), + DefaultDataKeys.PREDS: boxes, + DefaultDataKeys.METADATA: [a["name"] for a in batch.attr] + } + + def forward(self, x) -> torch.Tensor: + """First call the backbone, then the model head.""" + # hack to enable backbone to work properly. + self.model.device = self.device + return self.model(x) + + def _process_dataset( + self, + dataset: BaseAutoDataset, + batch_size: int, + num_workers: int, + pin_memory: bool, + collate_fn: Callable, + shuffle: bool = False, + drop_last: bool = True, + sampler: Optional[Sampler] = None, + convert_to_dataloader: bool = True, + ) -> Union[DataLoader, BaseAutoDataset]: + + if not _POINTCLOUD_AVAILABLE: + raise ModuleNotFoundError("Please, run `pip install flash[pointcloud]`.") + + dataset.preprocess_fn = self.model.preprocess + dataset.transform_fn = self.model.transform + + if convert_to_dataloader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + ) + + else: + return dataset diff --git a/flash/pointcloud/detection/open3d_ml/app.py b/flash/pointcloud/detection/open3d_ml/app.py new file mode 100644 index 0000000000..5578955d8a --- /dev/null +++ b/flash/pointcloud/detection/open3d_ml/app.py @@ -0,0 +1,171 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch +from torch.utils.data.dataset import Dataset + +import flash +from flash import DataModule +from flash.core.data.data_source import DefaultDataKeys +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + + from open3d._ml3d.vis.visualizer import LabelLUT, Visualizer + from open3d.visualization import gui + + class Visualizer(Visualizer): + + def visualize_dataset(self, dataset, split, indices=None, width=1024, height=768): + """Visualize a dataset. + + Example: + Minimal example for visualizing a dataset:: + import open3d.ml.torch as ml3d # or open3d.ml.tf as ml3d + + dataset = ml3d.datasets.SemanticKITTI(dataset_path='/path/to/SemanticKITTI/') + vis = ml3d.vis.Visualizer() + vis.visualize_dataset(dataset, 'all', indices=range(100)) + + Args: + dataset: The dataset to use for visualization. + split: The dataset split to be used, such as 'training' + indices: An iterable with a subset of the data points to visualize, such as [0,2,3,4]. + width: The width of the visualization window. + height: The height of the visualization window. + """ + # Setup the labels + lut = LabelLUT() + for id, color in dataset.color_map.items(): + lut.add_label(id, id, color=color) + self.set_lut("label", lut) + + self._consolidate_bounding_boxes = True + self._init_dataset(dataset, split, indices) + + self._visualize("Open3D - " + dataset.name, width, height) + + def _visualize(self, title, width, height): + gui.Application.instance.initialize() + self._init_user_interface(title, width, height) + + # override just to set background color to back :) + bgcolor = gui.ColorEdit() + bgcolor.color_value = gui.Color(0, 0, 0) + self._on_bgcolor_changed(bgcolor.color_value) + + self._3d.scene.downsample_threshold = 400000 + + # Turn all the objects off except the first one + for name, node in self._name2treenode.items(): + node.checkbox.checked = False + self._3d.scene.show_geometry(name, False) + for name in [self._objects.data_names[0]]: + self._name2treenode[name].checkbox.checked = True + self._3d.scene.show_geometry(name, True) + + def on_done_ui(): + # Add bounding boxes here: bounding boxes belonging to the dataset + # will not be loaded until now. + self._update_bounding_boxes() + + self._update_datasource_combobox() + self._update_shaders_combobox() + + # Display "colors" by default if available, "points" if not + available_attrs = self._get_available_attrs() + self._set_shader(self.SOLID_NAME, force_update=True) + if "colors" in available_attrs: + self._datasource_combobox.selected_text = "colors" + elif "points" in available_attrs: + self._datasource_combobox.selected_text = "points" + + self._dont_update_geometry = True + self._on_datasource_changed( + self._datasource_combobox.selected_text, self._datasource_combobox.selected_index + ) + self._update_geometry_colors() + self._dont_update_geometry = False + # _datasource_combobox was empty, now isn't, re-layout. + self.window.set_needs_layout() + + self._update_geometry() + self.setup_camera() + + self._load_geometries(self._objects.data_names, on_done_ui) + gui.Application.instance.run() + + class VizDataset(Dataset): + + name = "VizDataset" + + def __init__(self, dataset): + self.dataset = dataset + self.label_to_names = getattr(dataset, "label_to_names", {}) + self.path_list = getattr(dataset, "path_list", []) + self.color_map = getattr(dataset, "color_map", {}) + + def get_data(self, index): + data = self.dataset[index]["data"] + data["bounding_boxes"] = data["bbox_objs"] + data["color"] = np.ones_like(data["point"]) + return data + + def get_attr(self, index): + return self.dataset[index]["attr"] + + def get_split(self, *_) -> 'VizDataset': + return self + + def __len__(self) -> int: + return len(self.dataset) + + class App: + + def __init__(self, datamodule: DataModule): + self.datamodule = datamodule + self._enabled = not flash._IS_TESTING + + def get_dataset(self, stage: str = "train"): + dataloader = getattr(self.datamodule, f"{stage}_dataloader")() + return VizDataset(dataloader.dataset) + + def show_train_dataset(self, indices=None): + if self._enabled: + dataset = self.get_dataset("train") + viz = Visualizer() + viz.visualize_dataset(dataset, 'all', indices=indices) + + def show_predictions(self, predictions): + if self._enabled: + dataset = self.get_dataset("train") + + viz = Visualizer() + lut = LabelLUT() + for id, color in dataset.color_map.items(): + lut.add_label(id, id, color=color) + viz.set_lut("label", lut) + + for pred in predictions: + data = { + "points": torch.stack(pred[DefaultDataKeys.INPUT])[:, :3], + "name": pred[DefaultDataKeys.METADATA], + } + bounding_box = pred[DefaultDataKeys.PREDS] + + viz.visualize([data], bounding_boxes=bounding_box) + + +def launch_app(datamodule: DataModule) -> 'App': + return App(datamodule) diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/flash/pointcloud/detection/open3d_ml/backbones.py new file mode 100644 index 0000000000..6dbb0acbb1 --- /dev/null +++ b/flash/pointcloud/detection/open3d_ml/backbones.py @@ -0,0 +1,81 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC +from typing import Callable + +import torch +from pytorch_lightning.utilities.cloud_io import load as pl_load + +from flash.core.registry import FlashRegistry +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +ROOT_URL = "https://storage.googleapis.com/open3d-releases/model-zoo/" + +if _POINTCLOUD_AVAILABLE: + import open3d + import open3d.ml as _ml3d + from open3d._ml3d.torch.dataloaders.concat_batcher import ConcatBatcher, ObjectDetectBatch + from open3d._ml3d.torch.models.point_pillars import PointPillars + from open3d.ml.torch.dataloaders import DefaultBatcher +else: + ObjectDetectBatch = ABC + PointPillars = ABC + + +class ObjectDetectBatchCollator(ObjectDetectBatch): + + def __init__(self, batches): + self.num_batches = len(batches) + super().__init__(batches) + + def to(self, device): + super().to(device) + return self + + def __len__(self): + return self.num_batches + + +def register_open_3d_ml(register: FlashRegistry): + + if _POINTCLOUD_AVAILABLE: + + CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs") + + def get_collate_fn(model) -> Callable: + batcher_name = model.cfg.batcher + if batcher_name == 'DefaultBatcher': + batcher = DefaultBatcher() + elif batcher_name == 'ConcatBatcher': + batcher = ConcatBatcher(torch, model.__class__.__name__) + elif batcher_name == 'ObjectDetectBatchCollator': + return ObjectDetectBatchCollator + return batcher.collate_fn + + @register(parameters=PointPillars.__init__) + def pointpillars_kitti(*args, **kwargs) -> PointPillars: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "pointpillars_kitti.yml")) + cfg.model.device = "cpu" + model = PointPillars(**cfg.model) + weight_url = os.path.join(ROOT_URL, "pointpillars_kitti_202012221652utc.pth") + model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict'], ) + model.cfg.batcher = "ObjectDetectBatchCollator" + return model, 384, get_collate_fn(model) + + @register(parameters=PointPillars.__init__) + def pointpillars(*args, **kwargs) -> PointPillars: + model = PointPillars(*args, **kwargs) + model.cfg.batcher = "ObjectDetectBatch" + return model, get_collate_fn(model) diff --git a/flash/pointcloud/detection/open3d_ml/data_sources.py b/flash/pointcloud/detection/open3d_ml/data_sources.py new file mode 100644 index 0000000000..bd594ebe2f --- /dev/null +++ b/flash/pointcloud/detection/open3d_ml/data_sources.py @@ -0,0 +1,244 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from os.path import basename, dirname, exists, isdir, isfile, join +from posix import listdir +from typing import Any, Dict, List, Optional, Union + +import yaml +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +from flash.core.data.auto_dataset import BaseAutoDataset +from flash.core.data.data_source import BaseDataFormat, DataSource +from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE + +if _POINTCLOUD_AVAILABLE: + from open3d._ml3d.datasets.kitti import DataProcessing, KITTI + + +class PointCloudObjectDetectionDataFormat(BaseDataFormat): + KITTI = "kitti" + + +class BasePointCloudObjectDetectorLoader: + + pass + + +class KITTIPointCloudObjectDetectorLoader(BasePointCloudObjectDetectorLoader): + + def __init__( + self, + image_size: tuple = (375, 1242), + scans_folder_name: Optional[str] = "scans", + labels_folder_name: Optional[str] = "labels", + calibrations_folder_name: Optional[str] = "calibs", + **kwargs, + ): + + self.image_size = image_size + self.scans_folder_name = scans_folder_name + self.labels_folder_name = labels_folder_name + self.calibrations_folder_name = calibrations_folder_name + + def load_meta(self, root_dir, dataset: Optional[BaseAutoDataset]): + meta_file = join(root_dir, "meta.yaml") + if not exists(meta_file): + raise MisconfigurationException(f"The {root_dir} should contain a `meta.yaml` file about the classes.") + + with open(meta_file, 'r') as f: + self.meta = yaml.safe_load(f) + + if "label_to_names" not in self.meta: + raise MisconfigurationException( + f"The {root_dir} should contain a `meta.yaml` file about the classes with the field `label_to_names`." + ) + + dataset.num_classes = len(self.meta["label_to_names"]) + dataset.label_to_names = self.meta["label_to_names"] + dataset.color_map = self.meta["color_map"] + + def load_data(self, folder: str, dataset: Optional[BaseAutoDataset]): + sub_directories = listdir(folder) + if len(sub_directories) != 3: + raise MisconfigurationException( + f"Using KITTI Format, the {folder} should contains 3 directories " + "for ``calibrations``, ``labels`` and ``scans``." + ) + + assert self.scans_folder_name in sub_directories + assert self.labels_folder_name in sub_directories + assert self.calibrations_folder_name in sub_directories + + scans_dir = join(folder, self.scans_folder_name) + labels_dir = join(folder, self.labels_folder_name) + calibrations_dir = join(folder, self.calibrations_folder_name) + + scan_paths = [join(scans_dir, f) for f in listdir(scans_dir)] + label_paths = [join(labels_dir, f) for f in listdir(labels_dir)] + calibration_paths = [join(calibrations_dir, f) for f in listdir(calibrations_dir)] + + assert len(scan_paths) == len(label_paths) == len(calibration_paths) + + self.load_meta(dirname(folder), dataset) + + dataset.path_list = scan_paths + + return [{ + "scan_path": scan_path, + "label_path": label_path, + "calibration_path": calibration_path + } for scan_path, label_path, calibration_path, in zip(scan_paths, label_paths, calibration_paths)] + + def load_sample( + self, sample: Dict[str, str], dataset: Optional[BaseAutoDataset] = None, has_label: bool = True + ) -> Any: + pc = KITTI.read_lidar(sample["scan_path"]) + calib = KITTI.read_calib(sample["calibration_path"]) + label = None + if has_label: + label = KITTI.read_label(sample["label_path"], calib) + + reduced_pc = DataProcessing.remove_outside_points(pc, calib['world_cam'], calib['cam_img'], self.image_size) + + attr = { + "name": basename(sample["scan_path"]), + "path": sample["scan_path"], + "calibration_path": sample["calibration_path"], + "label_path": sample["label_path"] if has_label else None, + "split": "val", + } + + data = { + 'point': reduced_pc, + 'full_point': pc, + 'feat': None, + 'calib': calib, + 'bounding_boxes': label if has_label else None, + 'attr': attr + } + return data, attr + + def load_files(self, scan_paths: Union[str, List[str]], dataset: Optional[BaseAutoDataset] = None): + if isinstance(scan_paths, str): + scan_paths = [scan_paths] + + def clean_fn(path: str) -> str: + return path.replace(self.scans_folder_name, self.calibrations_folder_name).replace(".bin", ".txt") + + dataset.path_list = scan_paths + + return [{"scan_path": scan_path, "calibration_path": clean_fn(scan_path)} for scan_path in scan_paths] + + def predict_load_data(self, data, dataset: Optional[BaseAutoDataset] = None): + if (isinstance(data, str) and isfile(data)) or (isinstance(data, list) and all(isfile(p) for p in data)): + return self.load_files(data, dataset) + elif isinstance(data, str) and isdir(data): + raise NotImplementedError + + def predict_load_sample(self, data, dataset: Optional[BaseAutoDataset] = None): + data, attr = self.load_sample(data, dataset, has_label=False) + # hack to prevent manipulation of labels + attr["split"] = "test" + return data, attr + + +class PointCloudObjectDetectorFoldersDataSource(DataSource): + + def __init__( + self, + data_format: Optional[BaseDataFormat] = None, + image_size: tuple = (375, 1242), + **loader_kwargs, + ): + super().__init__() + + self.loaders = { + PointCloudObjectDetectionDataFormat.KITTI: KITTIPointCloudObjectDetectorLoader( + **loader_kwargs, image_size=image_size + ) + } + + self.data_format = data_format or PointCloudObjectDetectionDataFormat.KITTI + self.loader = self.loaders[data_format] + + def _validate_data(self, folder: str) -> None: + msg = f"The provided dataset for stage {self._running_stage} should be a folder. Found {folder}." + if not isinstance(folder, str): + raise MisconfigurationException(msg) + + if isinstance(folder, str) and not isdir(folder): + raise MisconfigurationException(msg) + + def load_data( + self, + data: Any, + dataset: Optional[BaseAutoDataset] = None, + ) -> Any: + + self._validate_data(data) + + return self.loader.load_data(data, dataset) + + def load_sample(self, metadata: Dict[str, str], dataset: Optional[BaseAutoDataset] = None) -> Any: + + data, metadata = self.loader.load_sample(metadata, dataset) + + preprocess_fn = getattr(dataset, "preprocess_fn", None) + if preprocess_fn: + data = preprocess_fn(data, metadata) + + transform_fn = getattr(dataset, "transform_fn", None) + if transform_fn: + data = transform_fn(data, metadata) + + return {"data": data, "attr": metadata} + + def _validate_predict_data(self, data: Union[str, List[str]]) -> None: + msg = f"The provided predict data should be a either a folder or a single/list of scan path(s). Found {data}." + if not isinstance(data, str) and not isinstance(data, list): + raise MisconfigurationException(msg) + + if isinstance(data, str) and (not isfile(data) or not isdir(data)): + raise MisconfigurationException(msg) + + if isinstance(data, list) and not all(isfile(p) for p in data): + raise MisconfigurationException(msg) + + def predict_load_data( + self, + data: Any, + dataset: Optional[BaseAutoDataset] = None, + ) -> Any: + + self._validate_predict_data(data) + + return self.loader.predict_load_data(data, dataset) + + def predict_load_sample( + self, + metadata: Any, + dataset: Optional[BaseAutoDataset] = None, + ) -> Any: + + data, metadata = self.loader.predict_load_sample(metadata, dataset) + + preprocess_fn = getattr(dataset, "preprocess_fn", None) + if preprocess_fn: + data = preprocess_fn(data, metadata) + + transform_fn = getattr(dataset, "transform_fn", None) + if transform_fn: + data = transform_fn(data, metadata) + + return {"data": data, "attr": metadata} diff --git a/flash/pointcloud/segmentation/__init__.py b/flash/pointcloud/segmentation/__init__.py index bf7f46a89c..5d10606f79 100644 --- a/flash/pointcloud/segmentation/__init__.py +++ b/flash/pointcloud/segmentation/__init__.py @@ -1,2 +1,3 @@ from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401 from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401 +from flash.pointcloud.segmentation.open3d_ml.app import launch_app # noqa: F401 diff --git a/flash/pointcloud/segmentation/open3d_ml/app.py b/flash/pointcloud/segmentation/open3d_ml/app.py index a226d6f5b2..879f45570e 100644 --- a/flash/pointcloud/segmentation/open3d_ml/app.py +++ b/flash/pointcloud/segmentation/open3d_ml/app.py @@ -13,7 +13,6 @@ # limitations under the License. import torch -import flash from flash import DataModule from flash.core.data.data_source import DefaultDataKeys from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE @@ -58,7 +57,7 @@ class App: def __init__(self, datamodule: DataModule): self.datamodule = datamodule - self._enabled = not flash._IS_TESTING + self._enabled = True # not flash._IS_TESTING def get_dataset(self, stage: str = "train"): dataloader = getattr(self.datamodule, f"{stage}_dataloader")() diff --git a/flash/pointcloud/segmentation/open3d_ml/backbones.py b/flash/pointcloud/segmentation/open3d_ml/backbones.py index 0fe44a72ce..aec3aa0123 100644 --- a/flash/pointcloud/segmentation/open3d_ml/backbones.py +++ b/flash/pointcloud/segmentation/open3d_ml/backbones.py @@ -27,8 +27,8 @@ def register_open_3d_ml(register: FlashRegistry): if _POINTCLOUD_AVAILABLE: import open3d import open3d.ml as _ml3d - from open3d.ml.torch.dataloaders import ConcatBatcher, DefaultBatcher - from open3d.ml.torch.models import RandLANet + from open3d._ml3d.torch.dataloaders import ConcatBatcher, DefaultBatcher + from open3d._ml3d.torch.models import RandLANet CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs") diff --git a/flash_examples/pointcloud_detection.py b/flash_examples/pointcloud_detection.py new file mode 100644 index 0000000000..6cd0409893 --- /dev/null +++ b/flash_examples/pointcloud_detection.py @@ -0,0 +1,41 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.data.utils import download_data +from flash.pointcloud import PointCloudObjectDetector, PointCloudObjectDetectorData + +# 1. Create the DataModule +# Dataset Credit: http://www.semantic-kitti.org/ +download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/") + +datamodule = PointCloudObjectDetectorData.from_folders( + train_folder="data/KITTI_Tiny/Kitti/train", + val_folder="data/KITTI_Tiny/Kitti/val", +) + +# 2. Build the task +model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0) +trainer.fit(model, datamodule) + +# 4. Predict what's within a few PointClouds? +predictions = model.predict([ + "data/KITTI_Tiny/Kitti/predict/scans/000000.bin", + "data/KITTI_Tiny/Kitti/predict/scans/000001.bin", +]) + +# 5. Save the model! +trainer.save_checkpoint("pointcloud_segmentation_model.pt") diff --git a/flash_examples/visualizations/pointcloud_detection.py b/flash_examples/visualizations/pointcloud_detection.py new file mode 100644 index 0000000000..ebfb0eb5a0 --- /dev/null +++ b/flash_examples/visualizations/pointcloud_detection.py @@ -0,0 +1,43 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.data.utils import download_data +from flash.pointcloud.detection import launch_app, PointCloudObjectDetector, PointCloudObjectDetectorData + +# 1. Create the DataModule +# Dataset Credit: http://www.semantic-kitti.org/ +download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_tiny.zip", "data/") + +datamodule = PointCloudObjectDetectorData.from_folders( + train_folder="data/KITTI_Tiny/Kitti/train", + val_folder="data/KITTI_Tiny/Kitti/val", +) + +# 2. Build the task +model = PointCloudObjectDetector(backbone="pointpillars_kitti", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1, num_sanity_val_steps=0) +trainer.fit(model, datamodule) + +# 4. Predict what's within a few PointClouds? +predictions = model.predict(["data/KITTI_Tiny/Kitti/predict/scans/000000.bin"]) + +# 5. Save the model! +trainer.save_checkpoint("pointcloud_segmentation_model.pt") + +# 6. Optional Visualize +app = launch_app(datamodule) +# app.show_train_dataset() +app.show_predictions(predictions) diff --git a/flash_examples/visualizations/pointcloud_segmentation.py b/flash_examples/visualizations/pointcloud_segmentation.py index e4859a8d90..85565a7027 100644 --- a/flash_examples/visualizations/pointcloud_segmentation.py +++ b/flash_examples/visualizations/pointcloud_segmentation.py @@ -13,7 +13,7 @@ # limitations under the License. import flash from flash.core.data.utils import download_data -from flash.pointcloud import launch_app, PointCloudSegmentation, PointCloudSegmentationData +from flash.pointcloud.segmentation import launch_app, PointCloudSegmentation, PointCloudSegmentationData # 1. Create the DataModule # Dataset Credit: http://www.semantic-kitti.org/ @@ -42,4 +42,5 @@ # 6. Optional Visualize app = launch_app(datamodule) +# app.show_train_dataset() app.show_predictions(predictions) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index 68252601e5..ec6c4bb834 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -81,6 +81,10 @@ "pointcloud_segmentation.py", marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") ), + pytest.param( + "pointcloud_detection.py", + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") + ), pytest.param( "graph_classification.py", marks=pytest.mark.skipif(not _GRAPH_TESTING, reason="graph libraries aren't installed") @@ -89,3 +93,16 @@ ) def test_example(tmpdir, file): run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file)) + + +@mock.patch.dict(os.environ, {"FLASH_TESTING": "1"}) +@pytest.mark.parametrize( + "file", [ + pytest.param( + "pointcloud_detection.py", + marks=pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") + ), + ] +) +def test_example_2(tmpdir, file): + run_test(str(Path(flash.PROJECT_ROOT) / "flash_examples" / file)) diff --git a/tests/pointcloud/detection/__init__.py b/tests/pointcloud/detection/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pointcloud/detection/test_data.py b/tests/pointcloud/detection/test_data.py new file mode 100644 index 0000000000..26484f476e --- /dev/null +++ b/tests/pointcloud/detection/test_data.py @@ -0,0 +1,60 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from os.path import join + +import pytest +import torch +from pytorch_lightning import seed_everything + +from flash import Trainer +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.utils import download_data +from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData +from tests.helpers.utils import _POINTCLOUD_TESTING + +if _POINTCLOUD_TESTING: + from flash.pointcloud.detection.open3d_ml.backbones import ObjectDetectBatchCollator + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_pointcloud_object_detection_data(tmpdir): + + seed_everything(52) + + download_data("https://pl-flash-data.s3.amazonaws.com/KITTI_micro.zip", tmpdir) + + dm = PointCloudObjectDetectorData.from_folders(train_folder=join(tmpdir, "KITTI_Micro", "Kitti", "train"), ) + + class MockModel(PointCloudObjectDetector): + + def training_step(self, batch, batch_idx: int): + assert isinstance(batch, ObjectDetectBatchCollator) + assert len(batch.point) == 2 + assert batch.point[0][1].shape == torch.Size([4]) + assert len(batch.bboxes) > 1 + assert batch.attr[0]["name"] == '000000.bin' + assert batch.attr[1]["name"] == '000001.bin' + + num_classes = 19 + model = MockModel(backbone="pointpillars_kitti", num_classes=num_classes) + trainer = Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=0) + trainer.fit(model, dm) + + predict_path = join(tmpdir, "KITTI_Micro", "Kitti", "predict") + model.eval() + + predictions = model.predict([join(predict_path, "scans/000000.bin")]) + assert torch.stack(predictions[0][DefaultDataKeys.INPUT]).shape[1] == 4 + assert len(predictions[0][DefaultDataKeys.PREDS]) == 158 + assert predictions[0][DefaultDataKeys.PREDS][0].__dict__["identifier"] == 'box:1' diff --git a/tests/pointcloud/detection/test_model.py b/tests/pointcloud/detection/test_model.py new file mode 100644 index 0000000000..b7d807c837 --- /dev/null +++ b/tests/pointcloud/detection/test_model.py @@ -0,0 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from flash.pointcloud.detection import PointCloudObjectDetector +from tests.helpers.utils import _POINTCLOUD_TESTING + + +@pytest.mark.skipif(not _POINTCLOUD_TESTING, reason="pointcloud libraries aren't installed") +def test_backbones(): + + backbones = PointCloudObjectDetector.available_backbones() + assert backbones == ['pointpillars', 'pointpillars_kitti'] From 6214983f7a2d2b8f828decb42a1c5404e47988fc Mon Sep 17 00:00:00 2001 From: Kinyugo Date: Fri, 16 Jul 2021 23:17:46 +0300 Subject: [PATCH 3/4] Feature/task a thon audio classification spectrograms (#594) * added audio spectrogram classification data, transforms and tests based on image classification * added audio spectrogram classification data, transforms and tests based on image classification * added audio spectrogram classification example and notebook * fixed formatting issues about newlines and longlines * updated docs to include audio classification task * removed empty `model` package * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updates * Update CHANGELOG.md * Updates * Updates * Try fix * Updates * Updates * Updates Co-authored-by: Ethan Harris Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris --- .github/workflows/ci-testing.yml | 11 + CHANGELOG.md | 2 + docs/source/_templates/layout.html | 2 +- docs/source/index.rst | 6 + .../source/reference/audio_classification.rst | 73 ++++ flash/audio/__init__.py | 1 + flash/audio/classification/__init__.py | 1 + flash/audio/classification/data.py | 87 +++++ flash/audio/classification/transforms.py | 54 +++ flash/core/utilities/imports.py | 26 +- flash_examples/audio_classification.py | 45 +++ requirements/datatype_audio.txt | 1 + tests/audio/__init__.py | 0 tests/audio/classification/__init__.py | 0 tests/audio/classification/test_data.py | 340 ++++++++++++++++++ tests/examples/test_scripts.py | 5 + tests/helpers/utils.py | 3 + tests/image/classification/test_data.py | 2 +- 18 files changed, 650 insertions(+), 9 deletions(-) create mode 100644 docs/source/reference/audio_classification.rst create mode 100644 flash/audio/__init__.py create mode 100644 flash/audio/classification/__init__.py create mode 100644 flash/audio/classification/data.py create mode 100644 flash/audio/classification/transforms.py create mode 100644 flash_examples/audio_classification.py create mode 100644 tests/audio/__init__.py create mode 100644 tests/audio/classification/__init__.py create mode 100644 tests/audio/classification/test_data.py diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d26d8ecee2..21ac8fbd45 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -61,6 +61,10 @@ jobs: python-version: 3.8 requires: 'latest' topic: ['graph'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['audio'] # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 @@ -128,6 +132,13 @@ jobs: run: | pip install '.[all]' --pre --upgrade + - name: Install audio test dependencies + if: matrix.topic[0] == 'audio' + run: | + sudo apt-get install libsndfile1 + pip install matplotlib + pip install '.[image]' --pre --upgrade + - name: Cache datasets uses: actions/cache@v2 with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 54851b160e..cb7c1cb3b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585)) +- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index d3312220d7..d050db39c5 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -4,7 +4,7 @@ {% block footer %} {{ super() }} {% endblock %} diff --git a/docs/source/index.rst b/docs/source/index.rst index cf3917f11d..2ac114009c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -40,6 +40,12 @@ Lightning Flash reference/style_transfer reference/video_classification +.. toctree:: + :maxdepth: 1 + :caption: Audio + + reference/audio_classification + .. toctree:: :maxdepth: 1 :caption: Tabular diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst new file mode 100644 index 0000000000..eb122e6995 --- /dev/null +++ b/docs/source/reference/audio_classification.rst @@ -0,0 +1,73 @@ + +.. _audio_classification: + +#################### +Audio Classification +#################### + +******** +The Task +******** + +The task of identifying what is in an audio file is called audio classification. +Typically, Audio Classification is used to identify audio files containing sounds or words. +The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. +A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc. + +------ + +******* +Example +******* + +Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset. +The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes. + +.. code-block:: + + urban8k_images + ├── train + │ ├── air_conditioner + │ ├── car_horn + │ ├── children_playing + │ ├── dog_bark + │ ├── drilling + │ ├── engine_idling + │ ├── gun_shot + │ ├── jackhammer + │ ├── siren + │ └── street_music + ├── test + │ ├── air_conditioner + │ ├── car_horn + │ ├── children_playing + │ ├── dog_bark + │ ├── drilling + │ ├── engine_idling + │ ├── gun_shot + │ ├── jackhammer + │ ├── siren + │ └── street_music + └── val + ├── air_conditioner + ├── car_horn + ├── children_playing + ├── dog_bark + ├── drilling + ├── engine_idling + ├── gun_shot + ├── jackhammer + ├── siren + └── street_music + + ... + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`. +We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data. +We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/audio_classification.py + :language: python + :lines: 14- diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py new file mode 100644 index 0000000000..40eeaae124 --- /dev/null +++ b/flash/audio/__init__.py @@ -0,0 +1 @@ +from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 diff --git a/flash/audio/classification/__init__.py b/flash/audio/classification/__init__.py new file mode 100644 index 0000000000..476a303d49 --- /dev/null +++ b/flash/audio/classification/__init__.py @@ -0,0 +1 @@ +from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py new file mode 100644 index 0000000000..68678b2a1b --- /dev/null +++ b/flash/audio/classification/data.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.audio.classification.transforms import default_transforms, train_default_transforms +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Deserializer, Preprocess +from flash.core.utilities.imports import requires_extras +from flash.image.classification.data import MatplotlibVisualization +from flash.image.data import ImageDeserializer, ImagePathsDataSource + + +class AudioClassificationPreprocess(Preprocess): + + @requires_extras(["audio", "image"]) + def __init__( + self, + train_transform: Optional[Dict[str, Callable]], + val_transform: Optional[Dict[str, Callable]], + test_transform: Optional[Dict[str, Callable]], + predict_transform: Optional[Dict[str, Callable]], + spectrogram_size: Tuple[int, int] = (196, 196), + time_mask_param: int = 80, + freq_mask_param: int = 80, + deserializer: Optional['Deserializer'] = None, + ): + self.spectrogram_size = spectrogram_size + self.time_mask_param = time_mask_param + self.freq_mask_param = freq_mask_param + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource() + }, + deserializer=deserializer or ImageDeserializer(), + default_data_source=DefaultDataSources.FILES, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return { + **self.transforms, + "spectrogram_size": self.spectrogram_size, + "time_mask_param": self.time_mask_param, + "freq_mask_param": self.freq_mask_param, + } + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.spectrogram_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param) + + +class AudioClassificationData(DataModule): + """Data module for audio classification.""" + + preprocess_cls = AudioClassificationPreprocess + + def set_block_viz_window(self, value: bool) -> None: + """Setter method to switch on/off matplotlib to pop up windows.""" + self.data_fetcher.block_viz_window = value + + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + return MatplotlibVisualization(*args, **kwargs) diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py new file mode 100644 index 0000000000..02a9ed2cbc --- /dev/null +++ b/flash/audio/classification/transforms.py @@ -0,0 +1,54 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Dict, Tuple + +import torch +from torch import nn + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms +from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE + +if _TORCHVISION_AVAILABLE: + import torchvision + from torchvision import transforms as T + +if _TORCHAUDIO_AVAILABLE: + from torchaudio import transforms as TAudio + + +def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default transforms for audio classification for spectrograms: resize the spectrogram, + convert the spectrogram and target to a tensor, and collate the batch.""" + return { + "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)), + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "collate": kornia_collate, + } + + +def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int, + freq_mask_param: int) -> Dict[str, Callable]: + """During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``""" + transforms = { + "post_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) + ) + } + + return merge_transforms(default_transforms(spectrogram_size), transforms) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 9922f49eba..80c6b6188c 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -16,6 +16,7 @@ import operator import types from importlib.util import find_spec +from typing import Callable, List, Union from pkg_resources import DistributionNotFound @@ -89,6 +90,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") +_TORCHAUDIO_AVAILABLE = _module_available("torchaudio") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") @@ -108,6 +110,7 @@ def _compare_version(package: str, op, version) -> bool: _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE _AUDIO_AVAILABLE = all([ _ASTEROID_AVAILABLE, + _TORCHAUDIO_AVAILABLE, ]) _GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE @@ -123,15 +126,22 @@ def _compare_version(package: str, op, version) -> bool: } -def _requires(module_path: str, module_available: bool): +def _requires( + module_paths: Union[str, List], + module_available: Callable[[str], bool], + formatter: Callable[[List[str]], str], +): + + if not isinstance(module_paths, list): + module_paths = [module_paths] def decorator(func): - if not module_available: + if not all(module_available(module_path) for module_path in module_paths): @functools.wraps(func) def wrapper(*args, **kwargs): raise ModuleNotFoundError( - f"Required dependencies not available. Please run: pip install '{module_path}'" + f"Required dependencies not available. Please run: pip install {formatter(module_paths)}" ) return wrapper @@ -141,12 +151,14 @@ def wrapper(*args, **kwargs): return decorator -def requires(module_path: str): - return _requires(module_path, _module_available(module_path)) +def requires(module_paths: Union[str, List]): + return _requires(module_paths, _module_available, lambda module_paths: " ".join(module_paths)) -def requires_extras(extras: str): - return _requires(f"lightning-flash[{extras}]", _EXTRAS_AVAILABLE[extras]) +def requires_extras(extras: Union[str, List]): + return _requires( + extras, lambda extras: _EXTRAS_AVAILABLE[extras], lambda extras: f"'lightning-flash[{','.join(extras)}]'" + ) def lazy_import(module_name, callback=None): diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py new file mode 100644 index 0000000000..b8f0f8a312 --- /dev/null +++ b/flash_examples/audio_classification.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.audio import AudioClassificationData +from flash.core.data.utils import download_data +from flash.core.finetuning import FreezeUnfreeze +from flash.image import ImageClassifier + +# 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data") + +datamodule = AudioClassificationData.from_folders( + train_folder="data/urban8k_images/train", + val_folder="data/urban8k_images/val", + spectrogram_size=(64, 64), +) + +# 2. Build the model. +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3) +trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) + +# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c +predictions = model.predict([ + "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", + "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", + "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", +]) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("audio_classification_model.pt") diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index 03c90d99ec..e608a13b78 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1 +1,2 @@ asteroid>=0.5.1 +torchaudio diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/classification/__init__.py b/tests/audio/classification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py new file mode 100644 index 0000000000..a1c0ba0677 --- /dev/null +++ b/tests/audio/classification/test_data.py @@ -0,0 +1,340 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pathlib import Path +from typing import Any, List, Tuple + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from flash.audio import AudioClassificationData +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from tests.helpers.utils import _AUDIO_TESTING + +if _TORCHVISION_AVAILABLE: + import torchvision + +if _PIL_AVAILABLE: + from PIL import Image + + +def _rand_image(size: Tuple[int, int] = None): + if size is None: + _size = np.random.choice([196, 244]) + size = (_size, _size) + return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8")) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_smoke(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a_1.png") + _rand_image().save(tmpdir / "b_1.png") + + train_images = [ + str(tmpdir / "a_1.png"), + str(tmpdir / "b_1.png"), + ] + + spectrograms_data = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[1, 2], + batch_size=2, + num_workers=0, + ) + assert spectrograms_data.train_dataloader() is not None + assert spectrograms_data.val_dataloader() is None + assert spectrograms_data.test_dataloader() is None + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert sorted(list(labels.numpy())) == [1, 2] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_list_image_paths(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + spectrograms_data = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here + assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here + + # check validation data + data = next(iter(spectrograms_data.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [1, 4] + + # check test data + data = next(iter(spectrograms_data.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [2, 5] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") +def test_from_filepaths_visualise(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + dm = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], + batch_size=2, + num_workers=0, + ) + + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + # dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") +def test_from_filepaths_visualise_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + + image_a = str(tmpdir / "a" / "a_1.png") + image_b = str(tmpdir / "b" / "b_1.png") + + _rand_image().save(image_a) + _rand_image().save(image_b) + + dm = AudioClassificationData.from_files( + train_files=[image_a, image_b], + train_targets=[[0, 1, 0], [0, 1, 1]], + val_files=[image_b, image_a], + val_targets=[[1, 1, 0], [0, 0, 1]], + test_files=[image_b, image_b], + test_targets=[[0, 0, 1], [1, 1, 0]], + batch_size=2, + spectrogram_size=(64, 64), + ) + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch("to_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_val_batch("per_batch_transform") + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_splits(tmpdir): + tmpdir = Path(tmpdir) + + B, _, H, W = 2, 3, 224, 224 + img_size: Tuple[int, int] = (H, W) + + (tmpdir / "splits").mkdir() + _rand_image(img_size).save(tmpdir / "s.png") + + num_samples: int = 10 + val_split: float = .3 + + train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)] + + train_labels: List[int] = list(range(num_samples)) + + assert len(train_filepaths) == len(train_labels) + + _to_tensor = { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor) + ), + } + + def run(transform: Any = None): + dm = AudioClassificationData.from_files( + train_files=train_filepaths, + train_targets=train_labels, + train_transform=transform, + val_transform=transform, + batch_size=B, + num_workers=0, + val_split=val_split, + spectrogram_size=img_size, + ) + data = next(iter(dm.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (B, 3, H, W) + assert labels.shape == (B, ) + + run(_to_tensor) + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_folders_only_train(tmpdir): + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") + + spectrograms_data = AudioClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1, ) + + assert spectrograms_data.val_dataloader() is None + assert spectrograms_data.test_dataloader() is None + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_folders_train_val(tmpdir): + + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") + spectrograms_data = AudioClassificationData.from_folders( + train_dir, + val_folder=train_dir, + test_folder=train_dir, + batch_size=2, + num_workers=0, + ) + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + + data = next(iter(spectrograms_data.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [0, 0] + + data = next(iter(spectrograms_data.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [0, 0] + + +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +def test_from_filepaths_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + _rand_image().save(tmpdir / "a1.png") + _rand_image().save(tmpdir / "a2.png") + + train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")] + train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]] + valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]] + test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]] + + dm = AudioClassificationData.from_files( + train_files=train_images, + train_targets=train_labels, + val_files=train_images, + val_targets=valid_labels, + test_files=train_images, + test_targets=test_labels, + batch_size=2, + num_workers=0, + ) + + data = next(iter(dm.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + + data = next(iter(dm.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) + + data = next(iter(dm.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + torch.testing.assert_allclose(labels, torch.tensor(test_labels)) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index ec6c4bb834..56b729e36e 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -21,6 +21,7 @@ from flash.core.utilities.imports import _SKLEARN_AVAILABLE from tests.examples.utils import run_test from tests.helpers.utils import ( + _AUDIO_TESTING, _GRAPH_TESTING, _IMAGE_TESTING, _POINTCLOUD_TESTING, @@ -37,6 +38,10 @@ pytest.param( "custom_task.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") ), + pytest.param( + "audio_classification.py", + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed") + ), pytest.param( "image_classification.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 5bb699b664..bd57cf570d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -14,6 +14,7 @@ import os from flash.core.utilities.imports import ( + _AUDIO_AVAILABLE, _GRAPH_AVAILABLE, _IMAGE_AVAILABLE, _POINTCLOUD_AVAILABLE, @@ -30,6 +31,7 @@ _SERVE_TESTING = _SERVE_AVAILABLE _POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE _GRAPH_TESTING = _GRAPH_AVAILABLE +_AUDIO_TESTING = _AUDIO_AVAILABLE if "FLASH_TEST_TOPIC" in os.environ: topic = os.environ["FLASH_TEST_TOPIC"] @@ -40,3 +42,4 @@ _SERVE_TESTING = topic == "serve" _POINTCLOUD_TESTING = topic == "pointcloud" _GRAPH_TESTING = topic == "graph" + _AUDIO_TESTING = topic == "audio" diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 6a80b5774a..87cb183504 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -168,7 +168,7 @@ def test_from_filepaths_visualise(tmpdir): dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) From ea4604ffafbdfa0a48cf231a4284bfeca76c91b8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 21:58:17 +0100 Subject: [PATCH 4/4] Fix docs build (#603) * Fix docs * Fixes * Fixes * Fixes --- docs/source/api/audio.rst | 21 ++++++++ docs/source/index.rst | 1 + flash/pointcloud/__init__.py | 6 +-- .../detection/open3d_ml/backbones.py | 50 +++++++++---------- 4 files changed, 49 insertions(+), 29 deletions(-) create mode 100644 docs/source/api/audio.rst diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst new file mode 100644 index 0000000000..79662fea87 --- /dev/null +++ b/docs/source/api/audio.rst @@ -0,0 +1,21 @@ +########### +flash.audio +########### + +.. contents:: + :depth: 1 + :local: + :backlinks: top + +.. currentmodule:: flash.audio + +Classification +______________ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ~classification.data.AudioClassificationData + ~classification.data.AudioClassificationPreprocess diff --git a/docs/source/index.rst b/docs/source/index.rst index 2ac114009c..d12099d884 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -89,6 +89,7 @@ Lightning Flash api/data api/serve api/image + api/audio api/pointcloud api/tabular api/text diff --git a/flash/pointcloud/__init__.py b/flash/pointcloud/__init__.py index 8ad5b88538..766f2f2e89 100644 --- a/flash/pointcloud/__init__.py +++ b/flash/pointcloud/__init__.py @@ -1,4 +1,2 @@ -from flash.pointcloud.detection.data import PointCloudObjectDetectorData # noqa: F401 -from flash.pointcloud.detection.model import PointCloudObjectDetector # noqa: F401 -from flash.pointcloud.segmentation.data import PointCloudSegmentationData # noqa: F401 -from flash.pointcloud.segmentation.model import PointCloudSegmentation # noqa: F401 +from flash.pointcloud.detection import PointCloudObjectDetector, PointCloudObjectDetectorData # noqa: F401 +from flash.pointcloud.segmentation import PointCloudSegmentation, PointCloudSegmentationData # noqa: F401 diff --git a/flash/pointcloud/detection/open3d_ml/backbones.py b/flash/pointcloud/detection/open3d_ml/backbones.py index 6dbb0acbb1..622971299e 100644 --- a/flash/pointcloud/detection/open3d_ml/backbones.py +++ b/flash/pointcloud/detection/open3d_ml/backbones.py @@ -54,28 +54,28 @@ def register_open_3d_ml(register: FlashRegistry): CONFIG_PATH = os.path.join(os.path.dirname(open3d.__file__), "_ml3d/configs") - def get_collate_fn(model) -> Callable: - batcher_name = model.cfg.batcher - if batcher_name == 'DefaultBatcher': - batcher = DefaultBatcher() - elif batcher_name == 'ConcatBatcher': - batcher = ConcatBatcher(torch, model.__class__.__name__) - elif batcher_name == 'ObjectDetectBatchCollator': - return ObjectDetectBatchCollator - return batcher.collate_fn - - @register(parameters=PointPillars.__init__) - def pointpillars_kitti(*args, **kwargs) -> PointPillars: - cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "pointpillars_kitti.yml")) - cfg.model.device = "cpu" - model = PointPillars(**cfg.model) - weight_url = os.path.join(ROOT_URL, "pointpillars_kitti_202012221652utc.pth") - model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict'], ) - model.cfg.batcher = "ObjectDetectBatchCollator" - return model, 384, get_collate_fn(model) - - @register(parameters=PointPillars.__init__) - def pointpillars(*args, **kwargs) -> PointPillars: - model = PointPillars(*args, **kwargs) - model.cfg.batcher = "ObjectDetectBatch" - return model, get_collate_fn(model) + def get_collate_fn(model) -> Callable: + batcher_name = model.cfg.batcher + if batcher_name == 'DefaultBatcher': + batcher = DefaultBatcher() + elif batcher_name == 'ConcatBatcher': + batcher = ConcatBatcher(torch, model.__class__.__name__) + elif batcher_name == 'ObjectDetectBatchCollator': + return ObjectDetectBatchCollator + return batcher.collate_fn + + @register(parameters=PointPillars.__init__) + def pointpillars_kitti(*args, **kwargs) -> PointPillars: + cfg = _ml3d.utils.Config.load_from_file(os.path.join(CONFIG_PATH, "pointpillars_kitti.yml")) + cfg.model.device = "cpu" + model = PointPillars(**cfg.model) + weight_url = os.path.join(ROOT_URL, "pointpillars_kitti_202012221652utc.pth") + model.load_state_dict(pl_load(weight_url, map_location='cpu')['model_state_dict'], ) + model.cfg.batcher = "ObjectDetectBatchCollator" + return model, 384, get_collate_fn(model) + + @register(parameters=PointPillars.__init__) + def pointpillars(*args, **kwargs) -> PointPillars: + model = PointPillars(*args, **kwargs) + model.cfg.batcher = "ObjectDetectBatch" + return model, get_collate_fn(model)