Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into feature/icevision
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris committed Jul 19, 2021
2 parents 5802dcf + ea4604f commit 37044c2
Show file tree
Hide file tree
Showing 43 changed files with 1,930 additions and 21 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ jobs:
python-version: 3.8
requires: 'latest'
topic: ['graph']
- os: ubuntu-20.04
python-version: 3.8
requires: 'latest'
topic: ['audio']

# Timeout: https://stackoverflow.com/a/59076067/4521646
timeout-minutes: 35
Expand Down Expand Up @@ -128,6 +132,13 @@ jobs:
run: |
pip install '.[all]' --pre --upgrade
- name: Install audio test dependencies
if: matrix.topic[0] == 'audio'
run: |
sudo apt-get install libsndfile1
pip install matplotlib
pip install '.[image]' --pre --upgrade
- name: Cache datasets
uses: actions/cache@v2
with:
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `PointCloudSegmentation` Task ([#566](https://github.com/PyTorchLightning/lightning-flash/pull/566))

- Added `PointCloudObjectDetection` Task ([#600](https://github.com/PyTorchLightning/lightning-flash/pull/600))

- Added a `GraphClassifier` task ([#73](https://github.com/PyTorchLightning/lightning-flash/pull/73))

- Added the option to pass `pretrained` as a string to `SemanticSegmentation` to change pretrained weights to load from `segmentation-models.pytorch` ([#587](https://github.com/PyTorchLightning/lightning-flash/pull/587))

- Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585))

- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594))

### Changed

- Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560))
Expand Down
2 changes: 1 addition & 1 deletion docs/source/_templates/layout.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
{% block footer %}
{{ super() }}
<script script type="text/javascript">
var collapsedSections = ['Guides', 'Image and Video', 'Tabular', 'Text', 'Point Cloud', 'Graph', 'Integrations', 'API Reference', 'Contributing a Task'];
var collapsedSections = ['Guides', 'Image and Video', 'Audio', 'Tabular', 'Text', 'Point Cloud', 'Graph', 'Integrations', 'API Reference', 'Contributing a Task'];
</script>

{% endblock %}
21 changes: 21 additions & 0 deletions docs/source/api/audio.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
###########
flash.audio
###########

.. contents::
:depth: 1
:local:
:backlinks: top

.. currentmodule:: flash.audio

Classification
______________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~classification.data.AudioClassificationData
~classification.data.AudioClassificationPreprocess
16 changes: 16 additions & 0 deletions docs/source/api/pointcloud.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,19 @@ ____________
segmentation.data.PointCloudSegmentationPreprocess
segmentation.data.PointCloudSegmentationFoldersDataSource
segmentation.data.PointCloudSegmentationDatasetDataSource


Object Detection
________________

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

~detection.model.PointCloudObjectDetector
~detection.data.PointCloudObjectDetectorData

detection.data.PointCloudObjectDetectorPreprocess
detection.data.PointCloudObjectDetectorFoldersDataSource
detection.data.PointCloudObjectDetectorDatasetDataSource
2 changes: 1 addition & 1 deletion docs/source/general/registry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Example::

from flash.image.backbones import IMAGE_CLASSIFIER_BACKBONES, OBJ_DETECTION_BACKBONES

print(IMAGE_CLASSIFIER_BACKBONES.available_models())
print(IMAGE_CLASSIFIER_BACKBONES.available_keys())
""" out:
['adv_inception_v3', 'cspdarknet53', 'cspdarknet53_iabn', 430+.., 'xception71']
"""
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ Lightning Flash
reference/style_transfer
reference/video_classification

.. toctree::
:maxdepth: 1
:caption: Audio

reference/audio_classification

.. toctree::
:maxdepth: 1
:caption: Tabular
Expand All @@ -60,6 +66,7 @@ Lightning Flash
:caption: Point Cloud

reference/pointcloud_segmentation
reference/pointcloud_object_detection

.. toctree::
:maxdepth: 1
Expand All @@ -82,6 +89,7 @@ Lightning Flash
api/data
api/serve
api/image
api/audio
api/pointcloud
api/tabular
api/text
Expand Down
73 changes: 73 additions & 0 deletions docs/source/reference/audio_classification.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@

.. _audio_classification:

####################
Audio Classification
####################

********
The Task
********

The task of identifying what is in an audio file is called audio classification.
Typically, Audio Classification is used to identify audio files containing sounds or words.
The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty.
A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc.

------

*******
Example
*******

Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset.
The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes.

.. code-block::
urban8k_images
├── train
│ ├── air_conditioner
│ ├── car_horn
│ ├── children_playing
│ ├── dog_bark
│ ├── drilling
│ ├── engine_idling
│ ├── gun_shot
│ ├── jackhammer
│ ├── siren
│ └── street_music
├── test
│ ├── air_conditioner
│ ├── car_horn
│ ├── children_playing
│ ├── dog_bark
│ ├── drilling
│ ├── engine_idling
│ ├── gun_shot
│ ├── jackhammer
│ ├── siren
│ └── street_music
└── val
├── air_conditioner
├── car_horn
├── children_playing
├── dog_bark
├── drilling
├── engine_idling
├── gun_shot
├── jackhammer
├── siren
└── street_music
...
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`.
We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data.
We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/audio_classification.py
:language: python
:lines: 14-
82 changes: 82 additions & 0 deletions docs/source/reference/pointcloud_object_detection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

.. _pointcloud_object_detection:

############################
Point Cloud Object Detection
############################

********
The Task
********

A Point Cloud is a set of data points in space, usually describes by ``x``, ``y`` and ``z`` coordinates.

PointCloud Object Detection is the task of identifying 3D objects in point clouds and their associated classes and 3D bounding boxes.

The current integration builds on top `Open3D-ML <https://github.com/intel-isl/Open3D-ML>`_.

------

*******
Example
*******

Let's look at an example using a data set generated from the `KITTI Vision Benchmark <http://www.semantic-kitti.org/dataset.html>`_.
The data are a tiny subset of the original dataset and contains sequences of point clouds.

The data contains:
* one folder for scans
* one folder for scan calibrations
* one folder for labels
* a meta.yaml file describing the classes and their official associated color map.

Here's the structure:

.. code-block::
data
├── meta.yaml
├── train
│ ├── scans
| | ├── 00000.bin
| | ├── 00001.bin
| | ...
│ ├── calibs
| | ├── 00000.txt
| | ├── 00001.txt
| | ...
│ ├── labels
| | ├── 00000.txt
| | ├── 00001.txt
│ ...
├── val
│ ...
├── predict
├── scans
| ├── 00000.bin
| ├── 00001.bin
|
├── calibs
| ├── 00000.txt
| ├── 00001.txt
├── meta.yaml
Learn more: http://www.semantic-kitti.org/dataset.html


Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.image.detection.data.PointCloudObjectDetectorData`.
We select a pre-trained ``randlanet_semantic_kitti`` backbone for our :class:`~flash.image.detection.model.PointCloudObjectDetector` task.
We then use the trained :class:`~flash.image.detection.model.PointCloudObjectDetector` for inference.
Finally, we save the model.
Here's the full example:

.. literalinclude:: ../../../flash_examples/pointcloud_detection.py
:language: python
:lines: 14-



.. image:: https://raw.githubusercontent.com/intel-isl/Open3D-ML/master/docs/images/visualizer_BoundingBoxes.png
:width: 100%
1 change: 1 addition & 0 deletions flash/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
1 change: 1 addition & 0 deletions flash/audio/classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401
87 changes: 87 additions & 0 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, Optional, Tuple

from flash.audio.classification.transforms import default_transforms, train_default_transforms
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DefaultDataSources
from flash.core.data.process import Deserializer, Preprocess
from flash.core.utilities.imports import requires_extras
from flash.image.classification.data import MatplotlibVisualization
from flash.image.data import ImageDeserializer, ImagePathsDataSource


class AudioClassificationPreprocess(Preprocess):

@requires_extras(["audio", "image"])
def __init__(
self,
train_transform: Optional[Dict[str, Callable]],
val_transform: Optional[Dict[str, Callable]],
test_transform: Optional[Dict[str, Callable]],
predict_transform: Optional[Dict[str, Callable]],
spectrogram_size: Tuple[int, int] = (196, 196),
time_mask_param: int = 80,
freq_mask_param: int = 80,
deserializer: Optional['Deserializer'] = None,
):
self.spectrogram_size = spectrogram_size
self.time_mask_param = time_mask_param
self.freq_mask_param = freq_mask_param

super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_sources={
DefaultDataSources.FILES: ImagePathsDataSource(),
DefaultDataSources.FOLDERS: ImagePathsDataSource()
},
deserializer=deserializer or ImageDeserializer(),
default_data_source=DefaultDataSources.FILES,
)

def get_state_dict(self) -> Dict[str, Any]:
return {
**self.transforms,
"spectrogram_size": self.spectrogram_size,
"time_mask_param": self.time_mask_param,
"freq_mask_param": self.freq_mask_param,
}

@classmethod
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False):
return cls(**state_dict)

def default_transforms(self) -> Optional[Dict[str, Callable]]:
return default_transforms(self.spectrogram_size)

def train_default_transforms(self) -> Optional[Dict[str, Callable]]:
return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param)


class AudioClassificationData(DataModule):
"""Data module for audio classification."""

preprocess_cls = AudioClassificationPreprocess

def set_block_viz_window(self, value: bool) -> None:
"""Setter method to switch on/off matplotlib to pop up windows."""
self.data_fetcher.block_viz_window = value

@staticmethod
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher:
return MatplotlibVisualization(*args, **kwargs)
Loading

0 comments on commit 37044c2

Please sign in to comment.