From 8867a7040eba3c56f829ffc5f6d34e53ff90eec4 Mon Sep 17 00:00:00 2001 From: Kinyugo Date: Thu, 15 Jul 2021 10:19:41 +0300 Subject: [PATCH 01/16] added audio spectrogram classification data, transforms and tests based on image classification --- flash/core/utilities/imports.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index 8632802001..dd76a4998f 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -84,6 +84,7 @@ def _compare_version(package: str, op, version) -> bool: _CYTOOLZ_AVAILABLE = _module_available("cytoolz") _UVICORN_AVAILABLE = _module_available("uvicorn") _PIL_AVAILABLE = _module_available("PIL") +_AUDIO_AVAILABLE = _module_available("torchaudio") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") From 6daf545f0181f78f0b16e7a2142f96bc19a6a3dd Mon Sep 17 00:00:00 2001 From: Kinyugo Date: Thu, 15 Jul 2021 10:22:34 +0300 Subject: [PATCH 02/16] added audio spectrogram classification data, transforms and tests based on image classification --- flash/audio/__init__.py | 1 + flash/audio/classification/__init__.py | 1 + flash/audio/classification/data.py | 80 ++++++ flash/audio/classification/model.py | 0 flash/audio/classification/transforms.py | 72 +++++ tests/audio/__init__.py | 1 + tests/audio/classification/__init__.py | 0 tests/audio/classification/test_data.py | 339 +++++++++++++++++++++++ 8 files changed, 494 insertions(+) create mode 100644 flash/audio/__init__.py create mode 100644 flash/audio/classification/__init__.py create mode 100644 flash/audio/classification/data.py create mode 100644 flash/audio/classification/model.py create mode 100644 flash/audio/classification/transforms.py create mode 100644 tests/audio/__init__.py create mode 100644 tests/audio/classification/__init__.py create mode 100644 tests/audio/classification/test_data.py diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py new file mode 100644 index 0000000000..6b7024f040 --- /dev/null +++ b/flash/audio/__init__.py @@ -0,0 +1 @@ +from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess \ No newline at end of file diff --git a/flash/audio/classification/__init__.py b/flash/audio/classification/__init__.py new file mode 100644 index 0000000000..0988071462 --- /dev/null +++ b/flash/audio/classification/__init__.py @@ -0,0 +1 @@ +from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess \ No newline at end of file diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py new file mode 100644 index 0000000000..e3a0fa40c1 --- /dev/null +++ b/flash/audio/classification/data.py @@ -0,0 +1,80 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Dict, Optional, Tuple + +from flash.audio.classification.transforms import default_transforms, train_default_transforms +from flash.core.data.callback import BaseDataFetcher +from flash.core.data.data_module import DataModule +from flash.core.data.data_source import DefaultDataSources +from flash.core.data.process import Deserializer, Preprocess +from flash.image.classification.data import MatplotlibVisualization +from flash.image.data import ImageDeserializer, ImagePathsDataSource + + +class AudioClassificationPreprocess(Preprocess): + + def __init__( + self, + train_transform: Optional[Dict[str, Callable]], + val_transform: Optional[Dict[str, Callable]], + test_transform: Optional[Dict[str, Callable]], + predict_transform: Optional[Dict[str, Callable]], + spectrogram_size: Tuple[int, int] = (196, 196), + time_mask_param: int = 80, + freq_mask_param: int = 80, + deserializer: Optional['Deserializer'] = None, + ): + self.spectrogram_size = spectrogram_size + self.time_mask_param = time_mask_param + self.freq_mask_param = freq_mask_param + + super().__init__( + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_sources={ + DefaultDataSources.FILES: ImagePathsDataSource(), + DefaultDataSources.FOLDERS: ImagePathsDataSource() + }, + deserializer=deserializer or ImageDeserializer(), + default_data_source=DefaultDataSources.FILES, + ) + + def get_state_dict(self) -> Dict[str, Any]: + return {**self.transforms, "image_size": self.image_size} + + @classmethod + def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): + return cls(**state_dict) + + def default_transforms(self) -> Optional[Dict[str, Callable]]: + return default_transforms(self.spectrogram_size) + + def train_default_transforms(self) -> Optional[Dict[str, Callable]]: + return train_default_transforms(self.spectrogram_size, self.time_mask_param, self.freq_mask_param) + + +class AudioClassificationData(DataModule): + """Data module for audio classification.""" + + preprocess_cls = AudioClassificationPreprocess + + def set_block_viz_window(self, value: bool) -> None: + """Setter method to switch on/off matplotlib to pop up windows.""" + self.data_fetcher.block_viz_window = value + + @staticmethod + def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: + return MatplotlibVisualization(*args, **kwargs) diff --git a/flash/audio/classification/model.py b/flash/audio/classification/model.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py new file mode 100644 index 0000000000..e6439c841a --- /dev/null +++ b/flash/audio/classification/transforms.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable, Dict, Tuple + +import torch +from torch import nn + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms +from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE, _AUDIO_AVAILABLE + +if _KORNIA_AVAILABLE: + import kornia as K + +if _TORCHVISION_AVAILABLE: + import torchvision + from torchvision import transforms as T + +if _AUDIO_AVAILABLE: + from torchaudio import transforms as TAudio + + +def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: + """The default transforms for audio classification for spectrograms: resize the spectrogram, convert the spectrogram and target to a tensor, + and collate the batch.""" + if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": + # Better approach as all transforms are applied on tensor directly + return { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "post_tensor_transform": ApplyToKeys( + DefaultDataKeys.INPUT, + K.geometry.Resize(spectrogram_size), + ), + "collate": kornia_collate, + } + return { + "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)), + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), + ), + "collate": kornia_collate, + } + + +def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int, + freq_mask_param: int) -> Dict[str, Callable]: + """During training we apply the default transforms with aditional ``TimeMasking`` and ``Frequency Masking``""" + if os.getenv("FLASH_TESTING", "0") != 1: + transforms = { + "post_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) + ) + } + + return merge_transforms(default_transforms(spectrogram_size), transforms) \ No newline at end of file diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py new file mode 100644 index 0000000000..6b7024f040 --- /dev/null +++ b/tests/audio/__init__.py @@ -0,0 +1 @@ +from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess \ No newline at end of file diff --git a/tests/audio/classification/__init__.py b/tests/audio/classification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py new file mode 100644 index 0000000000..cb4798b706 --- /dev/null +++ b/tests/audio/classification/test_data.py @@ -0,0 +1,339 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any, List, Tuple + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from flash.core.data.data_source import DefaultDataKeys +from flash.core.data.transforms import ApplyToKeys +from flash.audio import AudioClassificationData +from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from tests.helpers.utils import _IMAGE_TESTING + +if _TORCHVISION_AVAILABLE: + import torchvision + +if _PIL_AVAILABLE: + from PIL import Image + + +def _rand_image(size: Tuple[int, int] = None): + if size is None: + _size = np.random.choice([196, 244]) + size = (_size, _size) + return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8")) + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_filepaths_smoke(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + _rand_image().save(tmpdir / "a_1.png") + _rand_image().save(tmpdir / "b_1.png") + + train_images = [ + str(tmpdir / "a_1.png"), + str(tmpdir / "b_1.png"), + ] + + spectrograms_data = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[1, 2], + batch_size=2, + num_workers=0, + ) + assert spectrograms_data.train_dataloader() is not None + assert spectrograms_data.val_dataloader() is None + assert spectrograms_data.test_dataloader() is None + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert sorted(list(labels.numpy())) == [1, 2] + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_filepaths_list_image_paths(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + spectrograms_data = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert labels.numpy()[0] in [0, 3, 6] # data comes shuffled here + assert labels.numpy()[1] in [0, 3, 6] # data comes shuffled here + + # check validation data + data = next(iter(spectrograms_data.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [1, 4] + + # check test data + data = next(iter(spectrograms_data.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [2, 5] + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_filepaths_visualise(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "e").mkdir() + _rand_image().save(tmpdir / "e_1.png") + + train_images = [ + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + str(tmpdir / "e_1.png"), + ] + + dm = AudioClassificationData.from_files( + train_files=train_images, + train_targets=[0, 3, 6], + val_files=train_images, + val_targets=[1, 4, 7], + test_files=train_images, + test_targets=[2, 5, 8], + batch_size=2, + num_workers=0, + ) + + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + # dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_filepaths_visualise_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + (tmpdir / "b").mkdir() + + image_a = str(tmpdir / "a" / "a_1.png") + image_b = str(tmpdir / "b" / "b_1.png") + + _rand_image().save(image_a) + _rand_image().save(image_b) + + dm = AudioClassificationData.from_files( + train_files=[image_a, image_b], + train_targets=[[0, 1, 0], [0, 1, 1]], + val_files=[image_b, image_a], + val_targets=[[1, 1, 0], [0, 0, 1]], + test_files=[image_b, image_b], + test_targets=[[0, 0, 1], [1, 1, 0]], + batch_size=2, + spectrogram_size=(64, 64), + ) + # disable visualisation for testing + assert dm.data_fetcher.block_viz_window is True + dm.set_block_viz_window(False) + assert dm.data_fetcher.block_viz_window is False + + # call show functions + dm.show_train_batch() + dm.show_train_batch("pre_tensor_transform") + dm.show_train_batch("to_tensor_transform") + dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) + dm.show_val_batch("per_batch_transform") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_filepaths_splits(tmpdir): + tmpdir = Path(tmpdir) + + B, _, H, W = 2, 3, 224, 224 + img_size: Tuple[int, int] = (H, W) + + (tmpdir / "splits").mkdir() + _rand_image(img_size).save(tmpdir / "s.png") + + num_samples: int = 10 + val_split: float = .3 + + train_filepaths: List[str] = [str(tmpdir / "s.png") for _ in range(num_samples)] + + train_labels: List[int] = list(range(num_samples)) + + assert len(train_filepaths) == len(train_labels) + + _to_tensor = { + "to_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), + ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor) + ), + } + + def run(transform: Any = None): + dm = AudioClassificationData.from_files( + train_files=train_filepaths, + train_targets=train_labels, + train_transform=transform, + val_transform=transform, + batch_size=B, + num_workers=0, + val_split=val_split, + spectrogram_size=img_size, + ) + data = next(iter(dm.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (B, 3, H, W) + assert labels.shape == (B, ) + + run(_to_tensor) + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_folders_only_train(tmpdir): + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") + + spectrograms_data = AudioClassificationData.from_folders(train_dir, train_transform=None, batch_size=1) + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1, ) + + assert spectrograms_data.val_dataloader() is None + assert spectrograms_data.test_dataloader() is None + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_folders_train_val(tmpdir): + + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + (train_dir / "a").mkdir() + _rand_image().save(train_dir / "a" / "1.png") + _rand_image().save(train_dir / "a" / "2.png") + + (train_dir / "b").mkdir() + _rand_image().save(train_dir / "b" / "1.png") + _rand_image().save(train_dir / "b" / "2.png") + spectrograms_data = AudioClassificationData.from_folders( + train_dir, + val_folder=train_dir, + test_folder=train_dir, + batch_size=2, + num_workers=0, + ) + + data = next(iter(spectrograms_data.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + + data = next(iter(spectrograms_data.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [0, 0] + + data = next(iter(spectrograms_data.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + assert list(labels.numpy()) == [0, 0] + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_filepaths_multilabel(tmpdir): + tmpdir = Path(tmpdir) + + (tmpdir / "a").mkdir() + _rand_image().save(tmpdir / "a1.png") + _rand_image().save(tmpdir / "a2.png") + + train_images = [str(tmpdir / "a1.png"), str(tmpdir / "a2.png")] + train_labels = [[1, 0, 1, 0], [0, 0, 1, 1]] + valid_labels = [[1, 1, 1, 0], [1, 0, 0, 1]] + test_labels = [[1, 0, 1, 0], [1, 1, 0, 1]] + + dm = AudioClassificationData.from_files( + train_files=train_images, + train_targets=train_labels, + val_files=train_images, + val_targets=valid_labels, + test_files=train_images, + test_targets=test_labels, + batch_size=2, + num_workers=0, + ) + + data = next(iter(dm.train_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + + data = next(iter(dm.val_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + torch.testing.assert_allclose(labels, torch.tensor(valid_labels)) + + data = next(iter(dm.test_dataloader())) + imgs, labels = data['input'], data['target'] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 4) + torch.testing.assert_allclose(labels, torch.tensor(test_labels)) From 58a3841e82d7eddb161d6525e1ed7c95e2222654 Mon Sep 17 00:00:00 2001 From: Kinyugo Date: Fri, 16 Jul 2021 07:44:12 +0300 Subject: [PATCH 03/16] added audio spectrogram classification example and notebook --- flash_examples/audio_classification.py | 52 +++ flash_notebooks/audio_classification.ipynb | 354 +++++++++++++++++++++ 2 files changed, 406 insertions(+) create mode 100644 flash_examples/audio_classification.py create mode 100644 flash_notebooks/audio_classification.ipynb diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py new file mode 100644 index 0000000000..d0550ad2ee --- /dev/null +++ b/flash_examples/audio_classification.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import flash +from flash.core.data.utils import download_data +from flash.audio import AudioClassificationData +from flash.image import ImageClassifier +from flash.core.finetuning import FreezeUnfreeze + +# 1. Create the DataModule +download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data") + +datamodule = AudioClassificationData.from_folders( + train_folder="data/urban8k_images/train", + val_folder="data/urban8k_images/val", + spectrogram_size=(64, 64), +) + +# 2. Build the model. +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + +# 3. Create the trainer and finetune the model +trainer = flash.Trainer(max_epochs=3) +trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1)) + +# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c +predictions = model.predict([ + "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", + "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", + "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", + "data/urban8k_images/test/street_music/7390-9-0-6.wav.jpg", + "data/urban8k_images/test/car_horn/7389-1-0-6.wav.jpg", + "data/urban8k_images/test/dog_bark/344-3-4-0.wav.jpg", + "data/urban8k_images/test/drilling/22962-4-0-0.wav.jpg", + "data/urban8k_images/test/engine_idling/6988-5-0-2.wav.jpg", + "data/urban8k_images/test/gun_shot/7063-6-0-0.wav.jpg", + "data/urban8k_images/test/siren/22601-8-0-9.wav.jpg", +]) +print(predictions) + +# 5. Save the model! +trainer.save_checkpoint("audio_classification_model.pt") \ No newline at end of file diff --git a/flash_notebooks/audio_classification.ipynb b/flash_notebooks/audio_classification.ipynb new file mode 100644 index 0000000000..22c842e8d6 --- /dev/null +++ b/flash_notebooks/audio_classification.ipynb @@ -0,0 +1,354 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \"Open\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this notebook, we'll go over the basics of lightning Flash by finetuning/prediction with an ImageClassifier on [Urban Sound 8k Images Dataset](https://www.kaggle.com/gokulrejith/urban-sound-8k-images) containing mel spectrograms of urban sounds from 10 classes: *airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, and street_music*.\n", + "\n", + "# Finetuning\n", + "\n", + "Finetuning consists of four steps:\n", + " \n", + " - 1. Training a source neural network model on source dataset. In this notebook we can rely on [Torchvision](https://pytorch.org/docs/stable/torchvision/index.html) models, pretrained on the [ImageNet dataset](http://www.image-net.org) and finetune them to fit our dataset of Mel spectrograms. The specific architecture that will be used is the [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/).\n", + " \n", + " - 2. Create a new neural network called the target model. Its architecture replicates the source model and parameters, expect the latest layer which is removed. This model without its latest layer is traditionally called a backbone\n", + " \n", + " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", + " \n", + " - 4. Train the target model on a target dataset, such as Urban Sound 8k Images. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`. The strategy that will be used in this notebook is the `freeze/unfreeze` strategy. Since our dataset deviates so much from the ImageNet dataset, we first train the head only for a couple of epochs, then later unfreeze the whole model, even the backbone, so we can better fit our dataset. The reason for freezing the head for a couple of epochs is to ensure that we don't propagate, random information to the backbone as training starts, due to random weight initialization of the head, and we can actually leverage features already learned by the backbone.\n", + " \n", + " \n", + "\n", + " \n", + "\n", + "---\n", + " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", + " - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)\n", + " - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", + " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "! pip install git+https://github.com/PyTorchLightning/lightning-flash.git" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### The notebook runtime has to be re-started once Flash is installed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# https://github.com/streamlit/demo-self-driving/issues/17\n", + "if 'google.colab' in str(get_ipython()):\n", + " import os\n", + " os.kill(os.getpid(), 9)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import flash\n", + "from flash.core.data.utils import download_data\n", + "from flash.audio import AudioClassificationData\n", + "from flash.image import ImageClassifier\n", + "from flash.core.finetuning import FreezeUnfreeze" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Download data\n", + "The data are downloaded from a URL, and save in a 'data' directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "download_data(\"https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip\", \"./data\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Load the data\n", + "\n", + "Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.\n", + "Creates a AudioClassificationData object from folders of images arranged in this way:\n", + "\n", + "\n", + " train/dog/xxx.png\n", + " train/dog/xxy.png\n", + " train/dog/xxz.png\n", + " train/cat/123.png\n", + " train/cat/nsdf3.png\n", + " train/cat/asd932.png\n", + "\n", + "\n", + "Note: Each sub-folder content will be considered as a new class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = AudioClassificationData.from_folders(\n", + " train_folder=\"data/urban8k_images/train\",\n", + " val_folder=\"data/urban8k_images/val\",\n", + " test_folder=\"data/urban8k_images/test\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Build the model\n", + "\n", + "Create the ImageClassifier task. By default, the ImageClassifier task uses a [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/) backbone to train or finetune your model.\n", + "For [Urban Sound 8k Images Dataset](https://www.kaggle.com/gokulrejith/urban-sound-8k-images) ``datamodule.num_classes`` will be 10.\n", + "Backbone can easily be changed with `ImageClassifier(backbone=\"resnet50\")` or you could provide your own `ImageClassifier(backbone=my_backbone)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = ImageClassifier(backbone=\"resnet18\", num_classes=datamodule.num_classes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 4. Create the trainer. Run once on data\n", + "\n", + "The trainer object can be used for training or fine-tuning tasks on new sets of data. \n", + "\n", + "You can pass in parameters to control the training routine- limit the number of epochs, run on GPUs or TPUs, etc.\n", + "\n", + "For more details, read the [Trainer Documentation](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html).\n", + "\n", + "In this demo, we will limit the fine-tuning to run just 3 epoch using max_epochs=2." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = flash.Trainer(max_epochs=3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 5. Finetune the model \n", + "\n", + "`FreezeUnfreeze` strategy unfreezes the backbone after 1 epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 6. Test the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.test()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 7. Save it!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_checkpoint(\"audio_classification_model.pt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Predicting" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Load the model from a checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = ImageClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/audio_classification_model.pt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2a. Predict what's on a few images!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predictions = model.predict([\n", + " \"data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg\",\n", + " \"data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg\",\n", + " \"data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg\",\n", + " \"data/urban8k_images/test/street_music/7390-9-0-6.wav.jpg\",\n", + " \"data/urban8k_images/test/car_horn/7389-1-0-6.wav.jpg\",\n", + " \"data/urban8k_images/test/dog_bark/344-3-4-0.wav.jpg\",\n", + " \"data/urban8k_images/test/drilling/22962-4-0-0.wav.jpg\",\n", + " \"data/urban8k_images/test/engine_idling/6988-5-0-2.wav.jpg\",\n", + " \"data/urban8k_images/test/gun_shot/7063-6-0-0.wav.jpg\",\n", + " \"data/urban8k_images/test/siren/22601-8-0-9.wav.jpg\",\n", + "])\n", + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2b. Or generate prediction with a whole folder!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = ImageClassificationData.from_folders(predict_folder=\"data/urban8k_images/test\")\n", + "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "

Congratulations - Time to Join the Community!

\n", + "
\n", + "\n", + "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", + "\n", + "### Help us build Flash by adding support for new data-types and new tasks.\n", + "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", + "If you are interested, please open a PR with your contributions !!! \n", + "\n", + "\n", + "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", + "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", + "\n", + "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", + "\n", + "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!\n", + "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", + "\n", + "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", + "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", + "\n", + "* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", + "\n", + "### Contributions !\n", + "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", + "\n", + "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", + "* You can also contribute your own notebooks with useful examples !\n", + "\n", + "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", + "\n", + "" + ] + } + ], + "metadata": { + "interpreter": { + "hash": "bfd2fe4d77e9254f93414392fd32c2a0f6e7778a519d9ecdbf1751b4355012ab" + }, + "kernelspec": { + "display_name": "Python 3.9.1 64-bit ('lf_audio_spectrograms': conda)", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file From dc2ca9d9d9f204f97ff10c24c34ec3dc3318429c Mon Sep 17 00:00:00 2001 From: Kinyugo Date: Fri, 16 Jul 2021 11:13:41 +0300 Subject: [PATCH 04/16] fixed formatting issues about newlines and longlines --- flash/audio/__init__.py | 2 +- flash/audio/classification/__init__.py | 2 +- flash/audio/classification/transforms.py | 6 +++--- flash_examples/audio_classification.py | 2 +- tests/audio/__init__.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py index 6b7024f040..120039c2eb 100644 --- a/flash/audio/__init__.py +++ b/flash/audio/__init__.py @@ -1 +1 @@ -from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess \ No newline at end of file +from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess diff --git a/flash/audio/classification/__init__.py b/flash/audio/classification/__init__.py index 0988071462..a7db617f5b 100644 --- a/flash/audio/classification/__init__.py +++ b/flash/audio/classification/__init__.py @@ -1 +1 @@ -from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess \ No newline at end of file +from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py index e6439c841a..85bbd81a4a 100644 --- a/flash/audio/classification/transforms.py +++ b/flash/audio/classification/transforms.py @@ -33,8 +33,8 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: - """The default transforms for audio classification for spectrograms: resize the spectrogram, convert the spectrogram and target to a tensor, - and collate the batch.""" + """The default transforms for audio classification for spectrograms: resize the spectrogram, + convert the spectrogram and target to a tensor, and collate the batch.""" if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": # Better approach as all transforms are applied on tensor directly return { @@ -69,4 +69,4 @@ def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: ) } - return merge_transforms(default_transforms(spectrogram_size), transforms) \ No newline at end of file + return merge_transforms(default_transforms(spectrogram_size), transforms) diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py index d0550ad2ee..a5b5d91d04 100644 --- a/flash_examples/audio_classification.py +++ b/flash_examples/audio_classification.py @@ -49,4 +49,4 @@ print(predictions) # 5. Save the model! -trainer.save_checkpoint("audio_classification_model.pt") \ No newline at end of file +trainer.save_checkpoint("audio_classification_model.pt") diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py index 6b7024f040..120039c2eb 100644 --- a/tests/audio/__init__.py +++ b/tests/audio/__init__.py @@ -1 +1 @@ -from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess \ No newline at end of file +from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess From 734bb3c20911f40837753c1722bf87d5087c5f1a Mon Sep 17 00:00:00 2001 From: Kinyugo Date: Fri, 16 Jul 2021 17:42:22 +0300 Subject: [PATCH 05/16] updated docs to include audio classification task --- docs/source/index.rst | 3 +- .../source/reference/audio_classification.rst | 74 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 docs/source/reference/audio_classification.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 92fba5c46a..fe9c169155 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -27,8 +27,9 @@ Lightning Flash .. toctree:: :maxdepth: 1 - :caption: Image and Video + :caption: Audio, Image and Video + reference/audio_classification reference/image_classification reference/image_classification_multi_label reference/image_embedder diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst new file mode 100644 index 0000000000..086f6e177e --- /dev/null +++ b/docs/source/reference/audio_classification.rst @@ -0,0 +1,74 @@ + +.. _audio_classification: + +#################### +Audio Classification +#################### + +******** +The Task +******** + +The task of identifying what is in an audio file is called audio classification. +Typically, Audio Classification is used to identify audio files containing sounds or words. +The task predicts which ‘class’ the sound or words most likely belongs to with a degree of certainty. +A class is a label that describes the sounds in an audio file, such as ‘children_playing’, ‘jackhammer’, ‘siren’ etc. + +------ + +******* +Example +******* + +Let's look at the task of predicting whether audio file contains sounds of an airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, or street_music using the UrbanSound8k spectrogram images dataset. +The dataset contains ``train``, ``val`` and ``test`` folders, and then each folder contains a **airconditioner** folder, with spectrograms generated from air-conditioner sounds, **siren** folder with spectrograms generated from siren sounds and the same goes for the other classes. + +.. code-block:: + + urban8k_images + ├── test + │   ├── air_conditioner + │   ├── car_horn + │   ├── children_playing + │   ├── dog_bark + │   ├── drilling + │   ├── engine_idling + │   ├── gun_shot + │   ├── jackhammer + │   ├── siren + │   └── street_music + ├── train + │   ├── air_conditioner + │   ├── car_horn + │   ├── children_playing + │   ├── dog_bark + │   ├── drilling + │   ├── engine_idling + │   ├── gun_shot + │   ├── jackhammer + │   ├── siren + │   └── street_music + └── val + ├── air_conditioner + ├── car_horn + ├── children_playing + ├── dog_bark + ├── drilling + ├── engine_idling + ├── gun_shot + ├── jackhammer + ├── siren + └── street_music + + ... + +Once we've downloaded the data using :func:`~flash.core.data.download_data`, we create the :class:`~flash.audio.classification.data.AudioClassificationData`. +We select a pre-trained backbone to use for our :class:`~flash.image.classification.model.ImageClassifier` and fine-tune on the UrbanSound8k spectrogram images data. +We then use the trained :class:`~flash.image.classification.model.ImageClassifier` for inference. +Finally, we save the model. +Here's the full example: + +.. literalinclude:: ../../../flash_examples/audio_classification.py + :language: python + :lines: 14- + \ No newline at end of file From b65ca1165d4c9e9c1b6afe0507643b8524a83142 Mon Sep 17 00:00:00 2001 From: Kinyugo Date: Fri, 16 Jul 2021 17:50:39 +0300 Subject: [PATCH 06/16] removed empty `model` package --- flash/audio/classification/model.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 flash/audio/classification/model.py diff --git a/flash/audio/classification/model.py b/flash/audio/classification/model.py deleted file mode 100644 index e69de29bb2..0000000000 From b96ae0484bd72edbe305da481eeaa81ced71001e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jul 2021 15:25:02 +0000 Subject: [PATCH 07/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/reference/audio_classification.rst | 2 +- flash/audio/classification/transforms.py | 4 ++-- flash_examples/audio_classification.py | 4 ++-- flash_notebooks/audio_classification.ipynb | 5 ++--- tests/audio/classification/test_data.py | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst index 086f6e177e..ac2f8c15f3 100644 --- a/docs/source/reference/audio_classification.rst +++ b/docs/source/reference/audio_classification.rst @@ -71,4 +71,4 @@ Here's the full example: .. literalinclude:: ../../../flash_examples/audio_classification.py :language: python :lines: 14- - \ No newline at end of file + diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py index 85bbd81a4a..e7f7c00d4c 100644 --- a/flash/audio/classification/transforms.py +++ b/flash/audio/classification/transforms.py @@ -19,7 +19,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE, _AUDIO_AVAILABLE +from flash.core.utilities.imports import _AUDIO_AVAILABLE, _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE if _KORNIA_AVAILABLE: import kornia as K @@ -33,7 +33,7 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: - """The default transforms for audio classification for spectrograms: resize the spectrogram, + """The default transforms for audio classification for spectrograms: resize the spectrogram, convert the spectrogram and target to a tensor, and collate the batch.""" if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": # Better approach as all transforms are applied on tensor directly diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py index a5b5d91d04..363adb7f86 100644 --- a/flash_examples/audio_classification.py +++ b/flash_examples/audio_classification.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import flash -from flash.core.data.utils import download_data from flash.audio import AudioClassificationData -from flash.image import ImageClassifier +from flash.core.data.utils import download_data from flash.core.finetuning import FreezeUnfreeze +from flash.image import ImageClassifier # 1. Create the DataModule download_data("https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip", "./data") diff --git a/flash_notebooks/audio_classification.ipynb b/flash_notebooks/audio_classification.ipynb index 22c842e8d6..290901e8a9 100644 --- a/flash_notebooks/audio_classification.ipynb +++ b/flash_notebooks/audio_classification.ipynb @@ -346,9 +346,8 @@ "language_info": { "name": "python", "version": "" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index cb4798b706..bd36072c7a 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -20,9 +20,9 @@ import torch import torch.nn as nn +from flash.audio import AudioClassificationData from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys -from flash.audio import AudioClassificationData from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from tests.helpers.utils import _IMAGE_TESTING From 1532d296c94f2d3faafa3e04b9f8202d1eac82be Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Jul 2021 17:31:22 +0000 Subject: [PATCH 08/16] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source/reference/audio_classification.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst index ac2f8c15f3..0b623ef860 100644 --- a/docs/source/reference/audio_classification.rst +++ b/docs/source/reference/audio_classification.rst @@ -71,4 +71,3 @@ Here's the full example: .. literalinclude:: ../../../flash_examples/audio_classification.py :language: python :lines: 14- - From 4521959c8123b9a11bd27a964dcb338d5c1ef706 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 20:27:40 +0100 Subject: [PATCH 09/16] Updates --- .github/workflows/ci-testing.yml | 9 + docs/source/index.rst | 9 +- .../source/reference/audio_classification.rst | 42 +-- flash/audio/classification/data.py | 9 +- flash/audio/classification/transforms.py | 4 +- flash/core/utilities/imports.py | 28 +- flash_examples/audio_classification.py | 7 - flash_notebooks/audio_classification.ipynb | 353 ------------------ requirements/datatype_audio.txt | 1 + tests/audio/classification/test_data.py | 18 +- tests/examples/test_scripts.py | 5 + tests/helpers/utils.py | 3 + 12 files changed, 84 insertions(+), 404 deletions(-) delete mode 100644 flash_notebooks/audio_classification.ipynb diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d26d8ecee2..a1b13f6f22 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -61,6 +61,10 @@ jobs: python-version: 3.8 requires: 'latest' topic: ['graph'] + - os: ubuntu-20.04 + python-version: 3.8 + requires: 'latest' + topic: ['audio'] # Timeout: https://stackoverflow.com/a/59076067/4521646 timeout-minutes: 35 @@ -128,6 +132,11 @@ jobs: run: | pip install '.[all]' --pre --upgrade + - name: Install audio test dependencies + if: matrix.topic[0] == 'audio' + run: | + pip install '.[image]' --pre --upgrade + - name: Cache datasets uses: actions/cache@v2 with: diff --git a/docs/source/index.rst b/docs/source/index.rst index 23e393ac55..2ac114009c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,9 +30,8 @@ Lightning Flash .. toctree:: :maxdepth: 1 - :caption: Audio, Image and Video + :caption: Image and Video - reference/audio_classification reference/image_classification reference/image_classification_multi_label reference/image_embedder @@ -41,6 +40,12 @@ Lightning Flash reference/style_transfer reference/video_classification +.. toctree:: + :maxdepth: 1 + :caption: Audio + + reference/audio_classification + .. toctree:: :maxdepth: 1 :caption: Tabular diff --git a/docs/source/reference/audio_classification.rst b/docs/source/reference/audio_classification.rst index 0b623ef860..eb122e6995 100644 --- a/docs/source/reference/audio_classification.rst +++ b/docs/source/reference/audio_classification.rst @@ -26,28 +26,28 @@ The dataset contains ``train``, ``val`` and ``test`` folders, and then each fol .. code-block:: urban8k_images - ├── test - │   ├── air_conditioner - │   ├── car_horn - │   ├── children_playing - │   ├── dog_bark - │   ├── drilling - │   ├── engine_idling - │   ├── gun_shot - │   ├── jackhammer - │   ├── siren - │   └── street_music ├── train - │   ├── air_conditioner - │   ├── car_horn - │   ├── children_playing - │   ├── dog_bark - │   ├── drilling - │   ├── engine_idling - │   ├── gun_shot - │   ├── jackhammer - │   ├── siren - │   └── street_music + │ ├── air_conditioner + │ ├── car_horn + │ ├── children_playing + │ ├── dog_bark + │ ├── drilling + │ ├── engine_idling + │ ├── gun_shot + │ ├── jackhammer + │ ├── siren + │ └── street_music + ├── test + │ ├── air_conditioner + │ ├── car_horn + │ ├── children_playing + │ ├── dog_bark + │ ├── drilling + │ ├── engine_idling + │ ├── gun_shot + │ ├── jackhammer + │ ├── siren + │ └── street_music └── val ├── air_conditioner ├── car_horn diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index e3a0fa40c1..68678b2a1b 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -18,12 +18,14 @@ from flash.core.data.data_module import DataModule from flash.core.data.data_source import DefaultDataSources from flash.core.data.process import Deserializer, Preprocess +from flash.core.utilities.imports import requires_extras from flash.image.classification.data import MatplotlibVisualization from flash.image.data import ImageDeserializer, ImagePathsDataSource class AudioClassificationPreprocess(Preprocess): + @requires_extras(["audio", "image"]) def __init__( self, train_transform: Optional[Dict[str, Callable]], @@ -53,7 +55,12 @@ def __init__( ) def get_state_dict(self) -> Dict[str, Any]: - return {**self.transforms, "image_size": self.image_size} + return { + **self.transforms, + "spectrogram_size": self.spectrogram_size, + "time_mask_param": self.time_mask_param, + "freq_mask_param": self.freq_mask_param, + } @classmethod def load_state_dict(cls, state_dict: Dict[str, Any], strict: bool = False): diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py index e7f7c00d4c..6189b07f16 100644 --- a/flash/audio/classification/transforms.py +++ b/flash/audio/classification/transforms.py @@ -19,7 +19,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms -from flash.core.utilities.imports import _AUDIO_AVAILABLE, _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE if _KORNIA_AVAILABLE: import kornia as K @@ -28,7 +28,7 @@ import torchvision from torchvision import transforms as T -if _AUDIO_AVAILABLE: +if _TORCHAUDIO_AVAILABLE: from torchaudio import transforms as TAudio diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index d738152ba3..80c6b6188c 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -16,6 +16,7 @@ import operator import types from importlib.util import find_spec +from typing import Callable, List, Union from pkg_resources import DistributionNotFound @@ -89,7 +90,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCH_SCATTER_AVAILABLE = _module_available("torch_scatter") _TORCH_SPARSE_AVAILABLE = _module_available("torch_sparse") _TORCH_GEOMETRIC_AVAILABLE = _module_available("torch_geometric") -_AUDIO_AVAILABLE = _module_available("torchaudio") +_TORCHAUDIO_AVAILABLE = _module_available("torchaudio") if Version: _TORCHVISION_GREATER_EQUAL_0_9 = _compare_version("torchvision", operator.ge, "0.9.0") @@ -109,7 +110,7 @@ def _compare_version(package: str, op, version) -> bool: _POINTCLOUD_AVAILABLE = _OPEN3D_AVAILABLE _AUDIO_AVAILABLE = all([ _ASTEROID_AVAILABLE, - _AUDIO_AVAILABLE, + _TORCHAUDIO_AVAILABLE, ]) _GRAPH_AVAILABLE = _TORCH_SCATTER_AVAILABLE and _TORCH_SPARSE_AVAILABLE and _TORCH_GEOMETRIC_AVAILABLE @@ -125,15 +126,22 @@ def _compare_version(package: str, op, version) -> bool: } -def _requires(module_path: str, module_available: bool): +def _requires( + module_paths: Union[str, List], + module_available: Callable[[str], bool], + formatter: Callable[[List[str]], str], +): + + if not isinstance(module_paths, list): + module_paths = [module_paths] def decorator(func): - if not module_available: + if not all(module_available(module_path) for module_path in module_paths): @functools.wraps(func) def wrapper(*args, **kwargs): raise ModuleNotFoundError( - f"Required dependencies not available. Please run: pip install '{module_path}'" + f"Required dependencies not available. Please run: pip install {formatter(module_paths)}" ) return wrapper @@ -143,12 +151,14 @@ def wrapper(*args, **kwargs): return decorator -def requires(module_path: str): - return _requires(module_path, _module_available(module_path)) +def requires(module_paths: Union[str, List]): + return _requires(module_paths, _module_available, lambda module_paths: " ".join(module_paths)) -def requires_extras(extras: str): - return _requires(f"lightning-flash[{extras}]", _EXTRAS_AVAILABLE[extras]) +def requires_extras(extras: Union[str, List]): + return _requires( + extras, lambda extras: _EXTRAS_AVAILABLE[extras], lambda extras: f"'lightning-flash[{','.join(extras)}]'" + ) def lazy_import(module_name, callback=None): diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py index 363adb7f86..b8f0f8a312 100644 --- a/flash_examples/audio_classification.py +++ b/flash_examples/audio_classification.py @@ -38,13 +38,6 @@ "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", "data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg", "data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg", - "data/urban8k_images/test/street_music/7390-9-0-6.wav.jpg", - "data/urban8k_images/test/car_horn/7389-1-0-6.wav.jpg", - "data/urban8k_images/test/dog_bark/344-3-4-0.wav.jpg", - "data/urban8k_images/test/drilling/22962-4-0-0.wav.jpg", - "data/urban8k_images/test/engine_idling/6988-5-0-2.wav.jpg", - "data/urban8k_images/test/gun_shot/7063-6-0-0.wav.jpg", - "data/urban8k_images/test/siren/22601-8-0-9.wav.jpg", ]) print(predictions) diff --git a/flash_notebooks/audio_classification.ipynb b/flash_notebooks/audio_classification.ipynb deleted file mode 100644 index 290901e8a9..0000000000 --- a/flash_notebooks/audio_classification.ipynb +++ /dev/null @@ -1,353 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - " \"Open\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this notebook, we'll go over the basics of lightning Flash by finetuning/prediction with an ImageClassifier on [Urban Sound 8k Images Dataset](https://www.kaggle.com/gokulrejith/urban-sound-8k-images) containing mel spectrograms of urban sounds from 10 classes: *airconditioner, carhorn, childrenplaying, dogbark, drilling, engingeidling, gunshot, jackhammer, siren, and street_music*.\n", - "\n", - "# Finetuning\n", - "\n", - "Finetuning consists of four steps:\n", - " \n", - " - 1. Training a source neural network model on source dataset. In this notebook we can rely on [Torchvision](https://pytorch.org/docs/stable/torchvision/index.html) models, pretrained on the [ImageNet dataset](http://www.image-net.org) and finetune them to fit our dataset of Mel spectrograms. The specific architecture that will be used is the [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/).\n", - " \n", - " - 2. Create a new neural network called the target model. Its architecture replicates the source model and parameters, expect the latest layer which is removed. This model without its latest layer is traditionally called a backbone\n", - " \n", - " - 3. Add new layers after the backbone where the latest output size is the number of target dataset categories. Those new layers, traditionally called head will be randomly initialized while backbone will conserve its pre-trained weights from ImageNet.\n", - " \n", - " - 4. Train the target model on a target dataset, such as Urban Sound 8k Images. However, freezing some layers at training start such as the backbone tends to be more stable. In Flash, it can easily be done with `trainer.finetune(..., strategy=\"freeze\")`. It is also common to `freeze/unfreeze` the backbone. In `Flash`, it can be done with `trainer.finetune(..., strategy=\"freeze_unfreeze\")`. If one wants more control on the unfreeze flow, Flash supports `trainer.finetune(..., strategy=MyFinetuningStrategy())` where `MyFinetuningStrategy` is subclassing `pytorch_lightning.callbacks.BaseFinetuning`. The strategy that will be used in this notebook is the `freeze/unfreeze` strategy. Since our dataset deviates so much from the ImageNet dataset, we first train the head only for a couple of epochs, then later unfreeze the whole model, even the backbone, so we can better fit our dataset. The reason for freezing the head for a couple of epochs is to ensure that we don't propagate, random information to the backbone as training starts, due to random weight initialization of the head, and we can actually leverage features already learned by the backbone.\n", - " \n", - " \n", - "\n", - " \n", - "\n", - "---\n", - " - Give us a ⭐ [on Github](https://www.github.com/PytorchLightning/pytorch-lightning/)\n", - " - Check out [Flash documentation](https://lightning-flash.readthedocs.io/en/latest/)\n", - " - Check out [Lightning documentation](https://pytorch-lightning.readthedocs.io/en/latest/)\n", - " - Join us [on Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%capture\n", - "! pip install git+https://github.com/PyTorchLightning/lightning-flash.git" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### The notebook runtime has to be re-started once Flash is installed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# https://github.com/streamlit/demo-self-driving/issues/17\n", - "if 'google.colab' in str(get_ipython()):\n", - " import os\n", - " os.kill(os.getpid(), 9)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import flash\n", - "from flash.core.data.utils import download_data\n", - "from flash.audio import AudioClassificationData\n", - "from flash.image import ImageClassifier\n", - "from flash.core.finetuning import FreezeUnfreeze" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 1. Download data\n", - "The data are downloaded from a URL, and save in a 'data' directory." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "download_data(\"https://pl-flash-data.s3.amazonaws.com/urban8k_images.zip\", \"./data\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2. Load the data\n", - "\n", - "Flash Tasks have built-in DataModules that you can use to organize your data. Pass in a train, validation and test folders and Flash will take care of the rest.\n", - "Creates a AudioClassificationData object from folders of images arranged in this way:\n", - "\n", - "\n", - " train/dog/xxx.png\n", - " train/dog/xxy.png\n", - " train/dog/xxz.png\n", - " train/cat/123.png\n", - " train/cat/nsdf3.png\n", - " train/cat/asd932.png\n", - "\n", - "\n", - "Note: Each sub-folder content will be considered as a new class." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = AudioClassificationData.from_folders(\n", - " train_folder=\"data/urban8k_images/train\",\n", - " val_folder=\"data/urban8k_images/val\",\n", - " test_folder=\"data/urban8k_images/test\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3. Build the model\n", - "\n", - "Create the ImageClassifier task. By default, the ImageClassifier task uses a [resnet-18](https://pytorch.org/hub/pytorch_vision_resnet/) backbone to train or finetune your model.\n", - "For [Urban Sound 8k Images Dataset](https://www.kaggle.com/gokulrejith/urban-sound-8k-images) ``datamodule.num_classes`` will be 10.\n", - "Backbone can easily be changed with `ImageClassifier(backbone=\"resnet50\")` or you could provide your own `ImageClassifier(backbone=my_backbone)`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = ImageClassifier(backbone=\"resnet18\", num_classes=datamodule.num_classes)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 4. Create the trainer. Run once on data\n", - "\n", - "The trainer object can be used for training or fine-tuning tasks on new sets of data. \n", - "\n", - "You can pass in parameters to control the training routine- limit the number of epochs, run on GPUs or TPUs, etc.\n", - "\n", - "For more details, read the [Trainer Documentation](https://pytorch-lightning.readthedocs.io/en/latest/trainer.html).\n", - "\n", - "In this demo, we will limit the fine-tuning to run just 3 epoch using max_epochs=2." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer = flash.Trainer(max_epochs=3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 5. Finetune the model \n", - "\n", - "`FreezeUnfreeze` strategy unfreezes the backbone after 1 epoch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 6. Test the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.test()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 7. Save it!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trainer.save_checkpoint(\"audio_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Predicting" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1. Load the model from a checkpoint" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = ImageClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/audio_classification_model.pt\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2a. Predict what's on a few images!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "predictions = model.predict([\n", - " \"data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg\",\n", - " \"data/urban8k_images/test/children_playing/9223-2-0-15.wav.jpg\",\n", - " \"data/urban8k_images/test/jackhammer/22883-7-10-0.wav.jpg\",\n", - " \"data/urban8k_images/test/street_music/7390-9-0-6.wav.jpg\",\n", - " \"data/urban8k_images/test/car_horn/7389-1-0-6.wav.jpg\",\n", - " \"data/urban8k_images/test/dog_bark/344-3-4-0.wav.jpg\",\n", - " \"data/urban8k_images/test/drilling/22962-4-0-0.wav.jpg\",\n", - " \"data/urban8k_images/test/engine_idling/6988-5-0-2.wav.jpg\",\n", - " \"data/urban8k_images/test/gun_shot/7063-6-0-0.wav.jpg\",\n", - " \"data/urban8k_images/test/siren/22601-8-0-9.wav.jpg\",\n", - "])\n", - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2b. Or generate prediction with a whole folder!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "datamodule = ImageClassificationData.from_folders(predict_folder=\"data/urban8k_images/test\")\n", - "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", - "print(predictions)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "

