This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into feature/icevision
- Loading branch information
Showing
43 changed files
with
1,930 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
########### | ||
flash.audio | ||
########### | ||
|
||
.. contents:: | ||
:depth: 1 | ||
:local: | ||
:backlinks: top | ||
|
||
.. currentmodule:: flash.audio | ||
|
||
Classification | ||
______________ | ||
|
||
.. autosummary:: | ||
:toctree: generated/ | ||
:nosignatures: | ||
:template: classtemplate.rst | ||
|
||
~classification.data.AudioClassificationData | ||
~classification.data.AudioClassificationPreprocess |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
|
||
.. _audio_classification: | ||
|
||
#################### | ||
Audio Classification | ||
#################### | ||
|
||
******** | ||
The Task | ||
******** | ||
|
||
The task of identifying what is in an audio file is called audio classification. | ||
Typically, Audio Classification is used to identify audio files containing sounds or words. | ||
The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. | ||
A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc. | ||
|
||
------ | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset. | ||
The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes. | ||
|
||
.. code-block:: | ||
urban8k_images | ||
├── train | ||
│ ├── air_conditioner | ||
│ ├── car_horn | ||
│ ├── children_playing | ||
│ ├── dog_bark | ||
│ ├── drilling | ||
│ ├── engine_idling | ||
│ ├── gun_shot | ||
│ ├── jackhammer | ||
│ ├── siren | ||
│ └── street_music | ||
├── test | ||
│ ├── air_conditioner | ||
│ ├── car_horn | ||
│ ├── children_playing | ||
│ ├── dog_bark | ||
│ ├── drilling | ||
│ ├── engine_idling | ||
│ ├── gun_shot | ||
│ ├── jackhammer | ||
│ ├── siren | ||
│ └── street_music | ||
└── val | ||
├── air_conditioner | ||
├── car_horn | ||
├── children_playing | ||
├── dog_bark | ||
├── drilling | ||
├── engine_idling | ||
├── gun_shot | ||
├── jackhammer | ||
├── siren | ||
└── street_music | ||
... | ||
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`. | ||
We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data. | ||
We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference. | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/audio_classification.py | ||
:language: python | ||
:lines: 14- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
|
||
.. _pointcloud_object_detection: | ||
|
||
############################ | ||
Point Cloud Object Detection | ||
############################ | ||
|
||
******** | ||
The Task | ||
******** | ||
|
||
A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates. | ||
|
||
PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes. | ||
|
||
The current integration builds on top `Open3D-ML <https://github.com/intel-isl/Open3D-ML>`_. | ||
|
||
------ | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at an example using a data set generated from the `KITTI Vision Benchmark <http://www.semantic-kitti.org/dataset.html>`_. | ||
The data are a tiny subset of the original dataset and contains sequences of point clouds. | ||
|
||
The data contains: | ||
* one folder for scans | ||
* one folder for scan calibrations | ||
* one folder for labels | ||
* a meta.yaml file describing the classes and their official associated color map. | ||
|
||
Here's the structure: | ||
|
||
.. code-block:: | ||
data | ||
├── meta.yaml | ||
├── train | ||
│ ├── scans | ||
| | ├── 00000.bin | ||
| | ├── 00001.bin | ||
| | ... | ||
│ ├── calibs | ||
| | ├── 00000.txt | ||
| | ├── 00001.txt | ||
| | ... | ||
│ ├── labels | ||
| | ├── 00000.txt | ||
| | ├── 00001.txt | ||
│ ... | ||
├── val | ||
│ ... | ||
├── predict | ||
├── scans | ||
| ├── 00000.bin | ||
| ├── 00001.bin | ||
| | ||
├── calibs | ||
| ├── 00000.txt | ||
| ├── 00001.txt | ||
├── meta.yaml | ||
Learn more: http://www.semantic-kitti.org/dataset.html | ||
|
||
|
||
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.detection.data.PointCloudObjectDetectorData`. | ||
We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.detection.model.PointCloudObjectDetector` task. | ||
We then use the trained :class:`~flash.image.detection.model.PointCloudObjectDetector` for inference. | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/pointcloud_detection.py | ||
:language: python | ||
:lines: 14- | ||
|
||
|
||
|
||
.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/visualizer_BoundingBoxes.png | ||
:width: 100% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Dict, Optional, Tuple | ||
|
||
from flash.audio.classification.transforms import default_transforms, train_default_transforms | ||
from flash.core.data.callback import BaseDataFetcher | ||
from flash.core.data.data_module import DataModule | ||
from flash.core.data.data_source import DefaultDataSources | ||
from flash.core.data.process import Deserializer, Preprocess | ||
from flash.core.utilities.imports import requires_extras | ||
from flash.image.classification.data import MatplotlibVisualization | ||
from flash.image.data import ImageDeserializer, ImagePathsDataSource | ||
|
||
|
||
class AudioClassificationPreprocess(Preprocess): | ||
|
||
@requires_extras(["audio", "image"]) | ||
def __init__( | ||
self, | ||
train_transform: Optional[Dict[str, Callable]], | ||
val_transform: Optional[Dict[str, Callable]], | ||
test_transform: Optional[Dict[str, Callable]], | ||
predict_transform: Optional[Dict[str, Callable]], | ||
spectrogram_size: Tuple[int, int] = (196, 196), | ||
time_mask_param: int = 80, | ||
freq_mask_param: int = 80, | ||
deserializer: Optional['Deserializer'] = None, | ||
): | ||
self.spectrogram_size = spectrogram_size | ||
self.time_mask_param = time_mask_param | ||
self.freq_mask_param = freq_mask_param | ||
|
||
super().__init__( | ||
train_transform=train_transform, | ||
val_transform=val_transform, | ||
test_transform=test_transform, | ||
predict_transform=predict_transform, | ||
data_sources={ | ||
DefaultDataSources.FILES: ImagePathsDataSource(), | ||
DefaultDataSources.FOLDERS: ImagePathsDataSource() | ||
}, | ||
deserializer=deserializer or ImageDeserializer(), | ||
default_data_source=DefaultDataSources.FILES, | ||
) | ||
|
||
def get_state_dict(self) -> Dict[str, Any]: | ||
return { | ||
**self.transforms, | ||
"spectrogram_size": self.spectrogram_size, | ||
"time_mask_param": self.time_mask_param, | ||
"freq_mask_param": self.freq_mask_param, | ||
} | ||
|
||
@classmethod | ||
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): | ||
return cls(**state_dict) | ||
|
||
def default_transforms(self) -> Optional[Dict[str, Callable]]: | ||
return default_transforms(self.spectrogram_size) | ||
|
||
def train_default_transforms(self) -> Optional[Dict[str, Callable]]: | ||
return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param) | ||
|
||
|
||
class AudioClassificationData(DataModule): | ||
"""Data module for audio classification.""" | ||
|
||
preprocess_cls = AudioClassificationPreprocess | ||
|
||
def set_block_viz_window(self, value: bool) -> None: | ||
"""Setter method to switch on/off matplotlib to pop up windows.""" | ||
self.data_fetcher.block_viz_window = value | ||
|
||
@staticmethod | ||
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: | ||
return MatplotlibVisualization(*args, **kwargs) |
Oops, something went wrong.