This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/task a thon audio classification spectrograms (#594)
* added audio spectrogram classification data, transforms and tests based on image classification * added audio spectrogram classification data, transforms and tests based on image classification * added audio spectrogram classification example and notebook * fixed formatting issues about newlines and longlines * updated docs to include audio classification task * removed empty `model` package * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Updates * Update CHANGELOG.md * Updates * Updates * Try fix * Updates * Updates * Updates Co-authored-by: Ethan Harris <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris <[email protected]>
- Loading branch information
1 parent
5b853c2
commit 6214983
Showing
18 changed files
with
650 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
|
||
.. _audio_classification: | ||
|
||
#################### | ||
Audio Classification | ||
#################### | ||
|
||
******** | ||
The Task | ||
******** | ||
|
||
The task of identifying what is in an audio file is called audio classification. | ||
Typically, Audio Classification is used to identify audio files containing sounds or words. | ||
The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. | ||
A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc. | ||
|
||
------ | ||
|
||
******* | ||
Example | ||
******* | ||
|
||
Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset. | ||
The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes. | ||
|
||
.. code-block:: | ||
urban8k_images | ||
├── train | ||
│ ├── air_conditioner | ||
│ ├── car_horn | ||
│ ├── children_playing | ||
│ ├── dog_bark | ||
│ ├── drilling | ||
│ ├── engine_idling | ||
│ ├── gun_shot | ||
│ ├── jackhammer | ||
│ ├── siren | ||
│ └── street_music | ||
├── test | ||
│ ├── air_conditioner | ||
│ ├── car_horn | ||
│ ├── children_playing | ||
│ ├── dog_bark | ||
│ ├── drilling | ||
│ ├── engine_idling | ||
│ ├── gun_shot | ||
│ ├── jackhammer | ||
│ ├── siren | ||
│ └── street_music | ||
└── val | ||
├── air_conditioner | ||
├── car_horn | ||
├── children_playing | ||
├── dog_bark | ||
├── drilling | ||
├── engine_idling | ||
├── gun_shot | ||
├── jackhammer | ||
├── siren | ||
└── street_music | ||
... | ||
Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`. | ||
We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data. | ||
We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference. | ||
Finally, we save the model. | ||
Here's the full example: | ||
|
||
.. literalinclude:: ../../../flash_examples/audio_classification.py | ||
:language: python | ||
:lines: 14- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, Dict, Optional, Tuple | ||
|
||
from flash.audio.classification.transforms import default_transforms, train_default_transforms | ||
from flash.core.data.callback import BaseDataFetcher | ||
from flash.core.data.data_module import DataModule | ||
from flash.core.data.data_source import DefaultDataSources | ||
from flash.core.data.process import Deserializer, Preprocess | ||
from flash.core.utilities.imports import requires_extras | ||
from flash.image.classification.data import MatplotlibVisualization | ||
from flash.image.data import ImageDeserializer, ImagePathsDataSource | ||
|
||
|
||
class AudioClassificationPreprocess(Preprocess): | ||
|
||
@requires_extras(["audio", "image"]) | ||
def __init__( | ||
self, | ||
train_transform: Optional[Dict[str, Callable]], | ||
val_transform: Optional[Dict[str, Callable]], | ||
test_transform: Optional[Dict[str, Callable]], | ||
predict_transform: Optional[Dict[str, Callable]], | ||
spectrogram_size: Tuple[int, int] = (196, 196), | ||
time_mask_param: int = 80, | ||
freq_mask_param: int = 80, | ||
deserializer: Optional['Deserializer'] = None, | ||
): | ||
self.spectrogram_size = spectrogram_size | ||
self.time_mask_param = time_mask_param | ||
self.freq_mask_param = freq_mask_param | ||
|
||
super().__init__( | ||
train_transform=train_transform, | ||
val_transform=val_transform, | ||
test_transform=test_transform, | ||
predict_transform=predict_transform, | ||
data_sources={ | ||
DefaultDataSources.FILES: ImagePathsDataSource(), | ||
DefaultDataSources.FOLDERS: ImagePathsDataSource() | ||
}, | ||
deserializer=deserializer or ImageDeserializer(), | ||
default_data_source=DefaultDataSources.FILES, | ||
) | ||
|
||
def get_state_dict(self) -> Dict[str, Any]: | ||
return { | ||
**self.transforms, | ||
"spectrogram_size": self.spectrogram_size, | ||
"time_mask_param": self.time_mask_param, | ||
"freq_mask_param": self.freq_mask_param, | ||
} | ||
|
||
@classmethod | ||
def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): | ||
return cls(**state_dict) | ||
|
||
def default_transforms(self) -> Optional[Dict[str, Callable]]: | ||
return default_transforms(self.spectrogram_size) | ||
|
||
def train_default_transforms(self) -> Optional[Dict[str, Callable]]: | ||
return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param) | ||
|
||
|
||
class AudioClassificationData(DataModule): | ||
"""Data module for audio classification.""" | ||
|
||
preprocess_cls = AudioClassificationPreprocess | ||
|
||
def set_block_viz_window(self, value: bool) -> None: | ||
"""Setter method to switch on/off matplotlib to pop up windows.""" | ||
self.data_fetcher.block_viz_window = value | ||
|
||
@staticmethod | ||
def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: | ||
return MatplotlibVisualization(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Callable, Dict, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from flash.core.data.data_source import DefaultDataKeys | ||
from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms | ||
from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE | ||
|
||
if _TORCHVISION_AVAILABLE: | ||
import torchvision | ||
from torchvision import transforms as T | ||
|
||
if _TORCHAUDIO_AVAILABLE: | ||
from torchaudio import transforms as TAudio | ||
|
||
|
||
def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: | ||
"""The default transforms for audio classification for spectrograms: resize the spectrogram, | ||
convert the spectrogram and target to a tensor, and collate the batch.""" | ||
return { | ||
"pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)), | ||
"to_tensor_transform": nn.Sequential( | ||
ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), | ||
ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), | ||
), | ||
"collate": kornia_collate, | ||
} | ||
|
||
|
||
def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int, | ||
freq_mask_param: int) -> Dict[str, Callable]: | ||
"""During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``""" | ||
transforms = { | ||
"post_tensor_transform": nn.Sequential( | ||
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), | ||
ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) | ||
) | ||
} | ||
|
||
return merge_transforms(default_transforms(spectrogram_size), transforms) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import flash | ||
from flash.audio import AudioClassificationData | ||
from flash.core.data.utils import download_data | ||
from flash.core.finetuning import FreezeUnfreeze | ||
from flash.image import ImageClassifier | ||
|
||
# 1. Create the DataModule | ||
download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data") | ||
|
||
datamodule = AudioClassificationData.from_folders( | ||
train_folder="data/urban8k_images/train", | ||
val_folder="data/urban8k_images/val", | ||
spectrogram_size=(64, 64), | ||
) | ||
|
||
# 2. Build the model. | ||
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) | ||
|
||
# 3. Create the trainer and finetune the model | ||
trainer = flash.Trainer(max_epochs=3) | ||
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) | ||
|
||
# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c | ||
predictions = model.predict([ | ||
"data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", | ||
"data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", | ||
"data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", | ||
]) | ||
print(predictions) | ||
|
||
# 5. Save the model! | ||
trainer.save_checkpoint("audio_classification_model.pt") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
asteroid>=0.5.1 | ||
torchaudio |
Empty file.
Empty file.
Oops, something went wrong.