Congratulations - Time to Join the Community!

\n", - "
\n", - "\n", - "Congratulations on completing this notebook tutorial! If you enjoyed it and would like to join the Lightning movement, you can do so in the following ways!\n", - "\n", - "### Help us build Flash by adding support for new data-types and new tasks.\n", - "Flash aims at becoming the first task hub, so anyone can get started to great amazing application using deep learning. \n", - "If you are interested, please open a PR with your contributions !!! \n", - "\n", - "\n", - "### Star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) on GitHub\n", - "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we're building.\n", - "\n", - "* Please, star [Lightning](https://github.com/PyTorchLightning/pytorch-lightning)\n", - "\n", - "### Join our [Slack](https://join.slack.com/t/pytorch-lightning/shared_invite/zt-pw5v393p-qRaDgEk24~EjiZNBpSQFgQ)!\n", - "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in `#general` channel\n", - "\n", - "### Interested by SOTA AI models ! Check out [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", - "Bolts has a collection of state-of-the-art models, all implemented in [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) and can be easily integrated within your own projects.\n", - "\n", - "* Please, star [Bolt](https://github.com/PyTorchLightning/lightning-bolts)\n", - "\n", - "### Contributions !\n", - "The best way to contribute to our community is to become a code contributor! At any time you can go to [Lightning](https://github.com/PyTorchLightning/pytorch-lightning) or [Bolt](https://github.com/PyTorchLightning/lightning-bolts) GitHub Issues page and filter for \"good first issue\". \n", - "\n", - "* [Lightning good first issue](https://github.com/PyTorchLightning/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* [Bolt good first issue](https://github.com/PyTorchLightning/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", - "* You can also contribute your own notebooks with useful examples !\n", - "\n", - "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", - "\n", - "" - ] - } - ], - "metadata": { - "interpreter": { - "hash": "bfd2fe4d77e9254f93414392fd32c2a0f6e7778a519d9ecdbf1751b4355012ab" - }, - "kernelspec": { - "display_name": "Python 3.9.1 64-bit ('lf_audio_spectrograms': conda)", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/requirements/datatype_audio.txt b/requirements/datatype_audio.txt index 03c90d99ec..e608a13b78 100644 --- a/requirements/datatype_audio.txt +++ b/requirements/datatype_audio.txt @@ -1 +1,2 @@ asteroid>=0.5.1 +torchaudio diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index bd36072c7a..491eb52efd 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -24,7 +24,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE -from tests.helpers.utils import _IMAGE_TESTING +from tests.helpers.utils import _AUDIO_TESTING if _TORCHVISION_AVAILABLE: import torchvision @@ -40,7 +40,7 @@ def _rand_image(size: Tuple[int, int] = None): return Image.fromarray(np.random.randint(0, 255, (*size, 3), dtype="uint8")) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_filepaths_smoke(tmpdir): tmpdir = Path(tmpdir) @@ -71,7 +71,7 @@ def test_from_filepaths_smoke(tmpdir): assert sorted(list(labels.numpy())) == [1, 2] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_filepaths_list_image_paths(tmpdir): tmpdir = Path(tmpdir) @@ -118,7 +118,7 @@ def test_from_filepaths_list_image_paths(tmpdir): assert list(labels.numpy()) == [2, 5] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) @@ -153,7 +153,7 @@ def test_from_filepaths_visualise(tmpdir): dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) @@ -189,7 +189,7 @@ def test_from_filepaths_visualise_multilabel(tmpdir): dm.show_val_batch("per_batch_transform") -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_filepaths_splits(tmpdir): tmpdir = Path(tmpdir) @@ -234,7 +234,7 @@ def run(transform: Any = None): run(_to_tensor) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_folders_only_train(tmpdir): train_dir = Path(tmpdir / "train") train_dir.mkdir() @@ -258,7 +258,7 @@ def test_from_folders_only_train(tmpdir): assert spectrograms_data.test_dataloader() is None -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_folders_train_val(tmpdir): train_dir = Path(tmpdir / "train") @@ -297,7 +297,7 @@ def test_from_folders_train_val(tmpdir): assert list(labels.numpy()) == [0, 0] -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") def test_from_filepaths_multilabel(tmpdir): tmpdir = Path(tmpdir) diff --git a/tests/examples/test_scripts.py b/tests/examples/test_scripts.py index ec6c4bb834..56b729e36e 100644 --- a/tests/examples/test_scripts.py +++ b/tests/examples/test_scripts.py @@ -21,6 +21,7 @@ from flash.core.utilities.imports import _SKLEARN_AVAILABLE from tests.examples.utils import run_test from tests.helpers.utils import ( + _AUDIO_TESTING, _GRAPH_TESTING, _IMAGE_TESTING, _POINTCLOUD_TESTING, @@ -37,6 +38,10 @@ pytest.param( "custom_task.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed") ), + pytest.param( + "audio_classification.py", + marks=pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed") + ), pytest.param( "image_classification.py", marks=pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed") diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 5bb699b664..bd57cf570d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -14,6 +14,7 @@ import os from flash.core.utilities.imports import ( + _AUDIO_AVAILABLE, _GRAPH_AVAILABLE, _IMAGE_AVAILABLE, _POINTCLOUD_AVAILABLE, @@ -30,6 +31,7 @@ _SERVE_TESTING = _SERVE_AVAILABLE _POINTCLOUD_TESTING = _POINTCLOUD_AVAILABLE _GRAPH_TESTING = _GRAPH_AVAILABLE +_AUDIO_TESTING = _AUDIO_AVAILABLE if "FLASH_TEST_TOPIC" in os.environ: topic = os.environ["FLASH_TEST_TOPIC"] @@ -40,3 +42,4 @@ _SERVE_TESTING = topic == "serve" _POINTCLOUD_TESTING = topic == "pointcloud" _GRAPH_TESTING = topic == "graph" + _AUDIO_TESTING = topic == "audio" From ec1b83284b007aa25f99a5821048170188774a15 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 20:29:15 +0100 Subject: [PATCH 10/16] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 54851b160e..cb7c1cb3b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585)) +- Added `AudioClassificationData` and an example for classifying audio spectrograms ([#594](https://github.com/PyTorchLightning/lightning-flash/pull/594)) + ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) From 7ea6d77656ac31651d58808a5a9a4dc43ece25e1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 20:33:52 +0100 Subject: [PATCH 11/16] Updates --- flash/audio/__init__.py | 2 +- flash/audio/classification/__init__.py | 2 +- tests/audio/classification/test_data.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/flash/audio/__init__.py b/flash/audio/__init__.py index 120039c2eb..40eeaae124 100644 --- a/flash/audio/__init__.py +++ b/flash/audio/__init__.py @@ -1 +1 @@ -from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess +from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 diff --git a/flash/audio/classification/__init__.py b/flash/audio/classification/__init__.py index a7db617f5b..476a303d49 100644 --- a/flash/audio/classification/__init__.py +++ b/flash/audio/classification/__init__.py @@ -1 +1 @@ -from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess +from flash.audio.classification.data import AudioClassificationData, AudioClassificationPreprocess # noqa: F401 diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index 491eb52efd..53cb4d09e5 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from pathlib import Path from typing import Any, List, Tuple From aa9c89809b9d9d41806281a29c1fa905b2d4036e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 20:35:09 +0100 Subject: [PATCH 12/16] Updates --- tests/audio/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/audio/__init__.py b/tests/audio/__init__.py index 120039c2eb..e69de29bb2 100644 --- a/tests/audio/__init__.py +++ b/tests/audio/__init__.py @@ -1 +0,0 @@ -from flash.audio.classification import AudioClassificationData, AudioClassificationPreprocess From eb2ffedf0c50a8abc153bde9808c528fd2529fca Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 20:43:48 +0100 Subject: [PATCH 13/16] Try fix --- .github/workflows/ci-testing.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index a1b13f6f22..cf5c0fe143 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -135,6 +135,7 @@ jobs: - name: Install audio test dependencies if: matrix.topic[0] == 'audio' run: | + sudo apt-get install libsndfile1 pip install '.[image]' --pre --upgrade - name: Cache datasets From a0e41c48a4f08b92d797259498c3e34c5e949f19 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 20:53:23 +0100 Subject: [PATCH 14/16] Updates --- .github/workflows/ci-testing.yml | 1 + tests/audio/classification/test_data.py | 4 +++- tests/image/classification/test_data.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index cf5c0fe143..21ac8fbd45 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -136,6 +136,7 @@ jobs: if: matrix.topic[0] == 'audio' run: | sudo apt-get install libsndfile1 + pip install matplotlib pip install '.[image]' --pre --upgrade - name: Cache datasets diff --git a/tests/audio/classification/test_data.py b/tests/audio/classification/test_data.py index 53cb4d09e5..a1c0ba0677 100644 --- a/tests/audio/classification/test_data.py +++ b/tests/audio/classification/test_data.py @@ -22,7 +22,7 @@ from flash.audio import AudioClassificationData from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys -from flash.core.utilities.imports import _PIL_AVAILABLE, _TORCHVISION_AVAILABLE +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _TORCHVISION_AVAILABLE from tests.helpers.utils import _AUDIO_TESTING if _TORCHVISION_AVAILABLE: @@ -118,6 +118,7 @@ def test_from_filepaths_list_image_paths(tmpdir): @pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise(tmpdir): tmpdir = Path(tmpdir) @@ -153,6 +154,7 @@ def test_from_filepaths_visualise(tmpdir): @pytest.mark.skipif(not _AUDIO_TESTING, reason="audio libraries aren't installed.") +@pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 6a80b5774a..87cb183504 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -168,7 +168,7 @@ def test_from_filepaths_visualise(tmpdir): dm.show_train_batch(["pre_tensor_transform", "post_tensor_transform"]) -@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +@pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed.") @pytest.mark.skipif(not _MATPLOTLIB_AVAILABLE, reason="matplotlib isn't installed.") def test_from_filepaths_visualise_multilabel(tmpdir): tmpdir = Path(tmpdir) From 97e200bd2f7adf92c195af0b75ab67106e658780 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 20:57:20 +0100 Subject: [PATCH 15/16] Updates --- docs/source/_templates/layout.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index d3312220d7..d050db39c5 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -4,7 +4,7 @@ {% block footer %} {{ super() }} {% endblock %} From a3297c9d29568d2c89da819296b4029357d1234b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 16 Jul 2021 21:05:04 +0100 Subject: [PATCH 16/16] Updates --- flash/audio/classification/transforms.py | 34 ++++++------------------ 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/flash/audio/classification/transforms.py b/flash/audio/classification/transforms.py index 6189b07f16..02a9ed2cbc 100644 --- a/flash/audio/classification/transforms.py +++ b/flash/audio/classification/transforms.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Callable, Dict, Tuple import torch @@ -19,10 +18,7 @@ from flash.core.data.data_source import DefaultDataKeys from flash.core.data.transforms import ApplyToKeys, kornia_collate, merge_transforms -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE - -if _KORNIA_AVAILABLE: - import kornia as K +from flash.core.utilities.imports import _TORCHAUDIO_AVAILABLE, _TORCHVISION_AVAILABLE if _TORCHVISION_AVAILABLE: import torchvision @@ -35,19 +31,6 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable]: """The default transforms for audio classification for spectrograms: resize the spectrogram, convert the spectrogram and target to a tensor, and collate the batch.""" - if _KORNIA_AVAILABLE and os.getenv("FLASH_TESTING", "0") != "1": - # Better approach as all transforms are applied on tensor directly - return { - "to_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, torchvision.transforms.ToTensor()), - ApplyToKeys(DefaultDataKeys.TARGET, torch.as_tensor), - ), - "post_tensor_transform": ApplyToKeys( - DefaultDataKeys.INPUT, - K.geometry.Resize(spectrogram_size), - ), - "collate": kornia_collate, - } return { "pre_tensor_transform": ApplyToKeys(DefaultDataKeys.INPUT, T.Resize(spectrogram_size)), "to_tensor_transform": nn.Sequential( @@ -60,13 +43,12 @@ def default_transforms(spectrogram_size: Tuple[int, int]) -> Dict[str, Callable] def train_default_transforms(spectrogram_size: Tuple[int, int], time_mask_param: int, freq_mask_param: int) -> Dict[str, Callable]: - """During training we apply the default transforms with aditional ``TimeMasking`` and ``Frequency Masking``""" - if os.getenv("FLASH_TESTING", "0") != 1: - transforms = { - "post_tensor_transform": nn.Sequential( - ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), - ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) - ) - } + """During training we apply the default transforms with additional ``TimeMasking`` and ``Frequency Masking``""" + transforms = { + "post_tensor_transform": nn.Sequential( + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.TimeMasking(time_mask_param=time_mask_param)), + ApplyToKeys(DefaultDataKeys.INPUT, TAudio.FrequencyMasking(freq_mask_param=freq_mask_param)) + ) + } return merge_transforms(default_transforms(spectrogram_size), transforms)